Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| from torchvision import transforms | |
| from PIL import Image | |
| from skimage import color | |
| import gradio as gr | |
| import os | |
| import math | |
| # ---------------------------- | |
| # Model Definition (U2NET) | |
| # ---------------------------- | |
| from model.u2net import U2NET # make sure u2net.py is in model/ | |
| # Camera and object parameters | |
| sensor_size_mm = (7.4, 5.55) # sensor size in mm (width, height) | |
| focal_length_mm = 5.5 # focal length in mm | |
| object_distance_mm = 300 # distance from camera in mm | |
| # ---------------------------- | |
| # Preprocessing | |
| # ---------------------------- | |
| def preprocess_image(pil_img): | |
| transform = transforms.Compose([ | |
| transforms.Resize((320, 320)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(pil_img).unsqueeze(0) | |
| # ---------------------------- | |
| # Postprocessing | |
| # ---------------------------- | |
| def postprocess_mask(pred, original_size): | |
| pred = pred.squeeze().cpu().data.numpy() | |
| pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
| pred = (pred * 255).astype(np.uint8) | |
| pred = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR) | |
| return pred | |
| # ---------------------------- | |
| # Remove Background | |
| # ---------------------------- | |
| def remove_background(original_image, mask): | |
| original_np = np.array(original_image) | |
| if mask.ndim == 2: | |
| mask = np.expand_dims(mask, axis=2) | |
| mask = np.repeat(mask, 3, axis=2) | |
| fg = (original_np * (mask / 255)).astype(np.uint8) | |
| return fg | |
| # ---------------------------- | |
| # Measure Object (Contour-based) | |
| # ---------------------------- | |
| def measure_object(image_np, original_resolution): | |
| gray = color.rgb2gray(image_np[..., :3]) if image_np.ndim == 3 else image_np | |
| # Binary mask (Otsu threshold) | |
| gray8 = (255 * gray).astype(np.uint8) | |
| _, mask = cv2.threshold(gray8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, | |
| cv2.getStructuringElement(cv2.MORPH_RECT, (5,5)), iterations=1) | |
| # Find contours | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return image_np, "No object found." | |
| cnt = max(contours, key=cv2.contourArea) | |
| # Approximate polygon (try quadrilateral) | |
| epsilon = 0.02 * cv2.arcLength(cnt, True) | |
| approx = cv2.approxPolyDP(cnt, epsilon, True) | |
| # Fallback: if not 4 corners, use minAreaRect | |
| if len(approx) != 4: | |
| rect = cv2.minAreaRect(cnt) | |
| approx = cv2.boxPoints(rect).astype(int) | |
| approx = approx.reshape(-1, 2) # (4,2) | |
| # Calibration (mm per pixel) | |
| h_img, w_img = image_np.shape[:2] | |
| sensor_width_mm, sensor_height_mm = sensor_size_mm | |
| mm_per_px_x = (sensor_width_mm * object_distance_mm) / (focal_length_mm * w_img) | |
| mm_per_px_y = (sensor_height_mm * object_distance_mm) / (focal_length_mm * h_img) | |
| mm_per_px = 0.5 * (mm_per_px_x + mm_per_px_y) | |
| # Measure each edge | |
| edge_lengths_cm = [] | |
| edge_midpoints = [] | |
| for i in range(4): | |
| p1 = approx[i] | |
| p2 = approx[(i+1) % 4] | |
| d_px = math.hypot(p2[0]-p1[0], p2[1]-p1[1]) | |
| d_cm = (d_px * mm_per_px) / 10.0 | |
| edge_lengths_cm.append(d_cm) | |
| edge_midpoints.append(((p1[0]+p2[0])//2, (p1[1]+p2[1])//2)) | |
| # Area (real shape) | |
| area_px2 = cv2.contourArea(cnt) | |
| area_cm2 = (area_px2 * (mm_per_px**2)) / 100.0 # mm²→cm² | |
| # Annotate image | |
| annotated = image_np.copy() | |
| cv2.polylines(annotated, [approx.astype(int)], True, (0,255,0), 2) | |
| for (mx,my), L in zip(edge_midpoints, edge_lengths_cm): | |
| cv2.putText(annotated, f"{L:.2f} cm", (int(mx), int(my)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255,255,0), 2, cv2.LINE_AA) | |
| cv2.putText(annotated, f"Area: {area_cm2:.2f} cm^2", (30,30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,0,255), 2, cv2.LINE_AA) | |
| # Text summary | |
| measurements_text = f"Edges: {[f'{L:.2f}' for L in edge_lengths_cm]} cm | Area: {area_cm2:.2f} cm²" | |
| return annotated, measurements_text | |
| # ---------------------------- | |
| # Pipeline | |
| # ---------------------------- | |
| def process(image): | |
| # Convert and save as WebP | |
| image = image.convert("RGB") | |
| original_resolution = image.size # (W,H) | |
| temp_webp_path = "temp.webp" | |
| image.save(temp_webp_path, "WEBP", quality=80) | |
| # Load U2NET | |
| model_path = "u2net.pth" | |
| net = U2NET(3, 1) | |
| if torch.cuda.is_available(): | |
| net.load_state_dict(torch.load(model_path)) | |
| net.cuda() | |
| else: | |
| net.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| net.eval() | |
| # Preprocess | |
| image_tensor = preprocess_image(image) | |
| if torch.cuda.is_available(): | |
| image_tensor = image_tensor.cuda() | |
| # Predict | |
| with torch.no_grad(): | |
| d1, _, _, _, _, _, _ = net(image_tensor) | |
| pred_mask = d1[:, 0, :, :] | |
| pred_mask = F.upsample(pred_mask.unsqueeze(1), size=original_resolution[::-1], | |
| mode='bilinear', align_corners=False) | |
| mask = postprocess_mask(pred_mask, original_resolution) | |
| # Remove background | |
| result = remove_background(image, mask) | |
| # Measure object | |
| annotated, measurements_text = measure_object(result, original_resolution) | |
| return Image.fromarray(annotated), measurements_text | |
| # ---------------------------- | |
| # Gradio App | |
| # ---------------------------- | |
| demo = gr.Interface( | |
| fn=process, | |
| inputs=gr.Image(type="pil", label="Upload Image (JPG/PNG)"), | |
| outputs=[gr.Image(type="pil", label="Annotated Result"), | |
| gr.Textbox(label="Measurements")], | |
| title="U²-Net Background Removal + Object Measurement", | |
| description="Uploads JPG/PNG → Removes background with U²-Net → Finds contour → Measures all 4 edges & area in cm" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |