import gradio as gr from PIL import Image from torchvision import transforms from model import load_model, class_names import torch model = load_model() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(image): img = image.convert("RGB") tensor = transform(img).unsqueeze(0) with torch.no_grad(): output = model(tensor) probs = torch.softmax(output, dim=1).squeeze() return {class_names[i]: float(probs[i]) for i in range(len(class_names))} demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title="Fracture X-Ray Classifier", description="Upload an X-ray image to detect fractures.") demo.launch()