Stroke-ia commited on
Commit
f823b59
·
verified ·
1 Parent(s): 0152739

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +137 -165
main.py CHANGED
@@ -1,191 +1,163 @@
1
  import os
2
- os.environ["TORCH_HOME"] = "/tmp/torch"
3
- os.environ["HF_HOME"] = "/tmp/huggingface"
4
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
5
- os.environ["XDG_CACHE_HOME"] = "/tmp"
6
-
7
- from pipline import Transformer_Regression, extract_regions_Last, compute_ratios
8
- import torch
9
- import torchvision.transforms as transforms
10
- import gradio as gr
11
  import cv2
 
 
12
  import numpy as np
 
13
  from PIL import Image
14
-
15
-
16
- from pipline import Transformer_Regression, extract_regions_Last , compute_ratios
17
- import torch
18
  import torchvision.transforms as transforms
19
  from torch.nn import functional as F
20
- import cv2
21
- import gradio as gr
22
- import numpy as np
23
- from PIL import Image
24
 
 
 
 
 
 
 
 
 
25
 
26
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
 
28
- ## Define some parameters
29
- image_shape = 384 #### 512 got 87
30
- batch_size=1
31
- dim_patch=4
32
- num_classes=3
33
- label_smoothing=0.1
34
- scale=1
35
- import time
36
- start = time.time()
37
- torch.manual_seed(0)
38
- #import random
39
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  tfms = transforms.Compose([
42
  transforms.Resize((image_shape, image_shape)),
43
  transforms.ToTensor(),
44
- transforms.Normalize(0.5,0.5)
45
- #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
46
- #transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
47
-
48
  ])
49
 
50
- def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2):
 
 
 
51
  Model.eval()
52
- score_cup = []
53
- score_disc = []
54
- yreg_pred = []
55
- yreg_true = []
56
  with torch.no_grad():
57
- #for batch_sampler in loader:
58
- train_batch_tfms = batch_sampler['image'].to(device=device)
59
- #ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy()
60
- ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy()
61
- scores = Model(train_batch_tfms.unsqueeze(0))
62
-
63
- yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear',
64
- align_corners=True)
65
-
66
-
67
- # Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy())
68
- Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']),
69
- yseg_pred.argmax(1).long()[0].detach().cpu().numpy())
70
- Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB')
71
-
72
- ### Get back if two heads
73
- ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1],
74
- Regions_crop['cord'][2]:Regions_crop['cord'][3]]
75
- ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0)
76
-
77
- if num_head==2:
78
- scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device))
79
- yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]),
80
  mode='bilinear', align_corners=True)
81
- yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
82
- Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop
83
- # yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
84
- # Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop
85
- yseg_pred = torch.softmax(yseg_pred, dim=1)
86
- yseg_pred = yseg_pred.argmax(1).long()
87
- yseg_pred = ((yseg_pred).long()).detach().cpu().numpy()
88
- ratios = compute_ratios(yseg_pred[0])
89
- yreg_pred.append(ratios.vcdr)
90
-
91
- ### Plot
92
- p_img = batch_sampler['image'].to(device=device).unsqueeze(0)
93
- p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]),
94
- mode='bilinear', align_corners=True)
95
- ### Get reversed image
96
- image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy()
97
- image_orig=np.uint8(image_orig*255)
98
- ####
99
- # train_batch_tfms
100
- #plt.imshow(image_orig)
101
- # make a copy as these operations are destructive
102
- image_cont = image_orig.copy()
103
- ###### plot for Prediction....
104
- # threshold for 2 value
105
- ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0)
106
- # find and draw contour for 2 value (red)
107
- conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
108
- cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2)
109
- #threshold for 1 value
110
- ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0)
111
- #find and draw contour for 1 value (blue)
112
- conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
113
- cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2)
114
- #plot contoured image
115
-
116
- # plt.imshow(image_cont)
117
- # plt.axis('off')
118
-
119
- # print('Vertical cup to disc ratio:')
120
- # print(ratios.vcdr)
121
- if ratios.vcdr < 0.6:
122
- glaucoma = 'None'
123
- else:
124
- glaucoma = 'May be there is a risk of Glaucoma'
125
-
126
- # print('Galucoma:')
127
-
128
 
129
  return image_cont, ratios.vcdr, glaucoma, Regions_crop
130
 
131
- #load model
132
- DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128)
133
- DeepLab.to(device=device)
134
- DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device)))
135
-
136
- def infer(img):
137
- # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
138
-
139
- sample_batch = dict()
140
- sample_batch['image_original'] = img
141
 
142
- im_retina_pil = Image.fromarray(img)
 
 
 
 
143
 
144
- im_retina_pil = tfms(im_retina_pil)
145
- sample_batch['image'] = im_retina_pil
 
 
 
146
 
147
- # plt.figure('Head2')
148
  result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2)
149
-
150
- # cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB)
151
- cropped = result[cropped['cord'][0] :cropped['cord'][1] ,
152
- cropped['cord'][2] :cropped['cord'][3] ]
153
-
154
- return ratio, diagnosis, result, cropped
155
-
156
-
157
- title = "Glaucoma Detection in Retinal Fundus Images"
158
- description = "The method detects disc and cup in the retinal image, then it computes the Vertical cup to disc ratio"
159
-
160
- outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis (Rule of thumb ~0.6 or greater is suspicious)"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')]
161
- with gr.Blocks(css='#title {text-align : center;} ') as demo:
162
- with gr.Row():
163
- gr.Markdown(
164
- f'''
165
- # {title}
166
- {description}
167
-
168
- ''',
169
- elem_id='title'
170
- )
171
- with gr.Row():
172
- with gr.Column():
173
- prompt = gr.Image(label="Upload Your Retinal Fundus Image")
174
- btn = gr.Button(value='Submit')
175
- examples = gr.Examples(
176
- ['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'],
177
- inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False)
178
- with gr.Column():
179
- with gr.Row():
180
- text1 = gr.Textbox(label="Vertical Cup to Disc Ratio:")
181
- text2 = gr.Textbox(label="Predicted Diagnosis (Rule of thumb ~0.6 or greater is suspicious)")
182
- img = gr.Image(label='Detected disc and cup')
183
- zoom = gr.Image(label='Croppped')
184
-
185
- outputs = [text1,text2,img,zoom]
186
-
187
- btn.click(fn=infer, inputs=prompt, outputs=outputs)
188
-
189
-
190
- if __name__ == '__main__':
191
- demo.launch()
 
1
  import os
 
 
 
 
 
 
 
 
 
2
  import cv2
3
+ import time
4
+ import torch
5
  import numpy as np
6
+ import threading
7
  from PIL import Image
8
+ from datetime import datetime
9
+ from fastapi import FastAPI, UploadFile, File
10
+ from fastapi.staticfiles import StaticFiles
 
11
  import torchvision.transforms as transforms
12
  from torch.nn import functional as F
 
 
 
 
13
 
14
+ # -----------------------------
15
+ # 1. Environnement et config
16
+ # -----------------------------
17
+ os.environ["TORCH_HOME"] = "/tmp/torch"
18
+ os.environ["HF_HOME"] = "/tmp/huggingface"
19
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
20
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
21
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
22
 
23
+ from pipline import Transformer_Regression, extract_regions_Last, compute_ratios
24
 
25
+ MODEL_PATH = "TrainAll_Maghrabi84_50iteration_SWIN.pth.tar"
26
+ OUTPUT_DIR = "/tmp/outputs"
27
+ BASE_URL = "https://stroke-ia-detect-glocom.hf.space" # ⚠️ à adapter à ton domaine
 
 
 
 
 
 
 
 
28
 
29
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
30
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
+ # -----------------------------
33
+ # 2. Initialisation modèle
34
+ # -----------------------------
35
+ image_shape = 384
36
+ dim_patch = 4
37
+ scale = 1
38
+
39
+ DeepLab = Transformer_Regression(
40
+ image_dim=image_shape, dim_patch=dim_patch, num_classes=3, scale=scale, feat_dim=128
41
+ )
42
+ DeepLab.to(device)
43
+ DeepLab.load_state_dict(torch.load(MODEL_PATH, map_location=device))
44
+ DeepLab.eval()
45
+
46
+ # -----------------------------
47
+ # 3. Prétraitement
48
+ # -----------------------------
49
  tfms = transforms.Compose([
50
  transforms.Resize((image_shape, image_shape)),
51
  transforms.ToTensor(),
52
+ transforms.Normalize(0.5, 0.5)
 
 
 
53
  ])
54
 
55
+ # -----------------------------
56
+ # 4. Inférence
57
+ # -----------------------------
58
+ def Final_Compute_regression_results_Sample(Model, batch_sampler, num_head=2):
59
  Model.eval()
 
 
 
 
60
  with torch.no_grad():
61
+ train_batch_tfms = batch_sampler['image'].to(device)
62
+ ytrue_seg = batch_sampler['image_original']
63
+
64
+ scores = Model(train_batch_tfms.unsqueeze(0))
65
+ yseg_pred = F.interpolate(scores['seg'],
66
+ size=(ytrue_seg.shape[0], ytrue_seg.shape[1]),
67
+ mode='bilinear', align_corners=True)
68
+
69
+ Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']),
70
+ yseg_pred.argmax(1).long()[0].cpu().numpy())
71
+ Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB')
72
+
73
+ if num_head == 2:
74
+ scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device))
75
+ yseg_pred_crop = F.interpolate(scores['seg_aux_1'],
76
+ size=(Regions_crop['image'].size[1],
77
+ Regions_crop['image'].size[0]),
 
 
 
 
 
 
78
  mode='bilinear', align_corners=True)
79
+ yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1],
80
+ Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop
81
+
82
+ yseg_pred = torch.softmax(yseg_pred, dim=1)
83
+ yseg_pred = yseg_pred.argmax(1).long().cpu().numpy()
84
+ ratios = compute_ratios(yseg_pred[0])
85
+
86
+ p_img = batch_sampler['image'].to(device).unsqueeze(0)
87
+ p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]),
88
+ mode='bilinear', align_corners=True)
89
+ image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).cpu().numpy()
90
+ image_orig = np.uint8(image_orig * 255)
91
+ image_cont = image_orig.copy()
92
+
93
+ # Contours
94
+ ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0)
95
+ conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
96
+ cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2)
97
+ ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0)
98
+ conts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
99
+ cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2)
100
+
101
+ if ratios.vcdr < 0.6:
102
+ glaucoma = "None"
103
+ else:
104
+ glaucoma = "May be there is a risk of Glaucoma"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  return image_cont, ratios.vcdr, glaucoma, Regions_crop
107
 
108
+ # -----------------------------
109
+ # 5. FastAPI app
110
+ # -----------------------------
111
+ app = FastAPI(title="Glaucoma Detection API")
112
+ app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
 
 
 
 
 
113
 
114
+ @app.post("/predict/")
115
+ async def predict(image_file: UploadFile = File(...)):
116
+ tmp_path = f"/tmp/{image_file.filename}"
117
+ with open(tmp_path, "wb") as f:
118
+ f.write(await image_file.read())
119
 
120
+ img = np.array(Image.open(tmp_path).convert("RGB"))
121
+ sample_batch = {
122
+ "image_original": img,
123
+ "image": tfms(Image.fromarray(img))
124
+ }
125
 
 
126
  result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2)
127
+ cropped_img = result[cropped['cord'][0]:cropped['cord'][1],
128
+ cropped['cord'][2]:cropped['cord'][3]]
129
+
130
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
131
+ out_img_name = f"glaucoma_result_{timestamp}.png"
132
+ out_zoom_name = f"glaucoma_zoom_{timestamp}.png"
133
+ out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
134
+ out_zoom_path = os.path.join(OUTPUT_DIR, out_zoom_name)
135
+ cv2.imwrite(out_img_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
136
+ cv2.imwrite(out_zoom_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
137
+
138
+ os.remove(tmp_path)
139
+
140
+ return {
141
+ "ratio": round(float(ratio), 3),
142
+ "diagnosis": diagnosis,
143
+ "overlay_url": f"{BASE_URL}/files/{out_img_name}",
144
+ "zoom_url": f"{BASE_URL}/files/{out_zoom_name}",
145
+ "message": "✅ Glaucoma analysis complete"
146
+ }
147
+
148
+ # -----------------------------
149
+ # 6. Auto-cleanup (toutes les 10 min)
150
+ # -----------------------------
151
+ def auto_cleanup(interval_minutes=10):
152
+ while True:
153
+ time.sleep(interval_minutes * 60)
154
+ for filename in os.listdir(OUTPUT_DIR):
155
+ path = os.path.join(OUTPUT_DIR, filename)
156
+ try:
157
+ if os.path.isfile(path):
158
+ os.remove(path)
159
+ print(f"[CLEANUP] Removed {path}")
160
+ except Exception as e:
161
+ print(f"[CLEANUP] Error removing {path}: {e}")
162
+
163
+ threading.Thread(target=auto_cleanup, daemon=True).start()