import gradio as gr import torch from torchvision import transforms import timm from PIL import Image import os # 1. Load Labels with open('labels.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] # 2. Model Definition def get_model(num_classes=200, model_path='models/final_model_best.pth'): """Initializes and loads the pre-trained ConvNeXt V2 Large model.""" model = timm.create_model( 'convnextv2_large.fcmae_ft_in22k_in1k', pretrained=False, num_classes=num_classes, drop_path_rate=0.2 ) if not os.path.exists(model_path): print(f"Error: Model file not found at {model_path}.") return None try: model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() print("Model loaded successfully.") return model except Exception as e: print(f"An error occurred while loading the model: {e}") return None model = get_model() # 3. Image Transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 4. Prediction Function def predict(image): """Takes a PIL image and returns a dictionary of top 3 predictions.""" if model is None: return {"Error": "Model is not loaded. Please check the logs for errors."} image = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] # Get top 3 predictions top3_prob, top3_indices = torch.topk(probabilities, 3) confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)} return confidences # 5. Gradio Interface title = "Bird Species Classifier" description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species." iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Bird Image"), outputs=gr.Label(num_top_classes=3, label="Predictions"), title=title, description=description, ) if __name__ == "__main__": iface.launch()