Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from torchvision.models import mobilenet_v3_large | |
| from torchvision.transforms import v2 | |
| from PIL import Image | |
| class TrashMobileNet(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, num_classes=6): | |
| super(TrashMobileNet, self).__init__() | |
| self.model = mobilenet_v3_large(weights="DEFAULT") | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| num_features = self.model.classifier[-1].in_features | |
| self.model.classifier[-1] = nn.Linear(num_features, num_classes) | |
| for param in self.model.classifier[-1].parameters(): | |
| param.requires_grad = True | |
| def forward(self, x): | |
| x = self.model(x) | |
| return x | |
| model_name = "pradanaadn/trash-clasification" | |
| model = TrashMobileNet.from_pretrained(model_name) | |
| model.eval() | |
| transform = v2.Compose([ | |
| v2.Resize((224, 224)), | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| ]) | |
| def predict(image): | |
| labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image_tensor = transform(image) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| probabilities = probabilities[0].tolist() | |
| # Create dictionary of label-probability pairs | |
| return {label: float(prob) for label, prob in zip(labels, probabilities)} | |
| examples = [ | |
| ["examples/cardbox.jpeg", "A cardboard box"], | |
| ["examples/glass.jpeg", "A glass bottle"], | |
| ["examples/plastic.png", "Mixed trash"] | |
| ] | |
| with gr.Blocks() as iface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| elem_id="image_upload" | |
| ) | |
| submit_btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label( | |
| label="Classification Results", | |
| num_top_classes=6 | |
| ) | |
| gr.Markdown("### Example Images") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_image, | |
| outputs=output_label, | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=input_image, | |
| outputs=output_label | |
| ) | |
| # Launch the interface | |
| iface.launch() |