Spaces:
Runtime error
Runtime error
File size: 1,736 Bytes
8be9bad 742ffc8 e449388 e73752e 8be9bad e449388 3d5dbf8 b9019c2 742ffc8 781db03 742ffc8 b9019c2 742ffc8 b9019c2 8be9bad 742ffc8 9c6e1ee 742ffc8 b3bc3f6 742ffc8 b9019c2 742ffc8 b9019c2 8be9bad b9019c2 8be9bad 742ffc8 f99f29b af0a313 742ffc8 8258396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import os
# ── Ensure Matplotlib can write its cache without permission errors ──
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
import gradio as gr
import torch
from torchvision import transforms, models
from torchvision.models import MobileNet_V2_Weights
from PIL import Image
import torch.nn as nn
# Paths
BASE_DIR = os.getcwd()
MODEL_PATH = os.path.join(BASE_DIR, "src", "mobilenetv2.pth")
CLASSES_PATH = os.path.join(BASE_DIR, "src", "classes.txt")
# Load labels
with open(CLASSES_PATH, "r") as f:
class_names = [line.strip() for line in f]
# Load model (using the new weights API)
model = models.mobilenet_v2(weights=None) # explicitly no pretrained weights
model.classifier[1] = nn.Linear(model.last_channel, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Inference function
def classify_image(image):
try:
image = image.convert("RGB")
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
predicted_class = torch.argmax(output, dim=1).item()
return class_names[predicted_class]
except Exception as e:
return f"Error: {e}"
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(label="Prediction"),
title="MobileNetV2 Classifier",
description="Upload an image to classify it using a custom MobileNetV2 model.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
|