import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr # 🖥️ Device (CPU for Gradio unless you have GPU setup) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 🔨 Rebuild your model resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) in_features = resnet.fc.in_features resnet.fc = nn.Sequential( nn.Linear(in_features, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 3) # 3 classes: dog, wild, cat ) resnet = resnet.to(device) # 📥 Load saved weights resnet.load_state_dict(torch.load("best_model.pth", map_location=device)) resnet.eval() # 🖼️ Validation transforms val_transforms = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), # 🧠 Force 3-channel transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) ]) # 🏷️ Class names class_names = ["dog", "cat", "wild"] # 🔮 Prediction function def classify_image(img): img = val_transforms(img).unsqueeze(0).to(device) # Add batch dim & send to device with torch.no_grad(): outputs = resnet(img) probs = torch.softmax(outputs, dim=1) confidences = probs.squeeze().cpu().tolist() predicted_class = class_names[torch.argmax(probs).item()] return {class_names[i]: confidences[i] for i in range(len(class_names))} # 🎨 Gradio Interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Dog/Wild/Cat Classifier 🐶🐯🐱", description="Upload an image to classify it as Dog, Wild Animal, or Cat." ) iface.launch()