File size: 2,094 Bytes
3481ee8
 
 
 
 
 
 
 
 
 
 
84eafa0
3481ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe39adc
 
 
 
3481ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

# ── Fix Matplotlib cache permission errors ──
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"

# ── Constants and Paths ──
MODEL_PATH = "./mobilenetv2.pth"
CLASS_NAMES = ["undercooked", "raw", "cooked"]
IMAGE_SIZE = (224, 224)


# ── Load model ──
def load_model():
    model = models.mobilenet_v2(weights=None)
    model.classifier[1] = nn.Linear(model.last_channel, len(CLASS_NAMES))
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    model.eval()
    return model


model = load_model()

# ── Image transform ──
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])


def classify(image, progress=gr.Progress(track_tqdm=True)):
    try:
        image = image.convert("RGB")
        tensor = transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = model(tensor)
            prediction = torch.argmax(outputs, dim=1).item()

        return {CLASS_NAMES[prediction]: 1.0}
    except Exception as e:
        print("Error during prediction:", e)
        # Return fixed label with zero confidence or default fallback
        return {"Error": 0.0}



# ── Gradio Layout ──
with gr.Blocks(css="#main-col {max-width: 640px; margin: auto;}") as demo:
    with gr.Column(elem_id="main-col"):
        gr.Markdown("## 🍳 MobileNetV2 Food Doneness Classifier")
        gr.Markdown("Upload an image of food to determine if it's **undercooked**, **raw**, or **cooked**.")

        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Image")
            run_button = gr.Button("Classify", variant="primary")

        result_output = gr.Label(label="Prediction")

        gr.HTML("<br><small>Custom model trained with MobileNetV2</small>")

        run_button.click(fn=classify, inputs=image_input, outputs=result_output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)