import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr # Define the same preprocessing as during training transform = transforms.Compose([ transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Define model architecture (same as training) def load_model(): model = models.resnet50(weights=None) # Don't load pretrained again in_features = model.fc.in_features model.fc = nn.Sequential( nn.Linear(in_features, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 2) # 2 classes: Fractured, Non-Fractured ) model.load_state_dict(torch.load("fract_model.pth", map_location=torch.device('cpu'))) model.eval() return model model = load_model() class_names = ["Fractured", "Non-Fractured"] # Prediction function def predict(image): image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) class_idx = predicted.item() confidence = torch.softmax(outputs, dim=1)[0][class_idx].item() return {class_names[class_idx]: float(confidence)} # Gradio Interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title="Bone Fracture Detection", description="Upload an X-ray image to detect if it's Fractured or Non-Fractured." ) if __name__ == "__main__": interface.launch()