import os # ── Ensure Matplotlib can write its cache without permission errors ── os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" import gradio as gr import torch from torchvision import transforms, models from torchvision.models import MobileNet_V2_Weights from PIL import Image import torch.nn as nn # Paths BASE_DIR = os.getcwd() MODEL_PATH = os.path.join(BASE_DIR, "src", "mobilenetv2.pth") CLASSES_PATH = os.path.join(BASE_DIR, "src", "classes.txt") # Load labels with open(CLASSES_PATH, "r") as f: class_names = [line.strip() for line in f] # Load model (using the new weights API) model = models.mobilenet_v2(weights=None) # explicitly no pretrained weights model.classifier[1] = nn.Linear(model.last_channel, len(class_names)) model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) model.eval() # Image transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Inference function def classify_image(image): try: image = image.convert("RGB") img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(img_tensor) predicted_class = torch.argmax(output, dim=1).item() return class_names[predicted_class] except Exception as e: return f"Error: {e}" # Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Label(label="Prediction"), title="MobileNetV2 Classifier", description="Upload an image to classify it using a custom MobileNetV2 model.", allow_flagging="never" ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860, share=True)