Spaces:
Build error
Build error
File size: 2,094 Bytes
3481ee8 84eafa0 3481ee8 fe39adc 3481ee8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import os
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# ββ Fix Matplotlib cache permission errors ββ
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
# ββ Constants and Paths ββ
MODEL_PATH = "./mobilenetv2.pth"
CLASS_NAMES = ["undercooked", "raw", "cooked"]
IMAGE_SIZE = (224, 224)
# ββ Load model ββ
def load_model():
model = models.mobilenet_v2(weights=None)
model.classifier[1] = nn.Linear(model.last_channel, len(CLASS_NAMES))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
return model
model = load_model()
# ββ Image transform ββ
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
])
def classify(image, progress=gr.Progress(track_tqdm=True)):
try:
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(tensor)
prediction = torch.argmax(outputs, dim=1).item()
return {CLASS_NAMES[prediction]: 1.0}
except Exception as e:
print("Error during prediction:", e)
# Return fixed label with zero confidence or default fallback
return {"Error": 0.0}
# ββ Gradio Layout ββ
with gr.Blocks(css="#main-col {max-width: 640px; margin: auto;}") as demo:
with gr.Column(elem_id="main-col"):
gr.Markdown("## π³ MobileNetV2 Food Doneness Classifier")
gr.Markdown("Upload an image of food to determine if it's **undercooked**, **raw**, or **cooked**.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
run_button = gr.Button("Classify", variant="primary")
result_output = gr.Label(label="Prediction")
gr.HTML("<br><small>Custom model trained with MobileNetV2</small>")
run_button.click(fn=classify, inputs=image_input, outputs=result_output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |