IFMedTechdemo commited on
Commit
552875d
·
verified ·
1 Parent(s): 5a2c57d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import torchvision
5
+ import numpy as np
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
10
+ import os
11
+ import io
12
+
13
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
14
+
15
+ # Class names and colors
16
+ CLASS_NAMES = {1: 'Nipple', 2: 'Lump'}
17
+ CLASS_COLORS = {1: 'white', 2: 'white'}
18
+
19
+ def preprocess_image(image):
20
+ """Load and preprocess image for Faster R-CNN."""
21
+ # Convert PIL Image to numpy array
22
+ image = np.array(image)
23
+
24
+ # Convert RGB to RGB (already in correct format from Gradio)
25
+ image = image.astype(np.float32) / 255.0 # Normalize to [0,1]
26
+
27
+ # Normalize using ImageNet mean and std
28
+ mean = np.array([0.485, 0.456, 0.406])
29
+ std = np.array([0.229, 0.224, 0.225])
30
+ image = (image - mean) / std
31
+
32
+ return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)
33
+
34
+ def load_model(checkpoint_path, device):
35
+ """Load Faster R-CNN model with fine-tuned weights."""
36
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
37
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
38
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASS_NAMES) + 1)
39
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
40
+ model.to(device).eval()
41
+ return model
42
+
43
+ def predict(image, score_thresh=0.5):
44
+ """Run inference and return image with bounding boxes."""
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ # Load model
48
+ checkpoint_path = "lumps.pth" # This will be downloaded from the model repo
49
+ model = load_model(checkpoint_path, device)
50
+
51
+ # Preprocess image
52
+ image_tensor = preprocess_image(image)
53
+
54
+ # Run inference
55
+ model.eval()
56
+ with torch.no_grad():
57
+ preds = model([image_tensor.to(device)])[0]
58
+
59
+ boxes, labels, scores = preds['boxes'].cpu().numpy(), preds['labels'].cpu().numpy(), preds['scores'].cpu().numpy()
60
+
61
+ # Filter based on confidence threshold
62
+ keep = scores >= score_thresh
63
+ boxes, labels, scores = boxes[keep], labels[keep], scores[keep]
64
+
65
+ # Convert tensor back to image
66
+ mean = np.array([0.485, 0.456, 0.406])
67
+ std = np.array([0.229, 0.224, 0.225])
68
+ image_np = image_tensor.cpu().permute(1, 2, 0).numpy() * std + mean
69
+ image_np = np.clip(image_np, 0, 1)
70
+
71
+ # Create figure
72
+ fig, ax = plt.subplots(1, figsize=(12, 12))
73
+ ax.imshow(image_np)
74
+
75
+ # Draw bounding boxes
76
+ for box, label, score in zip(boxes, labels, scores):
77
+ xmin, ymin, xmax, ymax = box
78
+ rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
79
+ linewidth=3, edgecolor=CLASS_COLORS.get(label, 'blue'),
80
+ facecolor='none')
81
+ ax.add_patch(rect)
82
+ ax.text(xmin, ymin - 10, f"{CLASS_NAMES.get(label, f'class_{label}')} ({score:.2f})",
83
+ fontsize=14, color='white', backgroundcolor='black', weight='bold')
84
+
85
+ plt.axis('off')
86
+
87
+ # Convert plot to image
88
+ buf = io.BytesIO()
89
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
90
+ buf.seek(0)
91
+ result_image = Image.open(buf)
92
+ plt.close()
93
+
94
+ return result_image
95
+
96
+ # Create Gradio interface
97
+ demo = gr.Interface(
98
+ fn=predict,
99
+ inputs=[
100
+ gr.Image(type="pil", label="Upload Breast Image"),
101
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
102
+ ],
103
+ outputs=gr.Image(type="pil", label="Detection Results"),
104
+ title="Breast Lumps Detection",
105
+ description="""Upload a breast image to detect lumps and nipples using a Faster R-CNN model.\n\n
106
+ ⚠️ **Important Medical Disclaimer**: This is a screening tool for research and assistive purposes only.
107
+ It should NOT be used as the sole basis for medical diagnosis. All detections must be reviewed and confirmed
108
+ by qualified medical professionals. This model is not FDA approved or certified for clinical diagnosis.""",
109
+ examples=None,
110
+ allow_flagging="never"
111
+ )
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()