import os import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # ── Fix Matplotlib cache permission errors ── os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # ── Constants and Paths ── MODEL_PATH = "./mobilenetv2.pth" CLASS_NAMES = ["undercooked", "raw", "cooked"] IMAGE_SIZE = (224, 224) # ── Load model ── def load_model(): model = models.mobilenet_v2(weights=None) model.classifier[1] = nn.Linear(model.last_channel, len(CLASS_NAMES)) model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) model.eval() return model model = load_model() # ── Image transform ── transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), ]) def classify(image, progress=gr.Progress(track_tqdm=True)): try: image = image.convert("RGB") tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(tensor) prediction = torch.argmax(outputs, dim=1).item() return {CLASS_NAMES[prediction]: 1.0} except Exception as e: print("Error during prediction:", e) # Return fixed label with zero confidence or default fallback return {"Error": 0.0} # ── Gradio Layout ── with gr.Blocks(css="#main-col {max-width: 640px; margin: auto;}") as demo: with gr.Column(elem_id="main-col"): gr.Markdown("## 🍳 MobileNetV2 Food Doneness Classifier") gr.Markdown("Upload an image of food to determine if it's **undercooked**, **raw**, or **cooked**.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") run_button = gr.Button("Classify", variant="primary") result_output = gr.Label(label="Prediction") gr.HTML("
Custom model trained with MobileNetV2") run_button.click(fn=classify, inputs=image_input, outputs=result_output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)