import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr import os # 📦 Class names class_names = [ "accordion", "banjo", "drum", "flute", "guitar", "harmonica", "saxophone", "sitar", "tabla", "violin" ] # 📐 Transformations (same as during training) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 🧠 Load model def load_model(model_path="music_model.pth"): model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, len(class_names)) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model model = load_model("music_model.pth") # 🔍 Prediction function def predict(image): image = Image.fromarray(image).convert("RGB") img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) _, predicted = torch.max(outputs, 1) prediction = class_names[predicted.item()] confidences = torch.nn.functional.softmax(outputs[0], dim=0) confidences_dict = {class_names[i]: float(confidences[i]) for i in range(len(class_names))} return prediction, confidences_dict # 🎛️ Gradio Interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Instrument Image"), outputs=[ gr.Label(label="Predicted Instrument"), gr.Label(label="Confidence Scores") ], title="🎵 Musical Instrument Classifier", description="Upload an image of a musical instrument and get the predicted class (accordion, guitar, etc.)" ) # 🚀 Launch the app if __name__ == "__main__": interface.launch()