ma4389 commited on
Commit
01610e4
·
verified ·
1 Parent(s): ec3dc17

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +52 -0
  2. reqruiements.txt +4 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # Define the same preprocessing as during training
8
+ transform = transforms.Compose([
9
+ transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
13
+ ])
14
+
15
+ # Define model architecture (same as training)
16
+ def load_model():
17
+ model = models.resnet50(weights=None) # Don't load pretrained again
18
+ in_features = model.fc.in_features
19
+ model.fc = nn.Sequential(
20
+ nn.Linear(in_features, 512),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.4),
23
+ nn.Linear(512, 2) # 2 classes: Fractured, Non-Fractured
24
+ )
25
+ model.load_state_dict(torch.load("fract_model.pth", map_location=torch.device('cpu')))
26
+ model.eval()
27
+ return model
28
+
29
+ model = load_model()
30
+ class_names = ["Fractured", "Non-Fractured"]
31
+
32
+ # Prediction function
33
+ def predict(image):
34
+ image = transform(image).unsqueeze(0) # Add batch dimension
35
+ with torch.no_grad():
36
+ outputs = model(image)
37
+ _, predicted = torch.max(outputs, 1)
38
+ class_idx = predicted.item()
39
+ confidence = torch.softmax(outputs, dim=1)[0][class_idx].item()
40
+ return {class_names[class_idx]: float(confidence)}
41
+
42
+ # Gradio Interface
43
+ interface = gr.Interface(
44
+ fn=predict,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=gr.Label(num_top_classes=2),
47
+ title="Bone Fracture Detection",
48
+ description="Upload an X-ray image to detect if it's Fractured or Non-Fractured."
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ interface.launch()
reqruiements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ torchvision>=0.15
3
+ gradio>=4.0
4
+ pillow>=9.0