SHXn3's picture
Update app.py
781db03 verified
import os
# ── Ensure Matplotlib can write its cache without permission errors ──
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
import gradio as gr
import torch
from torchvision import transforms, models
from torchvision.models import MobileNet_V2_Weights
from PIL import Image
import torch.nn as nn
# Paths
BASE_DIR = os.getcwd()
MODEL_PATH = os.path.join(BASE_DIR, "src", "mobilenetv2.pth")
CLASSES_PATH = os.path.join(BASE_DIR, "src", "classes.txt")
# Load labels
with open(CLASSES_PATH, "r") as f:
class_names = [line.strip() for line in f]
# Load model (using the new weights API)
model = models.mobilenet_v2(weights=None) # explicitly no pretrained weights
model.classifier[1] = nn.Linear(model.last_channel, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Inference function
def classify_image(image):
try:
image = image.convert("RGB")
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
predicted_class = torch.argmax(output, dim=1).item()
return class_names[predicted_class]
except Exception as e:
return f"Error: {e}"
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(label="Prediction"),
title="MobileNetV2 Classifier",
description="Upload an image to classify it using a custom MobileNetV2 model.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)