ma4389 commited on
Commit
560014e
·
verified ·
1 Parent(s): 94ab7ca

Upload 3 files

Browse files
Files changed (3) hide show
  1. best_model (2).pth +3 -0
  2. covid.py +55 -0
  3. requirements.txt +4 -0
best_model (2).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:919069e965c757cefdc54331e2e95e2268ea588ac8c01e302366b94c2c0814a2
3
+ size 98551872
covid.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import torchvision.models as models
6
+ import torch.nn as nn
7
+
8
+ # 🔹 Device setup
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # 🔹 Load ResNet50 and modify for 2-class output
12
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
13
+ in_features = model.fc.in_features
14
+ model.fc = nn.Sequential(
15
+ nn.Linear(in_features, 512),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.4),
18
+ nn.Linear(512, 2) # 2 classes: NORMAL, PNEUMONIA
19
+ )
20
+
21
+ # 🔹 Load the trained model
22
+ model.load_state_dict(torch.load("best_model (2).pth", map_location=device))
23
+ model.to(device)
24
+ model.eval()
25
+
26
+ # 🔹 Preprocessing - exactly matching your val_transforms
27
+ transform = transforms.Compose([
28
+ transforms.Lambda(lambda img: img.convert("RGB")), # 🧠 Force RGB
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
32
+ ])
33
+
34
+ # 🔹 Class labels
35
+ class_names = ["NORMAL", "PNEUMONIA"]
36
+
37
+ # 🔹 Inference function
38
+ def classify_image(img):
39
+ img = transform(img).unsqueeze(0).to(device)
40
+ with torch.no_grad():
41
+ outputs = model(img)
42
+ probs = torch.nn.functional.softmax(outputs, dim=1)
43
+ return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
44
+
45
+ # 🔹 Gradio Interface
46
+ interface = gr.Interface(
47
+ fn=classify_image,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs=gr.Label(num_top_classes=2),
50
+ title="🩺 Pneumonia Classifier",
51
+ description="Upload a chest X-ray image. The model predicts whether it's NORMAL or shows signs of PNEUMONIA."
52
+ )
53
+
54
+ # 🔹 Launch the app
55
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ pillow