Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from PIL import Image
|
|
@@ -9,11 +9,9 @@ MODEL_NAME = "microsoft/swin-tiny-patch4-window7-224"
|
|
| 9 |
MODEL_PATH = "best_model_swin.pth"
|
| 10 |
NUM_CLASSES = 3
|
| 11 |
CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
|
| 12 |
-
device = torch.device("cpu")
|
| 13 |
|
| 14 |
-
# --- ADDED ---
|
| 15 |
# We will reject any prediction where the model's top guess is below 90% confidence.
|
| 16 |
-
# You can adjust this value (e.g., to 0.95 or 0.85)
|
| 17 |
CONFIDENCE_THRESHOLD = 0.90
|
| 18 |
|
| 19 |
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
|
|
@@ -41,7 +39,6 @@ def classify_image(input_image: Image.Image):
|
|
| 41 |
|
| 42 |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 43 |
|
| 44 |
-
# --- START OF MODIFICATION ---
|
| 45 |
|
| 46 |
# Get the top class and its confidence score
|
| 47 |
top_confidence, top_idx = torch.max(probabilities, dim=1)
|
|
@@ -53,7 +50,6 @@ def classify_image(input_image: Image.Image):
|
|
| 53 |
# Return a custom label for low-confidence predictions
|
| 54 |
return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score}
|
| 55 |
|
| 56 |
-
# --- END OF MODIFICATION ---
|
| 57 |
|
| 58 |
# If confidence is high enough, return the normal dictionary
|
| 59 |
confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from PIL import Image
|
|
|
|
| 9 |
MODEL_PATH = "best_model_swin.pth"
|
| 10 |
NUM_CLASSES = 3
|
| 11 |
CLASS_NAMES = ['COVID19', 'NORMAL', 'PNEUMONIA']
|
| 12 |
+
device = torch.device("cpu")
|
| 13 |
|
|
|
|
| 14 |
# We will reject any prediction where the model's top guess is below 90% confidence.
|
|
|
|
| 15 |
CONFIDENCE_THRESHOLD = 0.90
|
| 16 |
|
| 17 |
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
|
|
|
|
| 39 |
|
| 40 |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 41 |
|
|
|
|
| 42 |
|
| 43 |
# Get the top class and its confidence score
|
| 44 |
top_confidence, top_idx = torch.max(probabilities, dim=1)
|
|
|
|
| 50 |
# Return a custom label for low-confidence predictions
|
| 51 |
return {f"Invalid Image or Low Confidence ({top_class_name})": top_confidence_score}
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
# If confidence is high enough, return the normal dictionary
|
| 55 |
confidences = {CLASS_NAMES[i]: prob.item() for i, prob in enumerate(probabilities[0])}
|