| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from model import load_model, class_names | |
| import torch | |
| model = load_model() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| img = image.convert("RGB") | |
| tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(tensor) | |
| probs = torch.softmax(output, dim=1).squeeze() | |
| return {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=2), | |
| title="Fracture X-Ray Classifier", | |
| description="Upload an X-ray image to detect fractures.") | |
| demo.launch() | |