File size: 1,736 Bytes
8be9bad
 
 
 
742ffc8
e449388
e73752e
8be9bad
e449388
3d5dbf8
b9019c2
742ffc8
781db03
742ffc8
 
b9019c2
742ffc8
 
 
b9019c2
8be9bad
 
742ffc8
9c6e1ee
742ffc8
 
b3bc3f6
742ffc8
 
 
 
 
 
 
b9019c2
742ffc8
b9019c2
 
 
 
8be9bad
b9019c2
8be9bad
742ffc8
 
 
 
 
 
 
f99f29b
 
af0a313
742ffc8
 
8258396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
# ── Ensure Matplotlib can write its cache without permission errors ──
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"

import gradio as gr
import torch
from torchvision import transforms, models
from torchvision.models import MobileNet_V2_Weights
from PIL import Image
import torch.nn as nn

# Paths
BASE_DIR = os.getcwd()
MODEL_PATH = os.path.join(BASE_DIR, "src", "mobilenetv2.pth")
CLASSES_PATH = os.path.join(BASE_DIR, "src", "classes.txt")

# Load labels
with open(CLASSES_PATH, "r") as f:
    class_names = [line.strip() for line in f]

# Load model (using the new weights API)
model = models.mobilenet_v2(weights=None)  # explicitly no pretrained weights
model.classifier[1] = nn.Linear(model.last_channel, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()

# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Inference function
def classify_image(image):
    try:
        image = image.convert("RGB")
        img_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            output = model(img_tensor)
            predicted_class = torch.argmax(output, dim=1).item()
        return class_names[predicted_class]
    except Exception as e:
        return f"Error: {e}"

# Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Label(label="Prediction"),
    title="MobileNetV2 Classifier",
    description="Upload an image to classify it using a custom MobileNetV2 model.",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860, share=True)