Spaces:
Build error
Build error
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import timm | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| checkpoint = torch.load("./model.pth", map_location=torch.device(device)) | |
| model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=12) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| class_labels = [ | |
| 'battery', | |
| 'biological', | |
| 'brown-glass', | |
| 'cardboard', | |
| 'clothes', | |
| 'green-glass', | |
| 'metal', | |
| 'paper', | |
| 'plastic', | |
| 'shoes', | |
| 'trash', | |
| 'white-glass' | |
| ] | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # EfficientNet-B0 input size | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
| with torch.inference_mode(): | |
| output = model(image) | |
| _, predicted = torch.max(output, 1) | |
| label = class_labels[predicted.item()] | |
| return label | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="TSYP Garbage Classification Model", | |
| description="Upload an image of garbage to classify it into one of 12 categories(make sure it's the only thing in the photo , except background)" | |
| ) | |
| interface.launch() | |