import os
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# ── Fix Matplotlib cache permission errors ──
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
# ── Constants and Paths ──
MODEL_PATH = "./mobilenetv2.pth"
CLASS_NAMES = ["undercooked", "raw", "cooked"]
IMAGE_SIZE = (224, 224)
# ── Load model ──
def load_model():
model = models.mobilenet_v2(weights=None)
model.classifier[1] = nn.Linear(model.last_channel, len(CLASS_NAMES))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
return model
model = load_model()
# ── Image transform ──
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
])
def classify(image, progress=gr.Progress(track_tqdm=True)):
try:
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(tensor)
prediction = torch.argmax(outputs, dim=1).item()
return {CLASS_NAMES[prediction]: 1.0}
except Exception as e:
print("Error during prediction:", e)
# Return fixed label with zero confidence or default fallback
return {"Error": 0.0}
# ── Gradio Layout ──
with gr.Blocks(css="#main-col {max-width: 640px; margin: auto;}") as demo:
with gr.Column(elem_id="main-col"):
gr.Markdown("## 🍳 MobileNetV2 Food Doneness Classifier")
gr.Markdown("Upload an image of food to determine if it's **undercooked**, **raw**, or **cooked**.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
run_button = gr.Button("Classify", variant="primary")
result_output = gr.Label(label="Prediction")
gr.HTML("
Custom model trained with MobileNetV2")
run_button.click(fn=classify, inputs=image_input, outputs=result_output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)