Matiullah00999 commited on
Commit
dffb58b
·
verified ·
1 Parent(s): 051da8f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Gradio App: Aggregate Analysis with Feret and Boundary Extension
2
+
3
+ import os
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ from glob import glob
11
+ from PIL import Image
12
+ from skimage.measure import regionprops, label
13
+ from scipy.spatial.distance import cdist
14
+ from scipy.spatial import Delaunay
15
+ from io import BytesIO
16
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
17
+ import segmentation_models_pytorch as smp
18
+
19
+ # Configuration
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ DIAMETER_MM = 152.4
22
+ MIN_SIZE = 256
23
+
24
+ class PetModel(torch.nn.Module):
25
+ def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
26
+ super().__init__()
27
+ self.model = smp.create_model(
28
+ arch, encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
29
+ )
30
+ params = smp.encoders.get_preprocessing_params(encoder_name)
31
+ self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
32
+ self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
33
+
34
+ def forward(self, image):
35
+ image = (image - self.mean) / self.std
36
+ return self.model(image)
37
+
38
+ def preprocess_image(image, min_size=MIN_SIZE):
39
+ image = np.array(image)
40
+ if len(image.shape) == 2:
41
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
42
+ elif image.shape[2] == 4:
43
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
44
+ elif image.shape[2] == 1:
45
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
46
+
47
+ original_size = image.shape[:2]
48
+ h, w = image.shape[:2]
49
+ if h < min_size or w < min_size:
50
+ new_size = (max(w, min_size), max(h, min_size))
51
+ image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
52
+
53
+ image = image.astype(np.float32) / 255.0
54
+ image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
55
+ return image, original_size
56
+
57
+ def postprocess_output(output, original_size):
58
+ prob_mask = output.sigmoid()
59
+ pred_mask = (prob_mask > 0.5).float()
60
+ pred_mask = pred_mask.squeeze().cpu().numpy()
61
+ if pred_mask.shape != original_size:
62
+ pred_mask = cv2.resize(pred_mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)
63
+ return pred_mask
64
+
65
+ def load_model(model_path):
66
+ model = PetModel("unet", "efficientnet-b5", in_channels=3, out_classes=1)
67
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
68
+ model = model.to(DEVICE)
69
+ model.eval()
70
+ return model
71
+
72
+ model = load_model("segmentation_model_final.pth")
73
+
74
+ def fig_to_image(fig):
75
+ buf = BytesIO()
76
+ canvas = FigureCanvas(fig)
77
+ canvas.print_png(buf)
78
+ buf.seek(0)
79
+ return Image.open(buf)
80
+
81
+ def predict(image):
82
+ input_tensor, original_size = preprocess_image(image)
83
+ input_tensor = input_tensor.to(DEVICE)
84
+
85
+ with torch.no_grad():
86
+ output = model(input_tensor)
87
+
88
+ prediction_mask = postprocess_output(output, original_size)
89
+ image_np = np.array(image)
90
+ gray_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
91
+
92
+ # Calibration using outer boundary
93
+ _, bw = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
94
+ contours, _ = cv2.findContours(bw, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
95
+ pixel_length_mm = 1.0
96
+ dia = DIAMETER_MM
97
+ summary = {}
98
+
99
+ figs = []
100
+ if contours:
101
+ boundary = contours[0].squeeze()
102
+ dist_matrix = cdist(boundary, boundary)
103
+ i, j = np.unravel_index(np.argmax(dist_matrix), dist_matrix.shape)
104
+ line_pts = np.array([boundary[i], boundary[j]])
105
+ pixel_diameter = np.linalg.norm(boundary[i] - boundary[j])
106
+ pixels_per_mm = pixel_diameter / dia
107
+ pixel_length_mm = 1 / pixels_per_mm
108
+ line_length_mm = pixel_diameter * pixel_length_mm
109
+
110
+ # Boundary Plot
111
+ fig1 = plt.figure(figsize=(6, 6))
112
+ plt.imshow(image_np)
113
+ plt.plot(boundary[:, 0], boundary[:, 1], 'g', linewidth=2)
114
+ plt.plot(line_pts[:, 0], line_pts[:, 1], 'r', linewidth=2)
115
+ plt.title(f'Line Length: {line_length_mm:.2f} mm')
116
+ plt.axis('off')
117
+ figs.append(fig1)
118
+
119
+ # Feret Analysis
120
+ label_img = (prediction_mask > 0.5).astype(np.uint8)
121
+ binary_mask = (label_img * 255).astype(np.uint8)
122
+ color_mask = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
123
+ feret_lengths, feret_widths, rectangles = [], [], []
124
+ contours_mask, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
125
+ for cnt in contours_mask:
126
+ if len(cnt) >= 5:
127
+ rect = cv2.minAreaRect(cnt)
128
+ box = cv2.boxPoints(rect).astype(np.intp)
129
+ width, height = rect[1]
130
+ feret_length = max(width, height)
131
+ feret_lengths.append(feret_length)
132
+ feret_widths.append(min(width, height))
133
+ rectangles.append((box, feret_length))
134
+
135
+ thresholds = np.percentile(feret_lengths, [20, 40, 60, 80]) if feret_lengths else [0, 0, 0, 0]
136
+ colors = [(0, 0, 255), (0, 128, 255), (0, 255, 255), (0, 255, 0), (255, 0, 0)]
137
+
138
+ for box, length in rectangles:
139
+ if length <= thresholds[0]: color = colors[0]
140
+ elif length <= thresholds[1]: color = colors[1]
141
+ elif length <= thresholds[2]: color = colors[2]
142
+ elif length <= thresholds[3]: color = colors[3]
143
+ else: color = colors[4]
144
+ cv2.drawContours(color_mask, [box], 0, color, 3)
145
+
146
+ fig2 = plt.figure(figsize=(6, 6))
147
+ plt.imshow(cv2.cvtColor(color_mask, cv2.COLOR_BGR2RGB))
148
+ plt.title("Feret Rectangles by Size")
149
+ plt.axis('off')
150
+ figs.append(fig2)
151
+
152
+ # Delaunay Triangulation
153
+ labeled_img = label(label_img)
154
+ props = regionprops(labeled_img)
155
+ centroids = np.array([p.centroid for p in props])
156
+ edge_lengths = []
157
+
158
+ if len(centroids) >= 3:
159
+ tri = Delaunay(centroids)
160
+ fig3 = plt.figure(figsize=(6, 6))
161
+ plt.imshow(label_img, cmap='gray')
162
+ plt.triplot(centroids[:, 1], centroids[:, 0], tri.simplices.copy(), color='red')
163
+ for simplex in tri.simplices:
164
+ for i in range(3):
165
+ pt1, pt2 = centroids[simplex[i]], centroids[simplex[(i + 1) % 3]]
166
+ dist_mm = np.linalg.norm(pt1 - pt2) * pixel_length_mm
167
+ edge_lengths.append(dist_mm)
168
+ midpoint = (pt1 + pt2) / 2
169
+ plt.text(midpoint[1], midpoint[0], f"{dist_mm:.1f}", color='blue', fontsize=6, ha='center')
170
+ plt.title("Delaunay Triangulation")
171
+ plt.axis('off')
172
+ figs.append(fig3)
173
+
174
+ # Summary Stats
175
+ area_mask = np.sum(binary_mask > 0)
176
+ area_gray = np.count_nonzero(gray_img)
177
+ aggregate_area_mm2 = area_mask * (pixel_length_mm ** 2)
178
+ total_area_mm2 = area_gray * (pixel_length_mm ** 2)
179
+ aggregate_ratio = aggregate_area_mm2 / total_area_mm2 if total_area_mm2 > 0 else 0
180
+
181
+ if feret_lengths:
182
+ avg_feret_length_mm = np.mean(feret_lengths) * pixel_length_mm
183
+ avg_feret_width_mm = np.mean(feret_widths) * pixel_length_mm
184
+ max_feret_length_mm = np.max(feret_lengths) * pixel_length_mm
185
+ roundness_aggregate = avg_feret_length_mm / avg_feret_width_mm
186
+ else:
187
+ avg_feret_length_mm = avg_feret_width_mm = max_feret_length_mm = roundness_aggregate = 0
188
+
189
+ summary = f"""
190
+ → Pixel Size: {pixel_length_mm:.4f} mm/pixel
191
+ → Aggregate Area: {aggregate_area_mm2:.2f} mm²
192
+ → Aggregate Ratio: {aggregate_ratio:.4f}
193
+ → Avg Aggregate Length: {avg_feret_length_mm:.2f} mm
194
+ → Avg Aggregate Width: {avg_feret_width_mm:.2f} mm
195
+ → Max Aggregate Length: {max_feret_length_mm:.2f} mm
196
+ → Avg Aggregate Roundness: {roundness_aggregate:.2f}
197
+ """
198
+ if edge_lengths:
199
+ summary += f"""
200
+ → Avg inter_Aggregate Distance: {np.mean(edge_lengths):.2f} mm
201
+ → Max inter_Aggregate Distance: {np.max(edge_lengths):.2f} mm
202
+ """
203
+
204
+ images = [fig_to_image(fig) for fig in figs]
205
+ return images[0], images[1], images[2] if len(images) > 2 else images[1], summary
206
+
207
+ iface = gr.Interface(
208
+ fn=predict,
209
+ inputs=[gr.Image(label="Upload Concrete Image")],
210
+ outputs=[
211
+ gr.Image(label="Boundary and Calibration Line"),
212
+ gr.Image(label="Feret Rectangles by Size"),
213
+ gr.Image(label="Delaunay Triangulation"),
214
+ gr.Textbox(label="Summary Measurements")
215
+ ],
216
+ title="Concrete Aggregate Analysis App",
217
+ description="Upload a concrete cross-section image. The app segments aggregates, displays Feret rectangles, boundary calibration, Delaunay triangulation, and summary measurements."
218
+ )
219
+
220
+ if __name__ == "__main__":
221
+ iface.launch()