ReyaLabColumbia commited on
Commit
284bdef
·
verified ·
1 Parent(s): 5ac4683

Upload 2 files

Browse files
Files changed (2) hide show
  1. Colony_Analyzer_AI_zstack2_HF.py +356 -0
  2. app.py +50 -27
Colony_Analyzer_AI_zstack2_HF.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Mar 20 14:23:27 2025
5
+
6
+ @author: mattc
7
+ """
8
+
9
+ import os
10
+ import cv2
11
+ #this is the huggingface version
12
+ def cut_img(img):
13
+ img_map = {}
14
+ width, height = img.size
15
+ i_num = height // 512
16
+ j_num = width // 512
17
+ count = 1
18
+ for i in range(i_num):
19
+ for j in range(j_num):
20
+ cropped_img = img.crop((512*j, 512*i, 512*(j+1), 512*(i+1)))
21
+ img_map[count] = cropped_img
22
+ #print(type(cropped_img))
23
+ count += 1
24
+ return img_map
25
+
26
+ import numpy as np
27
+
28
+ def stitch(img_map):
29
+ rows = [
30
+ np.hstack([img_map[1], img_map[2], img_map[3], img_map[4]]), # First row (images 0 to 3)
31
+ np.hstack([img_map[5], img_map[6], img_map[7], img_map[8]]), # Second row (images 4 to 7)
32
+ np.hstack([img_map[9], img_map[10], img_map[11], img_map[12]]) # Third row (images 8 to 11)
33
+ ]
34
+ # Stack rows vertically
35
+ return(np.vstack(rows))
36
+
37
+
38
+ from PIL import Image
39
+
40
+
41
+ import matplotlib.pyplot as plt
42
+
43
+ def visualize_segmentation(mask, image=0):
44
+ plt.figure(figsize=(10, 5))
45
+
46
+ if(not np.isscalar(image)):
47
+ # Show original image if it is entered
48
+ plt.subplot(1, 2, 1)
49
+ plt.imshow(image)
50
+ plt.title("Original Image")
51
+ plt.axis("off")
52
+
53
+ # Show segmentation mask
54
+ plt.subplot(1, 2, 2)
55
+ plt.imshow(mask, cmap="gray") # Show as grayscale
56
+ plt.title("Segmentation Mask")
57
+ plt.axis("off")
58
+
59
+ plt.show()
60
+
61
+ import torch
62
+ from transformers import SegformerForSemanticSegmentation
63
+ # Load fine-tuned model
64
+ model = SegformerForSemanticSegmentation.from_pretrained("ReyaLabColumbia/Segformer_Colony_Counter") # Adjust path
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ model.to(device)
67
+ model.eval() # Set to evaluation mode
68
+
69
+ # Load image processor
70
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
71
+ image_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b3-finetuned-cityscapes-1024-1024")
72
+
73
+ def preprocess_image(image):
74
+ image = image.convert("RGB") # Open and convert to RGB
75
+ inputs = image_processor(image, return_tensors="pt") # Preprocess for model
76
+ return image, inputs["pixel_values"]
77
+
78
+ def postprocess_mask(logits):
79
+ mask = torch.argmax(logits, dim=1) # Take argmax across the class dimension
80
+ return mask.squeeze().cpu().numpy() # Convert to NumPy array
81
+
82
+
83
+ def eval_img(image):
84
+ # Load and preprocess image
85
+ image, pixel_values = preprocess_image(image)
86
+ pixel_values = pixel_values.to(device)
87
+ with torch.no_grad(): # No gradient calculation for inference
88
+ outputs = model(pixel_values=pixel_values) # Run model
89
+ logits = outputs.logits
90
+ # Convert logits to segmentation mask
91
+ segmentation_mask = postprocess_mask(logits)
92
+ #visualize_segmentation(segmentation_mask,image)
93
+ segmentation_mask = cv2.resize(segmentation_mask, (512, 512), interpolation=cv2.INTER_LINEAR_EXACT)
94
+ return(segmentation_mask)
95
+
96
+ def find_colonies(mask, size_cutoff, circ_cutoff):
97
+ binary_mask = np.where(mask == 1, 255, 0).astype(np.uint8)
98
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
99
+ contoursf = []
100
+ areas = []
101
+ for x in contours:
102
+ area = cv2.contourArea(x)
103
+ if (area < size_cutoff):
104
+ continue
105
+ perimeter = cv2.arcLength(x, True)
106
+
107
+ # Avoid division by zero
108
+ if perimeter == 0:
109
+ continue
110
+
111
+ # Calculate circularity
112
+ circularity = (4 * np.pi * area) / (perimeter ** 2)
113
+ if circularity >= circ_cutoff:
114
+ contoursf.append(x)
115
+ areas.append(area)
116
+ return(contoursf, areas)
117
+
118
+ def find_necrosis(mask):
119
+ binary_mask = np.where(mask == 2, 255, 0).astype(np.uint8)
120
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
121
+ return(contours)
122
+
123
+ # contour_image = np.zeros_like(p)
124
+ # contours = find_necrosis(p)
125
+ # cv2.drawContours(contour_image, contours, -1, (255), 2)
126
+ # visualize_segmentation(contour_image)
127
+ import pandas as pd
128
+ def compute_centroid(contour):
129
+ M = cv2.moments(contour)
130
+ if M["m00"] == 0: # Avoid division by zero
131
+ return None
132
+ cx = int(M["m10"] / M["m00"])
133
+ cy = int(M["m01"] / M["m00"])
134
+ return (cx, cy)
135
+
136
+
137
+ def contours_overlap_using_mask(contour1, contour2, image_shape=(1536, 2048)):
138
+ """Check if two contours overlap using a bitwise AND mask."""
139
+ import numpy as np
140
+ import cv2
141
+ mask1 = np.zeros(image_shape, dtype=np.uint8)
142
+ mask2 = np.zeros(image_shape, dtype=np.uint8)
143
+
144
+
145
+ # Draw each contour as a white shape on its respective mask
146
+ cv2.drawContours(mask1, [contour1], -1, 255, thickness=cv2.FILLED)
147
+ cv2.drawContours(mask2, [contour2], -1, 255, thickness=cv2.FILLED)
148
+
149
+
150
+ # Compute bitwise AND to find overlapping regions
151
+ overlap = cv2.bitwise_and(mask1, mask2)
152
+
153
+ return np.any(overlap)
154
+
155
+ def analyze_colonies(mask, size_cutoff, circ_cutoff):
156
+ colonies,areas = find_colonies(mask, size_cutoff, circ_cutoff)
157
+ necrosis = find_necrosis(mask)
158
+
159
+ data = []
160
+
161
+ for x in range(len(colonies)):
162
+ colony = colonies[x]
163
+ colony_area = areas[x]
164
+ centroid = compute_centroid(colony)
165
+
166
+ # Check if any necrosis contour is inside the colony
167
+ necrosis_area = 0
168
+ nec_list =[]
169
+ for nec in necrosis:
170
+ # Check if the first point of the necrosis contour is inside the colony
171
+ if contours_overlap_using_mask(colony, nec):
172
+ nec_area = cv2.contourArea(nec)
173
+ necrosis_area += nec_area
174
+ nec_list.append(nec)
175
+
176
+ data.append({
177
+ "colony_area": colony_area,
178
+ "necrosis_area": necrosis_area,
179
+ "centroid": centroid,
180
+ "percent_necrosis": necrosis_area/colony_area,
181
+ "contour": colony,
182
+ "nec_contours": nec_list
183
+ })
184
+
185
+ # Convert results to a DataFrame
186
+ df = pd.DataFrame(data)
187
+ df.index = range(1,len(df.index)+1)
188
+ return(df)
189
+
190
+
191
+ def contour_overlap(contour1, contour2, centroid1, centroid2, area1, area2, centroid_thresh=30, area_thresh = .4, img_shape = (1536, 2048)):
192
+ """
193
+ Determines the overlap between two contours.
194
+ Returns:
195
+ 0: No overlap
196
+ 1: Overlap but does not meet strict conditions
197
+ 2: Overlap >= 80% of the larger contour and centroids are close
198
+ """
199
+ # Create blank images
200
+ img1 = np.zeros(img_shape, dtype=np.uint8)
201
+ img2 = np.zeros(img_shape, dtype=np.uint8)
202
+
203
+ # Draw filled contours
204
+ cv2.drawContours(img1, [contour1], -1, 255, thickness=cv2.FILLED)
205
+ cv2.drawContours(img2, [contour2], -1, 255, thickness=cv2.FILLED)
206
+
207
+ # Compute overlap
208
+ intersection = cv2.bitwise_and(img1, img2)
209
+ intersection_area = np.count_nonzero(intersection)
210
+
211
+ if intersection_area == 0:
212
+ return 0 # No overlap
213
+
214
+ # Compute centroid distance
215
+ centroid_distance = float(np.sqrt(abs(centroid1[0]-centroid2[0])**2 + abs(centroid1[1]-centroid2[1])**2))
216
+ # Check percentage overlap relative to the larger contour
217
+ overlap_ratio = intersection_area/max(area1, area2)
218
+ if overlap_ratio >= area_thresh and centroid_distance <= centroid_thresh:
219
+ if area1 > area2:
220
+ return(2)
221
+ else:
222
+ return(3)
223
+ else:
224
+ return 1 # Some overlap but not meeting strict criteria
225
+
226
+ def compare_frames(frame1, frame2):
227
+ for i in range(1, len(frame1)+1):
228
+ if frame1.loc[i,"exclude"] == True:
229
+ continue
230
+ for j in range(1, len(frame2)+1):
231
+ if frame2.loc[j,"exclude"] == True:
232
+ continue
233
+ temp = contour_overlap(frame1.loc[i, "contour"], frame2.loc[j, "contour"], frame1.loc[i, "centroid"], frame2.loc[j, "centroid"], frame1.loc[i, "colony_area"], frame2.loc[j, "colony_area"])
234
+ if temp ==2:
235
+ frame2.loc[j,"exclude"] = True
236
+ elif temp ==3:
237
+ frame1.loc[i, "exclude"] = True
238
+ break
239
+ frame1 = frame1[frame1["exclude"]==False]
240
+ frame2 = frame2[frame2["exclude"]==False]
241
+ df = pd.concat([frame1, frame2], axis=0)
242
+ df.index = range(1,len(df.index)+1)
243
+ return(df)
244
+
245
+ def main(args):
246
+ min_size = args[1]
247
+ min_circ = args[2]
248
+ colonies = {}
249
+ files = args[0]
250
+ for idx,x in enumerate(files):
251
+ img_map = cut_img(files[idx])
252
+ for z in img_map:
253
+ img_map[z] = eval_img(img_map[z])
254
+ del z
255
+ p = stitch(img_map)
256
+ frame = analyze_colonies(p, min_size, min_circ)
257
+ frame["source"] = idx
258
+ frame["exclude"] = False
259
+ if isinstance(colonies, dict):
260
+ colonies = frame
261
+ else:
262
+ colonies = compare_frames(frame, colonies)
263
+ counts = {}
264
+ for x in range(len(files)):
265
+ counts[x] = list(colonies["source"]).count(x)
266
+ best = [x, counts[x]]
267
+ del x
268
+ for x in counts:
269
+ if counts[x] > best[1]:
270
+ best[0] = x
271
+ best[1] = counts[x]
272
+ del x, counts
273
+ best = best[0]
274
+ img = np.array(files[best])
275
+ for x in range(len(files)):
276
+ if x == best:
277
+ continue
278
+ mask = np.zeros_like(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
279
+ contours = colonies[colonies["source"]==x]
280
+ contours = list(contours["contour"])
281
+ cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
282
+ # Extract all ROIs from the source image at once
283
+ src_image = np.array(files[x])
284
+ roi = cv2.bitwise_and(src_image, src_image, mask=mask)
285
+ # Paste the extracted regions onto the destination image
286
+ np.copyto(img, roi, where=(mask[..., None] == 255))
287
+ try:
288
+ del x, mask, src_image, roi, best, contours
289
+ except:
290
+ pass
291
+
292
+ img = cv2.copyMakeBorder(img,top=0, bottom=10,left=0,right=10, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
293
+ colonies = colonies.sort_values(by=["colony_area"], ascending=False)
294
+ colonies = colonies[colonies["colony_area"]>= min_size]
295
+ colonies.index = range(1,len(colonies.index)+1)
296
+ #nearby is a boolean list of whether a colony has overlapping colonies. If so, labelling positions change
297
+ nearby = [False]*len(colonies)
298
+ areas = list(colonies["colony_area"])
299
+ for i in range(len(colonies)):
300
+ cv2.drawContours(img, [list(colonies["contour"])[i]], -1, (0, 255, 0), 2)
301
+ cv2.drawContours(img, list(colonies['nec_contours'])[i], -1, (0, 0, 255), 2)
302
+ coords = list(list(colonies["centroid"])[i])
303
+ if coords[0] > 1950:
304
+ #if a colony is too close to the right edge, makes the label move to left
305
+ coords[0] = 1950
306
+ for j in range(len(colonies)):
307
+ if j == i:
308
+ continue
309
+ coords2 = list(list(colonies["centroid"])[j])
310
+ if ((abs(coords[0] - coords2[0]) + abs(coords[1] - coords2[1])) <= 40):
311
+ nearby[i] = True
312
+ break
313
+ if nearby[i] ==True:
314
+ #If the colony has nearby colonies, this adjusts the labels so they are smaller and are positioned based on the approximate radius of the colony
315
+ # a random number is generated, and based on that, the label is put at the top or bottom, left or right
316
+ radius= int(np.sqrt(areas[i]/3.1415)*.9)
317
+ n = np.random.random()
318
+ if n >.75:
319
+ new_x = min(coords[0] + radius, 2000)
320
+ new_y = min(coords[1] + radius, 1480)
321
+ elif n >.5:
322
+ new_x = min(coords[0] + radius, 2000)
323
+ new_y = max(coords[1] - radius, 50)
324
+ elif n >.25:
325
+ new_x = max(coords[0] - radius, 0)
326
+ new_y = min(coords[1] + radius, 1480)
327
+ else:
328
+ new_x = max(coords[0] - radius, 0)
329
+ new_y = max(coords[1] - radius, 50)
330
+ cv2.putText(img, str(colonies.index[i]), (new_x,new_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
331
+ del n, radius, new_x, new_y
332
+ else:
333
+ cv2.putText(img, str(colonies.index[i]), coords, cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
334
+ del nearby, areas
335
+ colonies = colonies.drop('contour', axis=1)
336
+ colonies = colonies.drop('nec_contours', axis=1)
337
+ colonies = colonies.drop('exclude', axis=1)
338
+ img = cv2.copyMakeBorder(img,top=10, bottom=0,left=10,right=0, borderType=cv2.BORDER_CONSTANT, value=[255, 255, 255])
339
+
340
+ colonies.insert(loc=0, column="Colony Number", value=[str(x) for x in range(1, len(colonies)+1)])
341
+ total_area_dark = sum(colonies['necrosis_area'])
342
+ total_area_light = sum(colonies['colony_area'])
343
+ ratio = total_area_dark/(abs(total_area_light)+1)
344
+
345
+ colonies.loc[len(colonies)+1] = ["Total", total_area_light, total_area_dark, None, ratio, None]
346
+ Parameters = pd.DataFrame({"Minimum colony size in pixels":[min_size], "Minimum colony circularity":[min_circ]})
347
+ with pd.ExcelWriter("Group_analysis_results.xlsx") as writer:
348
+ colonies.to_excel(writer, sheet_name="Colony data", index=False)
349
+ Parameters.to_excel(writer, sheet_name="Parameters", index=False)
350
+ caption = np.ones((150, 2068, 3), dtype=np.uint8) * 255 # Multiply by 255 to make it white
351
+ cv2.putText(caption, "Total area necrotic: "+str(total_area_dark)+ ", Total area living: "+str(total_area_light)+", Ratio: "+str(ratio), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
352
+
353
+
354
+
355
+ cv2.imwrite('Group_analysis_results.png', np.vstack((img, caption)))
356
+ return(np.vstack((img, caption)), 'Group_analysis_results.png', 'Group_analysis_results.xlsx')
app.py CHANGED
@@ -1,33 +1,56 @@
1
  import gradio as gr
2
- import Colony_Analyzer_AI2_HF as analyzer
3
  from PIL import Image
4
- import cv2
5
- import numpy as np
6
 
7
- # Analysis function adapted from your Tkinter app
8
  def analyze_image(image, min_size, circularity):
9
- # Assume your analyzer.main accepts [image, params] format, adjust as needed
10
- processed_img,picname, excelname = analyzer.main([image, min_size, circularity])
11
- print(type(processed_img))
12
- # Convert back to RGB for display
13
- #result = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
14
  return Image.fromarray(processed_img), picname, excelname
15
 
16
- # Create Gradio interface
17
- iface = gr.Interface(
18
- fn=analyze_image,
19
- inputs=[
20
- gr.Image(type="pil", label="Upload Image"),
21
- gr.Number(label="Minimum Colony Size (pixels)", value=1000),
22
- gr.Number(label="Minimum Circularity", value=0.25)
23
- ],
24
- outputs=[
25
- gr.Image(type="pil", label="Analyzed Image"),
26
- gr.File(label="Download Image"),
27
- gr.File(label="Download results (Excel)")
28
- ],
29
- title="AI Colony Analyzer",
30
- description="Upload an image to run the colony analysis."
31
- )
32
-
33
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from PIL import Image
 
 
3
 
4
+ # Single image analysis function (your existing logic)
5
  def analyze_image(image, min_size, circularity):
6
+ import Colony_Analyzer_AI2_HF as analyzer
7
+ processed_img, picname, excelname = analyzer.main([image, min_size, circularity])
 
 
 
8
  return Image.fromarray(processed_img), picname, excelname
9
 
10
+ # Z-stack analysis function (adapt with your own logic)
11
+ def analyze_zstack(images, min_size, circularity):
12
+ # images: list of PIL images
13
+ # Plug in your own z-stack segmentation logic here
14
+ # Example stub: pass images as a list to your analyzer
15
+ import Colony_Analyzer_AI_zstack2_HF as analyzer
16
+ processed_img, picname, excelname = analyzer.main([images, min_size, circularity])
17
+ return Image.fromarray(processed_img), picname, excelname
18
+
19
+ with gr.Blocks() as demo:
20
+ gr.Markdown("# AI Colony Analyzer\nUpload an image (or Z-Stack) to run colony analysis.")
21
+
22
+ z_stack_checkbox = gr.Checkbox(label="Enable Z-Stack", value=False)
23
+ image_input_single = gr.Image(type="pil", label="Upload Image", visible=True)
24
+ image_input_multi = gr.Image(type="pil", label="Upload Z-Stack Images", file_count="multiple", visible=False)
25
+ min_size_input = gr.Number(label="Minimum Colony Size (pixels)", value=1000)
26
+ circularity_input = gr.Number(label="Minimum Circularity", value=0.25)
27
+ output_image = gr.Image(type="pil", label="Analyzed Image")
28
+ output_file_img = gr.File(label="Download Image")
29
+ output_file_excel = gr.File(label="Download results (Excel)")
30
+ process_btn = gr.Button("Process")
31
+
32
+ def toggle_inputs(z_stack_enabled):
33
+ return (
34
+ gr.update(visible=not z_stack_enabled), # single input
35
+ gr.update(visible=z_stack_enabled) # multi input
36
+ )
37
+
38
+ z_stack_checkbox.change(
39
+ toggle_inputs,
40
+ inputs=z_stack_checkbox,
41
+ outputs=[image_input_single, image_input_multi]
42
+ )
43
+
44
+ def conditional_analyze(z_stack, single_image, multi_images, min_size, circularity):
45
+ if z_stack:
46
+ return analyze_zstack(multi_images, min_size, circularity)
47
+ else:
48
+ return analyze_image(single_image, min_size, circularity)
49
+
50
+ process_btn.click(
51
+ conditional_analyze,
52
+ inputs=[z_stack_checkbox, image_input_single, image_input_multi, min_size_input, circularity_input],
53
+ outputs=[output_image, output_file_img, output_file_excel]
54
+ )
55
+
56
+ demo.launch()