rabbydatainsight commited on
Commit
6e41f2b
·
verified ·
1 Parent(s): 5cfe6ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -53
app.py CHANGED
@@ -1,54 +1,73 @@
1
- # app.py (Use this code for Hugging Face)
2
- import torch
3
- import gradio as gr
4
- from PIL import Image
5
- from transformers import SwinForImageClassification, ViTImageProcessor
6
-
7
- # --- 1. Load Model & Processor ---
8
- MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224"
9
- MODEL_PATH = "best_model_swin.pth"
10
- NUM_CLASSES = 3
11
- CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
12
- device = torch.device("cpu") # Use CPU for free-tier hosting
13
-
14
- processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
15
- model = SwinForImageClassification.from_pretrained(
16
- MODEL_NAME,
17
- num_labels=NUM_CLASSES,
18
- ignore_mismatched_sizes=True
19
- )
20
- model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
21
- model.to(device)
22
- model.eval()
23
-
24
- # --- 2. Define Prediction Function ---
25
- def classify_image(input_image: Image.Image):
26
- if input_image is None:
27
- return "Please upload an image."
28
- if input_image.mode != "RGB":
29
- input_image = input_image.convert("RGB")
30
-
31
- inputs = processor(images=input_image, return_tensors="pt")
32
- pixel_values = inputs['pixel_values'].to(device)
33
-
34
- with torch.no_grad():
35
- outputs = model(pixel_values)
36
-
37
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
38
-
39
- # Create a dictionary of {class_name: probability}
40
- confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
41
-
42
- return confidences
43
-
44
- # --- 3. Create the Gradio Interface ---
45
- iface = gr.Interface(
46
- fn=classify_image,
47
- inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
48
- outputs=gr.Label(num_top_classes=3, label="Predictions"),
49
- title="Swin Transformer Chest X-Ray Classifier",
50
- description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia."
51
- )
52
-
53
- # --- 4. Launch the app ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  iface.launch()
 
1
+ # app.py (Use this code for Hugging Face)
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from transformers import SwinForImageClassification, ViTImageProcessor
6
+
7
+ # --- 1. Load Model & Processor ---
8
+ MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224"
9
+ MODEL_PATH = "best_model_swin.pth"
10
+ NUM_CLASSES = 3
11
+ CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
12
+ device = torch.device("cpu") # Use CPU for free-tier hosting
13
+
14
+ # --- ADDED ---
15
+ # We will reject any prediction where the model's top guess is below 90% confidence.
16
+ # You can adjust this value (e.g., to 0.95 or 0.85)
17
+ CONFIDENCE_THRESHOLD = 0.90
18
+
19
+ processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
20
+ model = SwinForImageClassification.from_pretrained(
21
+ MODEL_NAME,
22
+ num_labels=NUM_CLASSES,
23
+ ignore_mismatched_sizes=True
24
+ )
25
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
26
+ model.to(device)
27
+ model.eval()
28
+
29
+ # --- 2. Define Prediction Function ---
30
+ def classify_image(input_image: Image.Image):
31
+ if input_image is None:
32
+ return "Please upload an image."
33
+ if input_image.mode != "RGB":
34
+ input_image = input_image.convert("RGB")
35
+
36
+ inputs = processor(images=input_image, return_tensors="pt")
37
+ pixel_values = inputs['pixel_values'].to(device)
38
+
39
+ with torch.no_grad():
40
+ outputs = model(pixel_values)
41
+
42
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
43
+
44
+ # --- START OF MODIFICATION ---
45
+
46
+ # Get the top class and its confidence score
47
+ top_confidence, top_idx = torch.max(probabilities, dim=1)
48
+ top_confidence_score = top_confidence.item()
49
+ top_class_name = CLASS_NAMES[top_idx.item()]
50
+
51
+ # Check if the confidence is below our threshold
52
+ if top_confidence_score < CONFIDENCE_THRESHOLD:
53
+ # Return a custom label for low-confidence predictions
54
+ return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score}
55
+
56
+ # --- END OF MODIFICATION ---
57
+
58
+ # If confidence is high enough, return the normal dictionary
59
+ confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
60
+
61
+ return confidences
62
+
63
+ # --- 3. Create the Gradio Interface ---
64
+ iface = gr.Interface(
65
+ fn=classify_image,
66
+ inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
67
+ outputs=gr.Label(num_top_classes=3, label="Predictions"),
68
+ title="Swin Transformer Chest X-Ray Classifier",
69
+ description="Upload an X-ray image to classify it as COVID-19, Normal, or Pneumonia."
70
+ )
71
+
72
+ # --- 4. Launch the app ---
73
  iface.launch()