| import gradio as gr |
| from PIL import Image |
| from torchvision import transforms |
| from model import load_model, class_names |
| import torch |
|
|
| model = load_model() |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], |
| [0.229, 0.224, 0.225]) |
| ]) |
|
|
| def predict(image): |
| img = image.convert("RGB") |
| tensor = transform(img).unsqueeze(0) |
| with torch.no_grad(): |
| output = model(tensor) |
| probs = torch.softmax(output, dim=1).squeeze() |
| return {class_names[i]: float(probs[i]) for i in range(len(class_names))} |
|
|
| demo = gr.Interface(fn=predict, |
| inputs=gr.Image(type="pil"), |
| outputs=gr.Label(num_top_classes=2), |
| title="Fracture X-Ray Classifier", |
| description="Upload an X-ray image to detect fractures.") |
|
|
| demo.launch() |
|
|