PsychicFireSong's picture
Initial upload of Gradio app and baseline model
b404ed5
raw
history blame
3.98 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
# --- Model Loading ---
def get_model_paths():
"""Returns a dictionary of model names to their file paths."""
model_dir = "models"
model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
# You can create more descriptive names here if you want
model_names = [os.path.splitext(f)[0] for f in model_files]
return dict(zip(model_names, [os.path.join(model_dir, f) for f in model_files]))
MODEL_PATHS = get_model_paths()
# Add placeholder paths for the other two models
MODEL_PATHS["Future Model 1"] = "models/future_model_1.pth"
MODEL_PATHS["Future Model 2"] = "models/future_model_2.pth"
# This is a placeholder for your actual model loading logic
# You will need to replace this with the code to load your specific model architecture
def load_model(model_path):
"""Loads a model from the given path."""
# Example:
# model = torch.load(model_path)
# model.eval()
# return model
# For now, returning a dummy object
print(f"Loading model from: {model_path}")
if not os.path.exists(model_path):
print("Warning: Model file does not exist. Using a dummy model.")
return None
# Replace with your actual model loading
try:
# This is a guess, you'll need to replace with your actual model class
from baseline import convnext_v2_base
model = convnext_v2_base(num_classes=10) # Or whatever your number of classes is
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
print("Using a dummy model.")
return None
# --- Image Preprocessing ---
# You'll need to adjust this to match the preprocessing your model expects
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- Prediction ---
# Load species list
species_df = pd.read_csv('species_list.txt', sep=';', header=None, names=['class_id', 'species_name'])
idx_to_class = {i: row['class_id'] for i, row in species_df.iterrows()}
class_id_to_name = {row['class_id']: row['species_name'] for i, row in species_df.iterrows()}
def predict(model_name, image):
"""Makes a prediction on an image using the selected model."""
model_path = MODEL_PATHS[model_name]
if not os.path.exists(model_path):
return f"Model '{model_name}' not found. Please upload the model file."
model = load_model(model_path)
if model is None:
return f"Could not load model '{model_name}'."
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
processed_image = preprocess(pil_image).unsqueeze(0)
with torch.no_grad():
outputs = model(processed_image).logits
_, predicted_idx = torch.max(outputs, 1)
class_id = idx_to_class[predicted_idx.item()]
class_name = class_id_to_name[class_id]
return f"Prediction: {class_name}"
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Plant Classification")
gr.Markdown("Select a model and upload an image to classify.")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODEL_PATHS.keys()),
label="Select Model",
value=list(MODEL_PATHS.keys())[0] if MODEL_PATHS else None
)
image_input = gr.Image(type="numpy")
output_text = gr.Textbox(label="Prediction")
image_input.change(
fn=predict,
inputs=[model_dropdown, image_input],
outputs=output_text
)
model_dropdown.change(
fn=predict,
inputs=[model_dropdown, image_input],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()