Spaces:
Runtime error
Runtime error
| import os | |
| # ── Ensure Matplotlib can write its cache without permission errors ── | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms, models | |
| from torchvision.models import MobileNet_V2_Weights | |
| from PIL import Image | |
| import torch.nn as nn | |
| # Paths | |
| BASE_DIR = os.getcwd() | |
| MODEL_PATH = os.path.join(BASE_DIR, "src", "mobilenetv2.pth") | |
| CLASSES_PATH = os.path.join(BASE_DIR, "src", "classes.txt") | |
| # Load labels | |
| with open(CLASSES_PATH, "r") as f: | |
| class_names = [line.strip() for line in f] | |
| # Load model (using the new weights API) | |
| model = models.mobilenet_v2(weights=None) # explicitly no pretrained weights | |
| model.classifier[1] = nn.Linear(model.last_channel, len(class_names)) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) | |
| model.eval() | |
| # Image transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Inference function | |
| def classify_image(image): | |
| try: | |
| image = image.convert("RGB") | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| predicted_class = torch.argmax(output, dim=1).item() | |
| return class_names[predicted_class] | |
| except Exception as e: | |
| return f"Error: {e}" | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(label="Prediction"), | |
| title="MobileNetV2 Classifier", | |
| description="Upload an image to classify it using a custom MobileNetV2 model.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |