U2_net_Arc / app.py
uchihamadara1816's picture
Font size
dca0904 verified
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image
from skimage import color
import gradio as gr
import os
import math
# ----------------------------
# Model Definition (U2NET)
# ----------------------------
from model.u2net import U2NET # make sure u2net.py is in model/
# Camera and object parameters
sensor_size_mm = (7.4, 5.55) # sensor size in mm (width, height)
focal_length_mm = 5.5 # focal length in mm
object_distance_mm = 300 # distance from camera in mm
# ----------------------------
# Preprocessing
# ----------------------------
def preprocess_image(pil_img):
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
return transform(pil_img).unsqueeze(0)
# ----------------------------
# Postprocessing
# ----------------------------
def postprocess_mask(pred, original_size):
pred = pred.squeeze().cpu().data.numpy()
pred = (pred - pred.min()) / (pred.max() - pred.min())
pred = (pred * 255).astype(np.uint8)
pred = cv2.resize(pred, original_size, interpolation=cv2.INTER_LINEAR)
return pred
# ----------------------------
# Remove Background
# ----------------------------
def remove_background(original_image, mask):
original_np = np.array(original_image)
if mask.ndim == 2:
mask = np.expand_dims(mask, axis=2)
mask = np.repeat(mask, 3, axis=2)
fg = (original_np * (mask / 255)).astype(np.uint8)
return fg
# ----------------------------
# Measure Object (Contour-based)
# ----------------------------
def measure_object(image_np, original_resolution):
gray = color.rgb2gray(image_np[..., :3]) if image_np.ndim == 3 else image_np
# Binary mask (Otsu threshold)
gray8 = (255 * gray).astype(np.uint8)
_, mask = cv2.threshold(gray8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE,
cv2.getStructuringElement(cv2.MORPH_RECT, (5,5)), iterations=1)
# Find contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image_np, "No object found."
cnt = max(contours, key=cv2.contourArea)
# Approximate polygon (try quadrilateral)
epsilon = 0.02 * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
# Fallback: if not 4 corners, use minAreaRect
if len(approx) != 4:
rect = cv2.minAreaRect(cnt)
approx = cv2.boxPoints(rect).astype(int)
approx = approx.reshape(-1, 2) # (4,2)
# Calibration (mm per pixel)
h_img, w_img = image_np.shape[:2]
sensor_width_mm, sensor_height_mm = sensor_size_mm
mm_per_px_x = (sensor_width_mm * object_distance_mm) / (focal_length_mm * w_img)
mm_per_px_y = (sensor_height_mm * object_distance_mm) / (focal_length_mm * h_img)
mm_per_px = 0.5 * (mm_per_px_x + mm_per_px_y)
# Measure each edge
edge_lengths_cm = []
edge_midpoints = []
for i in range(4):
p1 = approx[i]
p2 = approx[(i+1) % 4]
d_px = math.hypot(p2[0]-p1[0], p2[1]-p1[1])
d_cm = (d_px * mm_per_px) / 10.0
edge_lengths_cm.append(d_cm)
edge_midpoints.append(((p1[0]+p2[0])//2, (p1[1]+p2[1])//2))
# Area (real shape)
area_px2 = cv2.contourArea(cnt)
area_cm2 = (area_px2 * (mm_per_px**2)) / 100.0 # mm²→cm²
# Annotate image
annotated = image_np.copy()
cv2.polylines(annotated, [approx.astype(int)], True, (0,255,0), 2)
for (mx,my), L in zip(edge_midpoints, edge_lengths_cm):
cv2.putText(annotated, f"{L:.2f} cm", (int(mx), int(my)),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255,255,0), 2, cv2.LINE_AA)
cv2.putText(annotated, f"Area: {area_cm2:.2f} cm^2", (30,30),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,0,255), 2, cv2.LINE_AA)
# Text summary
measurements_text = f"Edges: {[f'{L:.2f}' for L in edge_lengths_cm]} cm | Area: {area_cm2:.2f} cm²"
return annotated, measurements_text
# ----------------------------
# Pipeline
# ----------------------------
def process(image):
# Convert and save as WebP
image = image.convert("RGB")
original_resolution = image.size # (W,H)
temp_webp_path = "temp.webp"
image.save(temp_webp_path, "WEBP", quality=80)
# Load U2NET
model_path = "u2net.pth"
net = U2NET(3, 1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location='cpu'))
net.eval()
# Preprocess
image_tensor = preprocess_image(image)
if torch.cuda.is_available():
image_tensor = image_tensor.cuda()
# Predict
with torch.no_grad():
d1, _, _, _, _, _, _ = net(image_tensor)
pred_mask = d1[:, 0, :, :]
pred_mask = F.upsample(pred_mask.unsqueeze(1), size=original_resolution[::-1],
mode='bilinear', align_corners=False)
mask = postprocess_mask(pred_mask, original_resolution)
# Remove background
result = remove_background(image, mask)
# Measure object
annotated, measurements_text = measure_object(result, original_resolution)
return Image.fromarray(annotated), measurements_text
# ----------------------------
# Gradio App
# ----------------------------
demo = gr.Interface(
fn=process,
inputs=gr.Image(type="pil", label="Upload Image (JPG/PNG)"),
outputs=[gr.Image(type="pil", label="Annotated Result"),
gr.Textbox(label="Measurements")],
title="U²-Net Background Removal + Object Measurement",
description="Uploads JPG/PNG → Removes background with U²-Net → Finds contour → Measures all 4 edges & area in cm"
)
if __name__ == "__main__":
demo.launch()