Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from glob import glob | |
| from PIL import Image | |
| from skimage.measure import regionprops, label | |
| from scipy.spatial.distance import cdist | |
| from scipy.spatial import Delaunay | |
| from io import BytesIO | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
| import segmentation_models_pytorch as smp | |
| # Configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DIAMETER_MM = 152.4 | |
| MIN_SIZE = 256 | |
| class PetModel(torch.nn.Module): | |
| def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs): | |
| super().__init__() | |
| self.model = smp.create_model( | |
| arch, encoder_name, in_channels=in_channels, classes=out_classes, **kwargs | |
| ) | |
| params = smp.encoders.get_preprocessing_params(encoder_name) | |
| self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1)) | |
| self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1)) | |
| def forward(self, image): | |
| image = (image - self.mean) / self.std | |
| return self.model(image) | |
| def preprocess_image(image, min_size=MIN_SIZE): | |
| image = np.array(image) | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| elif image.shape[2] == 1: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| original_size = image.shape[:2] | |
| h, w = image.shape[:2] | |
| if h < min_size or w < min_size: | |
| new_size = (max(w, min_size), max(h, min_size)) | |
| image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR) | |
| image = image.astype(np.float32) / 255.0 | |
| image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) | |
| return image, original_size | |
| def postprocess_output(output, original_size): | |
| prob_mask = output.sigmoid() | |
| pred_mask = (prob_mask > 0.5).float() | |
| pred_mask = pred_mask.squeeze().cpu().numpy() | |
| if pred_mask.shape != original_size: | |
| pred_mask = cv2.resize(pred_mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST) | |
| return pred_mask | |
| def load_model(model_path): | |
| model = PetModel("unet", "efficientnet-b5", in_channels=3, out_classes=1) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| return model | |
| model = load_model("segmentation_model_final.pth") | |
| csv_output_path = "measurement_summary.csv" | |
| def fig_to_image(fig): | |
| buf = BytesIO() | |
| canvas = FigureCanvas(fig) | |
| canvas.print_png(buf) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def analyze(image): | |
| input_tensor, original_size = preprocess_image(image) | |
| input_tensor = input_tensor.to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| prediction_mask = postprocess_output(output, original_size) | |
| image_np = np.array(image) | |
| gray_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) | |
| label_img = (prediction_mask * 255).astype(np.uint8) | |
| _, bw = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| contours, _ = cv2.findContours(bw, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) | |
| contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| if not contours: | |
| return None, None, None, "No contour found." | |
| boundary = contours[0].squeeze() | |
| dist_matrix = cdist(boundary, boundary) | |
| i, j = np.unravel_index(np.argmax(dist_matrix), dist_matrix.shape) | |
| line_pts = np.array([boundary[i], boundary[j]]) | |
| pixel_diameter = np.linalg.norm(boundary[i] - boundary[j]) | |
| pixels_per_mm = pixel_diameter / DIAMETER_MM | |
| pixel_length_mm = 1 / pixels_per_mm | |
| line_length_mm = pixel_diameter * pixel_length_mm | |
| fig1 = plt.figure(figsize=(6, 6)) | |
| plt.imshow(image_np) | |
| plt.plot(boundary[:, 0], boundary[:, 1], 'g', linewidth=2) | |
| plt.plot(line_pts[:, 0], line_pts[:, 1], 'r', linewidth=2) | |
| plt.title(f"Calibration Line: {line_length_mm:.2f} mm") | |
| plt.axis("off") | |
| img1 = fig_to_image(fig1) | |
| binary_mask = (label_img > 127).astype(np.uint8) | |
| color_mask = cv2.cvtColor(label_img, cv2.COLOR_GRAY2BGR) | |
| feret_lengths, feret_widths, rectangles = [], [], [] | |
| contours_mask, _ = cv2.findContours(binary_mask * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| for cnt in contours_mask: | |
| if len(cnt) >= 5: | |
| rect = cv2.minAreaRect(cnt) | |
| box = cv2.boxPoints(rect).astype(np.intp) | |
| width, height = rect[1] | |
| feret_length = max(width, height) | |
| feret_lengths.append(feret_length) | |
| feret_widths.append(min(width, height)) | |
| rectangles.append((box, feret_length)) | |
| thresholds = np.percentile(feret_lengths, [20, 40, 60, 80]) if feret_lengths else [0]*4 | |
| colors = [(0,0,255),(0,128,255),(0,255,255),(0,255,0),(255,0,0)] | |
| for box, length in rectangles: | |
| if length <= thresholds[0]: color = colors[0] | |
| elif length <= thresholds[1]: color = colors[1] | |
| elif length <= thresholds[2]: color = colors[2] | |
| elif length <= thresholds[3]: color = colors[3] | |
| else: color = colors[4] | |
| cv2.drawContours(color_mask, [box], 0, color, 8) | |
| fig2 = plt.figure(figsize=(6, 6)) | |
| plt.imshow(cv2.cvtColor(color_mask, cv2.COLOR_BGR2RGB)) | |
| plt.title("Feret Rectangles (Colored by Size)") | |
| plt.axis("off") | |
| img2 = fig_to_image(fig2) | |
| labeled_img = label(binary_mask) | |
| props = regionprops(labeled_img) | |
| centroids = np.array([p.centroid for p in props]) | |
| edge_lengths = [] | |
| fig3 = plt.figure(figsize=(6, 6)) | |
| plt.imshow(label_img, cmap="gray") | |
| if len(centroids) >= 3: | |
| tri = Delaunay(centroids) | |
| plt.triplot(centroids[:, 1], centroids[:, 0], tri.simplices.copy(), color="red", linewidth=1) | |
| for simplex in tri.simplices: | |
| for i in range(3): | |
| pt1 = centroids[simplex[i]] | |
| pt2 = centroids[simplex[(i + 1) % 3]] | |
| dist_px = np.linalg.norm(pt1 - pt2) | |
| dist_mm = dist_px * pixel_length_mm | |
| edge_lengths.append(dist_mm) | |
| plt.title("Delaunay Triangulation") | |
| else: | |
| plt.title("Not Enough Aggregates for Triangulation") | |
| plt.axis("off") | |
| img3 = fig_to_image(fig3) | |
| num_white_pixels = np.sum(binary_mask == 1) | |
| num_nonblack_pixels = np.count_nonzero(gray_img) | |
| aggregate_area_mm2 = num_white_pixels * (pixel_length_mm ** 2) | |
| total_area_mm2 = num_nonblack_pixels * (pixel_length_mm ** 2) | |
| aggregate_ratio = aggregate_area_mm2 / total_area_mm2 if total_area_mm2 > 0 else 0 | |
| if feret_lengths: | |
| avg_feret_length_mm = np.mean(feret_lengths) * pixel_length_mm | |
| avg_feret_width_mm = np.mean(feret_widths) * pixel_length_mm | |
| max_feret_length_mm = np.max(feret_lengths) * pixel_length_mm | |
| roundness_aggregate = avg_feret_length_mm / avg_feret_width_mm | |
| else: | |
| avg_feret_length_mm = avg_feret_width_mm = max_feret_length_mm = roundness_aggregate = 0 | |
| # Save to CSV | |
| data = { | |
| "Pixel_Size_mm_per_pixel": [pixel_length_mm], | |
| "Aggregate_Area_mm2": [aggregate_area_mm2], | |
| "Aggregate_Ratio": [aggregate_ratio], | |
| "Avg_Length_mm": [avg_feret_length_mm], | |
| "Avg_Width_mm": [avg_feret_width_mm], | |
| "Max_Length_mm": [max_feret_length_mm], | |
| "Roundness": [roundness_aggregate], | |
| "Avg_Dist_mm": [np.mean(edge_lengths) if edge_lengths else 0], | |
| "Max_Dist_mm": [np.max(edge_lengths) if edge_lengths else 0] | |
| } | |
| df = pd.DataFrame(data) | |
| df.to_csv(csv_output_path, index=False) | |
| summary = f"""📏 **Measurements Summary**: | |
| - Pixel Size: `{pixel_length_mm:.4f}` mm/pixel | |
| - Aggregate Area: `{aggregate_area_mm2:.2f}` mm² | |
| - Aggregate Ratio: `{aggregate_ratio:.4f}` | |
| - Avg Aggregate Length: `{avg_feret_length_mm:.2f}` mm | |
| - Avg Aggregate Width: `{avg_feret_width_mm:.2f}` mm | |
| - Max Aggregate Length: `{max_feret_length_mm:.2f}` mm | |
| - Aggregate Roundness: `{roundness_aggregate:.2f}` | |
| """ | |
| if edge_lengths: | |
| summary += f"- Avg Inter-Aggregate Distance: `{np.mean(edge_lengths):.2f}` mm\n" | |
| summary += f"- Max Inter-Aggregate Distance: `{np.max(edge_lengths):.2f}` mm\n" | |
| return img1, img2, img3, summary | |
| demo = gr.Interface( | |
| fn=analyze, | |
| inputs=gr.Image(type="pil", label="Upload Concrete Image"), | |
| outputs=[ | |
| gr.Image(label="Boundary & Calibration Line"), | |
| gr.Image(label="Feret Rectangles"), | |
| gr.Image(label="Delaunay Triangulation"), | |
| gr.Textbox(label="Measurements Summary") | |
| ], | |
| title="Concrete Aggregate Analysis App", | |
| description="Upload a concrete image. The model will segment aggregates and analyze their distribution and shape." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |