AMLTestSpace / app.py
PsychicFireSong's picture
Update app.py
5bbce66 verified
import gradio as gr
import torch
from torchvision import transforms
import timm
from PIL import Image
import os
# 1. Load Labels
with open('labels.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
# 2. Model Definition
def get_model(num_classes=200, model_path='models/final_model_best.pth'):
"""Initializes and loads the pre-trained ConvNeXt V2 Large model."""
model = timm.create_model(
'convnextv2_large.fcmae_ft_in22k_in1k',
pretrained=False,
num_classes=num_classes,
drop_path_rate=0.2
)
if not os.path.exists(model_path):
print(f"Error: Model file not found at {model_path}.")
return None
try:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print("Model loaded successfully.")
return model
except Exception as e:
print(f"An error occurred while loading the model: {e}")
return None
model = get_model()
# 3. Image Transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 4. Prediction Function
def predict(image):
"""Takes a PIL image and returns a dictionary of top 3 predictions."""
if model is None:
return {"Error": "Model is not loaded. Please check the logs for errors."}
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
# Get top 3 predictions
top3_prob, top3_indices = torch.topk(probabilities, 3)
confidences = {labels[i]: float(p) for i, p in zip(top3_indices, top3_prob)}
return confidences
# 5. Gradio Interface
title = "Bird Species Classifier"
description = "Upload an image of a bird to classify it into one of 200 species. This model is a ConvNeXt V2 Large, fine-tuned on a dataset of 200 bird species."
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=3, label="Predictions"),
title=title,
description=description,
)
if __name__ == "__main__":
iface.launch()