Matiullah00999 commited on
Commit
b7d8234
·
verified ·
1 Parent(s): dffb58b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -221
app.py CHANGED
@@ -1,221 +1,164 @@
1
- # Hugging Face Gradio App: Aggregate Analysis with Feret and Boundary Extension
2
-
3
- import os
4
- import cv2
5
- import torch
6
- import numpy as np
7
- import gradio as gr
8
- import matplotlib.pyplot as plt
9
- import pandas as pd
10
- from glob import glob
11
- from PIL import Image
12
- from skimage.measure import regionprops, label
13
- from scipy.spatial.distance import cdist
14
- from scipy.spatial import Delaunay
15
- from io import BytesIO
16
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
17
- import segmentation_models_pytorch as smp
18
-
19
- # Configuration
20
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
- DIAMETER_MM = 152.4
22
- MIN_SIZE = 256
23
-
24
- class PetModel(torch.nn.Module):
25
- def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
26
- super().__init__()
27
- self.model = smp.create_model(
28
- arch, encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
29
- )
30
- params = smp.encoders.get_preprocessing_params(encoder_name)
31
- self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
32
- self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
33
-
34
- def forward(self, image):
35
- image = (image - self.mean) / self.std
36
- return self.model(image)
37
-
38
- def preprocess_image(image, min_size=MIN_SIZE):
39
- image = np.array(image)
40
- if len(image.shape) == 2:
41
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
42
- elif image.shape[2] == 4:
43
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
44
- elif image.shape[2] == 1:
45
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
46
-
47
- original_size = image.shape[:2]
48
- h, w = image.shape[:2]
49
- if h < min_size or w < min_size:
50
- new_size = (max(w, min_size), max(h, min_size))
51
- image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
52
-
53
- image = image.astype(np.float32) / 255.0
54
- image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
55
- return image, original_size
56
-
57
- def postprocess_output(output, original_size):
58
- prob_mask = output.sigmoid()
59
- pred_mask = (prob_mask > 0.5).float()
60
- pred_mask = pred_mask.squeeze().cpu().numpy()
61
- if pred_mask.shape != original_size:
62
- pred_mask = cv2.resize(pred_mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)
63
- return pred_mask
64
-
65
- def load_model(model_path):
66
- model = PetModel("unet", "efficientnet-b5", in_channels=3, out_classes=1)
67
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
68
- model = model.to(DEVICE)
69
- model.eval()
70
- return model
71
-
72
- model = load_model("segmentation_model_final.pth")
73
-
74
- def fig_to_image(fig):
75
- buf = BytesIO()
76
- canvas = FigureCanvas(fig)
77
- canvas.print_png(buf)
78
- buf.seek(0)
79
- return Image.open(buf)
80
-
81
- def predict(image):
82
- input_tensor, original_size = preprocess_image(image)
83
- input_tensor = input_tensor.to(DEVICE)
84
-
85
- with torch.no_grad():
86
- output = model(input_tensor)
87
-
88
- prediction_mask = postprocess_output(output, original_size)
89
- image_np = np.array(image)
90
- gray_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
91
-
92
- # Calibration using outer boundary
93
- _, bw = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
94
- contours, _ = cv2.findContours(bw, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
95
- pixel_length_mm = 1.0
96
- dia = DIAMETER_MM
97
- summary = {}
98
-
99
- figs = []
100
- if contours:
101
- boundary = contours[0].squeeze()
102
- dist_matrix = cdist(boundary, boundary)
103
- i, j = np.unravel_index(np.argmax(dist_matrix), dist_matrix.shape)
104
- line_pts = np.array([boundary[i], boundary[j]])
105
- pixel_diameter = np.linalg.norm(boundary[i] - boundary[j])
106
- pixels_per_mm = pixel_diameter / dia
107
- pixel_length_mm = 1 / pixels_per_mm
108
- line_length_mm = pixel_diameter * pixel_length_mm
109
-
110
- # Boundary Plot
111
- fig1 = plt.figure(figsize=(6, 6))
112
- plt.imshow(image_np)
113
- plt.plot(boundary[:, 0], boundary[:, 1], 'g', linewidth=2)
114
- plt.plot(line_pts[:, 0], line_pts[:, 1], 'r', linewidth=2)
115
- plt.title(f'Line Length: {line_length_mm:.2f} mm')
116
- plt.axis('off')
117
- figs.append(fig1)
118
-
119
- # Feret Analysis
120
- label_img = (prediction_mask > 0.5).astype(np.uint8)
121
- binary_mask = (label_img * 255).astype(np.uint8)
122
- color_mask = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
123
- feret_lengths, feret_widths, rectangles = [], [], []
124
- contours_mask, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
125
- for cnt in contours_mask:
126
- if len(cnt) >= 5:
127
- rect = cv2.minAreaRect(cnt)
128
- box = cv2.boxPoints(rect).astype(np.intp)
129
- width, height = rect[1]
130
- feret_length = max(width, height)
131
- feret_lengths.append(feret_length)
132
- feret_widths.append(min(width, height))
133
- rectangles.append((box, feret_length))
134
-
135
- thresholds = np.percentile(feret_lengths, [20, 40, 60, 80]) if feret_lengths else [0, 0, 0, 0]
136
- colors = [(0, 0, 255), (0, 128, 255), (0, 255, 255), (0, 255, 0), (255, 0, 0)]
137
-
138
- for box, length in rectangles:
139
- if length <= thresholds[0]: color = colors[0]
140
- elif length <= thresholds[1]: color = colors[1]
141
- elif length <= thresholds[2]: color = colors[2]
142
- elif length <= thresholds[3]: color = colors[3]
143
- else: color = colors[4]
144
- cv2.drawContours(color_mask, [box], 0, color, 3)
145
-
146
- fig2 = plt.figure(figsize=(6, 6))
147
- plt.imshow(cv2.cvtColor(color_mask, cv2.COLOR_BGR2RGB))
148
- plt.title("Feret Rectangles by Size")
149
- plt.axis('off')
150
- figs.append(fig2)
151
-
152
- # Delaunay Triangulation
153
- labeled_img = label(label_img)
154
- props = regionprops(labeled_img)
155
- centroids = np.array([p.centroid for p in props])
156
- edge_lengths = []
157
-
158
- if len(centroids) >= 3:
159
- tri = Delaunay(centroids)
160
- fig3 = plt.figure(figsize=(6, 6))
161
- plt.imshow(label_img, cmap='gray')
162
- plt.triplot(centroids[:, 1], centroids[:, 0], tri.simplices.copy(), color='red')
163
- for simplex in tri.simplices:
164
- for i in range(3):
165
- pt1, pt2 = centroids[simplex[i]], centroids[simplex[(i + 1) % 3]]
166
- dist_mm = np.linalg.norm(pt1 - pt2) * pixel_length_mm
167
- edge_lengths.append(dist_mm)
168
- midpoint = (pt1 + pt2) / 2
169
- plt.text(midpoint[1], midpoint[0], f"{dist_mm:.1f}", color='blue', fontsize=6, ha='center')
170
- plt.title("Delaunay Triangulation")
171
- plt.axis('off')
172
- figs.append(fig3)
173
-
174
- # Summary Stats
175
- area_mask = np.sum(binary_mask > 0)
176
- area_gray = np.count_nonzero(gray_img)
177
- aggregate_area_mm2 = area_mask * (pixel_length_mm ** 2)
178
- total_area_mm2 = area_gray * (pixel_length_mm ** 2)
179
- aggregate_ratio = aggregate_area_mm2 / total_area_mm2 if total_area_mm2 > 0 else 0
180
-
181
- if feret_lengths:
182
- avg_feret_length_mm = np.mean(feret_lengths) * pixel_length_mm
183
- avg_feret_width_mm = np.mean(feret_widths) * pixel_length_mm
184
- max_feret_length_mm = np.max(feret_lengths) * pixel_length_mm
185
- roundness_aggregate = avg_feret_length_mm / avg_feret_width_mm
186
- else:
187
- avg_feret_length_mm = avg_feret_width_mm = max_feret_length_mm = roundness_aggregate = 0
188
-
189
- summary = f"""
190
- → Pixel Size: {pixel_length_mm:.4f} mm/pixel
191
- → Aggregate Area: {aggregate_area_mm2:.2f} mm²
192
- → Aggregate Ratio: {aggregate_ratio:.4f}
193
- → Avg Aggregate Length: {avg_feret_length_mm:.2f} mm
194
- → Avg Aggregate Width: {avg_feret_width_mm:.2f} mm
195
- → Max Aggregate Length: {max_feret_length_mm:.2f} mm
196
- → Avg Aggregate Roundness: {roundness_aggregate:.2f}
197
- """
198
- if edge_lengths:
199
- summary += f"""
200
- → Avg inter_Aggregate Distance: {np.mean(edge_lengths):.2f} mm
201
- → Max inter_Aggregate Distance: {np.max(edge_lengths):.2f} mm
202
- """
203
-
204
- images = [fig_to_image(fig) for fig in figs]
205
- return images[0], images[1], images[2] if len(images) > 2 else images[1], summary
206
-
207
- iface = gr.Interface(
208
- fn=predict,
209
- inputs=[gr.Image(label="Upload Concrete Image")],
210
- outputs=[
211
- gr.Image(label="Boundary and Calibration Line"),
212
- gr.Image(label="Feret Rectangles by Size"),
213
- gr.Image(label="Delaunay Triangulation"),
214
- gr.Textbox(label="Summary Measurements")
215
- ],
216
- title="Concrete Aggregate Analysis App",
217
- description="Upload a concrete cross-section image. The app segments aggregates, displays Feret rectangles, boundary calibration, Delaunay triangulation, and summary measurements."
218
- )
219
-
220
- if __name__ == "__main__":
221
- iface.launch()
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from scipy.spatial.distance import cdist
7
+ from scipy.spatial import Delaunay
8
+ from skimage.measure import label, regionprops
9
+ import gradio as gr
10
+ import io
11
+ from PIL import Image
12
+
13
+ # Constants
14
+ DIA_MM = 152.4
15
+
16
+ # Main processing function
17
+ def analyze_aggregate(image_pil):
18
+ results = {}
19
+ edge_lengths = []
20
+
21
+ # Convert to OpenCV image
22
+ img = np.array(image_pil)
23
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
24
+ gray_img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
25
+
26
+ # Simulated label (as if predicted by a model)
27
+ label_img = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2BGR) # Dummy label for placeholder
28
+ _, label_gray = cv2.threshold(gray_img, 127, 255, cv2.THRESH_BINARY)
29
+ binary_mask = (label_gray > 0).astype(np.uint8)
30
+ color_mask = cv2.cvtColor(label_gray, cv2.COLOR_GRAY2BGR)
31
+
32
+ # Pixel calibration
33
+ _, bw = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
34
+ contours, _ = cv2.findContours(bw, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
35
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
36
+ if not contours:
37
+ return "No contours found.", None, None, None, None
38
+
39
+ boundary = contours[0].squeeze()
40
+ dist_matrix = cdist(boundary, boundary)
41
+ i, j = np.unravel_index(np.argmax(dist_matrix), dist_matrix.shape)
42
+ line_pts = np.array([boundary[i], boundary[j]])
43
+ pixel_diameter = np.linalg.norm(boundary[i] - boundary[j])
44
+ pixels_per_mm = pixel_diameter / DIA_MM
45
+ pixel_length_mm = 1 / pixels_per_mm
46
+ line_length_mm = pixel_diameter * pixel_length_mm
47
+
48
+ # Plot 1: Boundary and line
49
+ fig1, ax1 = plt.subplots(figsize=(6, 6))
50
+ ax1.imshow(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
51
+ ax1.plot(boundary[:, 0], boundary[:, 1], 'g', linewidth=2)
52
+ ax1.plot(line_pts[:, 0], line_pts[:, 1], 'r', linewidth=2)
53
+ ax1.set_title(f'Line Length: {line_length_mm:.2f} mm')
54
+ ax1.axis('off')
55
+
56
+ # Aggregate area
57
+ num_white_pixels = np.sum(binary_mask == 1)
58
+ num_nonblack_pixels = np.count_nonzero(gray_img)
59
+ aggregate_area_mm2 = num_white_pixels * (pixel_length_mm ** 2)
60
+ total_area_mm2 = num_nonblack_pixels * (pixel_length_mm ** 2)
61
+ aggregate_ratio = aggregate_area_mm2 / total_area_mm2 if total_area_mm2 > 0 else 0
62
+
63
+ # Feret Rectangles
64
+ feret_lengths, feret_widths = [], []
65
+ rectangles = []
66
+ contours_mask, _ = cv2.findContours(binary_mask * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
67
+ for cnt in contours_mask:
68
+ if len(cnt) >= 5:
69
+ rect = cv2.minAreaRect(cnt)
70
+ box = cv2.boxPoints(rect).astype(np.intp)
71
+ width, height = rect[1]
72
+ feret_length = max(width, height)
73
+ feret_lengths.append(feret_length)
74
+ feret_widths.append(min(width, height))
75
+ rectangles.append((box, feret_length))
76
+
77
+ thresholds = np.percentile(feret_lengths, [20, 40, 60, 80]) if feret_lengths else [0, 0, 0, 0]
78
+ colors = [(0, 0, 255), (0, 128, 255), (0, 255, 255), (0, 255, 0), (255, 0, 0)]
79
+
80
+ for box, length in rectangles:
81
+ if length <= thresholds[0]: color = colors[0]
82
+ elif length <= thresholds[1]: color = colors[1]
83
+ elif length <= thresholds[2]: color = colors[2]
84
+ elif length <= thresholds[3]: color = colors[3]
85
+ else: color = colors[4]
86
+ cv2.drawContours(color_mask, [box], 0, color, 3)
87
+
88
+ # Plot 2: Feret rectangles
89
+ fig2, ax2 = plt.subplots(figsize=(6, 6))
90
+ ax2.imshow(cv2.cvtColor(color_mask, cv2.COLOR_BGR2RGB))
91
+ ax2.set_title("Feret Rectangles by Size")
92
+ ax2.axis('off')
93
+
94
+ # Feret Stats
95
+ if feret_lengths:
96
+ avg_feret_length_mm = np.mean(feret_lengths) * pixel_length_mm
97
+ avg_feret_width_mm = np.mean(feret_widths) * pixel_length_mm
98
+ max_feret_length_mm = np.max(feret_lengths) * pixel_length_mm
99
+ roundness_aggregate = avg_feret_length_mm / avg_feret_width_mm
100
+ else:
101
+ avg_feret_length_mm = avg_feret_width_mm = max_feret_length_mm = roundness_aggregate = 0
102
+
103
+ # Delaunay triangulation
104
+ labeled_img = label(binary_mask)
105
+ props = regionprops(labeled_img)
106
+ centroids = np.array([p.centroid for p in props])
107
+
108
+ if len(centroids) >= 3:
109
+ tri = Delaunay(centroids)
110
+ fig3, ax3 = plt.subplots(figsize=(6, 6))
111
+ ax3.imshow(label_gray, cmap='gray')
112
+ ax3.triplot(centroids[:, 1], centroids[:, 0], tri.simplices.copy(), color='red')
113
+
114
+ for simplex in tri.simplices:
115
+ for i in range(3):
116
+ pt1 = centroids[simplex[i]]
117
+ pt2 = centroids[(i + 1) % 3]
118
+ dist_px = np.linalg.norm(pt1 - pt2)
119
+ dist_mm = dist_px * pixel_length_mm
120
+ edge_lengths.append(dist_mm)
121
+ midpoint = (pt1 + pt2) / 2
122
+ ax3.text(midpoint[1], midpoint[0], f"{dist_mm:.1f}", color='blue', fontsize=6, ha='center')
123
+
124
+ ax3.set_title("Delaunay Triangulation")
125
+ ax3.axis('off')
126
+ else:
127
+ fig3 = plt.figure()
128
+ plt.text(0.5, 0.5, 'Not enough centroids for triangulation.', ha='center')
129
+ plt.axis('off')
130
+
131
+ # Summary text
132
+ summary = f"""
133
+ Pixel Size: {pixel_length_mm:.4f} mm/pixel
134
+ → Aggregate Area: {aggregate_area_mm2:.2f} mm²
135
+ Aggregate Ratio: {aggregate_ratio:.4f}
136
+ Avg Aggregate Length: {avg_feret_length_mm:.2f} mm
137
+ → Avg Aggregate Width: {avg_feret_width_mm:.2f} mm
138
+ Max Aggregate Length: {max_feret_length_mm:.2f} mm
139
+ Avg Aggregate Roundness: {roundness_aggregate:.2f}
140
+ """
141
+ if edge_lengths:
142
+ summary += f"""
143
+ → Avg inter-Aggregate Distance: {np.mean(edge_lengths):.2f} mm
144
+ Max inter-Aggregate Distance: {np.max(edge_lengths):.2f} mm
145
+ """
146
+
147
+ return summary.strip(), fig1, fig2, fig3
148
+
149
+ # Gradio UI
150
+ demo = gr.Interface(
151
+ fn=analyze_aggregate,
152
+ inputs=[gr.Image(label="Upload Image")],
153
+ outputs=[
154
+ gr.Textbox(label="Summary Measurements"),
155
+ gr.Plot(label="Boundary and Calibration Line"),
156
+ gr.Plot(label="Feret Rectangles by Size"),
157
+ gr.Plot(label="Delaunay Triangulation")
158
+ ],
159
+ title="Aggregate Analysis from Uploaded Image",
160
+ description="Upload an image with circular calibration. The app will calculate size, aspect ratio, and spacing of aggregates.",
161
+ allow_flagging='never'
162
+ )
163
+
164
+ demo.launch()