import gradio as gr from PIL import Image import torch import torchvision.models as models import torchvision.transforms as transforms import json # --------------------------- # Load model # --------------------------- def load_model(model_path="fine_tuned_resnet50.pth"): model = models.resnet50(pretrained=False) model.fc = torch.nn.Linear(in_features=2048, out_features=102) model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) model.eval() return model model = load_model("fine_tuned_resnet50.pth") # --------------------------- # Load flower info # --------------------------- with open("flower with discription.json", "r") as f: flower_info = {flower["id"]: flower for flower in json.load(f)} # --------------------------- # Image preprocessing # --------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --------------------------- # Inference function # --------------------------- def classify_image(image): image = image.convert("RGB") image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image_tensor) predicted_class = torch.argmax(output, dim=1).item() info = flower_info.get(predicted_class, None) if info: return [ info["name"].title(), info["scientific_name"], info["genus"], info["fun_fact"], info["where_found"], info.get("description", "No description available.") ] else: return ["Unknown"] * 6 # --------------------------- # Gradio interface # --------------------------- iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Flower Name"), gr.Textbox(label="Scientific Name"), gr.Textbox(label="Genus"), gr.Textbox(label="Fun Fact"), gr.Textbox(label="Where Found"), gr.Textbox(label="Description") ], title="Flower Classification", description="🌸 Upload a flower image to get its name, genus, scientific name, fun fact, and more.", allow_flagging="never" ) if __name__ == "__main__": iface.launch()