import torch from PIL import Image import torchvision.transforms as transforms import gradio as gr import os import gdown from model import get_model, CLASS_NAMES MODEL_PATH = "waste_classifier.pth" # Download model if not present if not os.path.exists(MODEL_PATH): url = "https://drive.google.com/uc?id=1RDBXrDvQ7B71SU-nUybDzbIpXkzHBStV" gdown.download(url, MODEL_PATH, quiet=False) # Load model model = get_model() model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) import torch.nn.functional as F def predict(image): image = image.convert("RGB") img = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img) probs = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probs, 1) return f"{CLASS_NAMES[predicted.item()]} ({confidence.item()*100:.2f}%)" # Gradio UI interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Waste Classifier", description="Upload an image to classify waste" ) interface.launch()