Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -71,18 +71,18 @@ if not os.path.exists(reference_model_path):
|
|
| 71 |
reference_detector_global = YOLO(reference_model_path)
|
| 72 |
print("YOLO reference model loaded in {:.2f} seconds".format(time.time() - start_time))
|
| 73 |
|
| 74 |
-
print("Loading U²-Net model for reference background removal (U2NETP)...")
|
| 75 |
-
start_time = time.time()
|
| 76 |
-
u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
|
| 77 |
-
if not os.path.exists(u2net_model_path):
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
u2net_global = U2NETP(3, 1)
|
| 81 |
-
u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
|
| 82 |
-
device = "cpu"
|
| 83 |
-
u2net_global.to(device)
|
| 84 |
-
u2net_global.eval()
|
| 85 |
-
print("U²-Net model loaded in {:.2f} seconds".format(time.time() - start_time))
|
| 86 |
|
| 87 |
print("Loading BiRefNet model...")
|
| 88 |
start_time = time.time()
|
|
@@ -119,16 +119,16 @@ def unload_and_reload_models():
|
|
| 119 |
new_birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 120 |
"zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
|
| 121 |
)
|
| 122 |
-
new_birefnet.to(device)
|
| 123 |
-
new_birefnet.eval()
|
| 124 |
-
new_u2net = U2NETP(3, 1)
|
| 125 |
-
new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
|
| 126 |
-
new_u2net.to(device)
|
| 127 |
-
new_u2net.eval()
|
| 128 |
drawer_detector_global = new_drawer_detector
|
| 129 |
reference_detector_global = new_reference_detector
|
| 130 |
birefnet_global = new_birefnet
|
| 131 |
-
u2net_global =
|
| 132 |
print("Models reloaded in {:.2f} seconds".format(time.time() - start_time))
|
| 133 |
|
| 134 |
# ---------------------
|
|
@@ -159,23 +159,27 @@ def detect_reference_square(img: np.ndarray):
|
|
| 159 |
res[0].cpu().boxes.xyxy[0]
|
| 160 |
)
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
def
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def remove_bg(image: np.ndarray) -> np.ndarray:
|
| 180 |
t = time.time()
|
| 181 |
image_pil = Image.fromarray(image)
|
|
@@ -187,7 +191,7 @@ def remove_bg(image: np.ndarray) -> np.ndarray:
|
|
| 187 |
scale_ratio = 1024 / max(image_pil.size)
|
| 188 |
scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
|
| 189 |
result = np.array(pred_pil.resize(scaled_size))
|
| 190 |
-
print("BiRefNet
|
| 191 |
return result
|
| 192 |
|
| 193 |
def make_square(img: np.ndarray):
|
|
@@ -469,6 +473,7 @@ def predict(
|
|
| 469 |
print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
|
| 470 |
except DrawerNotDetectedError as e:
|
| 471 |
return None, None, None, None, f"Error: {str(e)}"
|
|
|
|
| 472 |
t = time.time()
|
| 473 |
shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
|
| 474 |
del drawer_img
|
|
@@ -490,9 +495,9 @@ def predict(
|
|
| 490 |
# ---------------------
|
| 491 |
t = time.time()
|
| 492 |
reference_obj_img = make_square(reference_obj_img)
|
| 493 |
-
|
| 494 |
-
reference_square_mask = remove_bg_reference(reference_obj_img)
|
| 495 |
print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
|
|
|
|
| 496 |
t = time.time()
|
| 497 |
try:
|
| 498 |
cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
|
|
@@ -565,6 +570,7 @@ def predict(
|
|
| 565 |
del objects_mask
|
| 566 |
gc.collect()
|
| 567 |
print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
|
|
|
|
| 568 |
Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
|
| 569 |
|
| 570 |
# ---------------------
|
|
@@ -573,12 +579,16 @@ def predict(
|
|
| 573 |
t = time.time()
|
| 574 |
outlines, contours = extract_outlines(dilated_mask)
|
| 575 |
print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
|
|
|
|
| 576 |
output_img = shrunked_img.copy()
|
| 577 |
del shrunked_img
|
| 578 |
gc.collect()
|
|
|
|
| 579 |
t = time.time()
|
| 580 |
use_finger_clearance = True if finger_clearance.lower() == "yes" else False
|
| 581 |
-
doc, final_polygons_inch = save_dxf_spline(
|
|
|
|
|
|
|
| 582 |
del contours
|
| 583 |
gc.collect()
|
| 584 |
print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
|
|
@@ -623,8 +633,14 @@ def predict(
|
|
| 623 |
text_x = (inner_min_x + inner_max_x) / 2.0
|
| 624 |
text_height_dxf = 0.5
|
| 625 |
text_y_dxf = inner_min_y - 0.125 - text_height_dxf
|
| 626 |
-
text_entity = msp.add_text(
|
| 627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
text_entity.dxf.insert = (text_x, text_y_dxf)
|
| 629 |
|
| 630 |
# Save the DXF
|
|
@@ -644,8 +660,27 @@ def predict(
|
|
| 644 |
text_y_in = inner_min_y - 0.125 - text_height_cv
|
| 645 |
text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
|
| 646 |
org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
|
| 647 |
-
|
| 648 |
-
cv2.putText(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
# Restore brightness for display purposes:
|
| 651 |
# Since we reduced brightness by 0.5 during preprocessing,
|
|
@@ -656,11 +691,14 @@ def predict(
|
|
| 656 |
|
| 657 |
outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
|
| 658 |
print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
# ---------------------
|
| 666 |
# Gradio Interface
|
|
|
|
| 71 |
reference_detector_global = YOLO(reference_model_path)
|
| 72 |
print("YOLO reference model loaded in {:.2f} seconds".format(time.time() - start_time))
|
| 73 |
|
| 74 |
+
# print("Loading U²-Net model for reference background removal (U2NETP)...")
|
| 75 |
+
# start_time = time.time()
|
| 76 |
+
# u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
|
| 77 |
+
# if not os.path.exists(u2net_model_path):
|
| 78 |
+
# print("Caching U²-Net model to", u2net_model_path)
|
| 79 |
+
# shutil.copy("u2netp.pth", u2net_model_path)
|
| 80 |
+
# u2net_global = U2NETP(3, 1)
|
| 81 |
+
# u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
|
| 82 |
+
# device = "cpu"
|
| 83 |
+
# u2net_global.to(device)
|
| 84 |
+
# u2net_global.eval()
|
| 85 |
+
# print("U²-Net model loaded in {:.2f} seconds".format(time.time() - start_time))
|
| 86 |
|
| 87 |
print("Loading BiRefNet model...")
|
| 88 |
start_time = time.time()
|
|
|
|
| 119 |
new_birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 120 |
"zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
|
| 121 |
)
|
| 122 |
+
# new_birefnet.to(device)
|
| 123 |
+
# new_birefnet.eval()
|
| 124 |
+
# new_u2net = U2NETP(3, 1)
|
| 125 |
+
# new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
|
| 126 |
+
# new_u2net.to(device)
|
| 127 |
+
# new_u2net.eval()
|
| 128 |
drawer_detector_global = new_drawer_detector
|
| 129 |
reference_detector_global = new_reference_detector
|
| 130 |
birefnet_global = new_birefnet
|
| 131 |
+
u2net_global = new_birefnet
|
| 132 |
print("Models reloaded in {:.2f} seconds".format(time.time() - start_time))
|
| 133 |
|
| 134 |
# ---------------------
|
|
|
|
| 159 |
res[0].cpu().boxes.xyxy[0]
|
| 160 |
)
|
| 161 |
|
| 162 |
+
# Use U2NETP for reference background removal.
|
| 163 |
+
# def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
|
| 164 |
+
# t = time.time()
|
| 165 |
+
# image_pil = Image.fromarray(image)
|
| 166 |
+
# transform_u2netp = transforms.Compose([
|
| 167 |
+
# transforms.Resize((320, 320)),
|
| 168 |
+
# transforms.ToTensor(),
|
| 169 |
+
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 170 |
+
# ])
|
| 171 |
+
# input_tensor = transform_u2netp(image_pil).unsqueeze(0).to("cpu")
|
| 172 |
+
# with torch.no_grad():
|
| 173 |
+
# outputs = u2net_global(input_tensor)
|
| 174 |
+
# pred = outputs[0]
|
| 175 |
+
# pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
| 176 |
+
# pred_np = pred.squeeze().cpu().numpy()
|
| 177 |
+
# pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
|
| 178 |
+
# pred_np = (pred_np * 255).astype(np.uint8)
|
| 179 |
+
# print("U2NETP background removal completed in {:.2f} seconds".format(time.time() - t))
|
| 180 |
+
# return pred_np
|
| 181 |
+
|
| 182 |
+
# Use BiRefNet for main object background removal.
|
| 183 |
def remove_bg(image: np.ndarray) -> np.ndarray:
|
| 184 |
t = time.time()
|
| 185 |
image_pil = Image.fromarray(image)
|
|
|
|
| 191 |
scale_ratio = 1024 / max(image_pil.size)
|
| 192 |
scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
|
| 193 |
result = np.array(pred_pil.resize(scaled_size))
|
| 194 |
+
print("BiRefNet background removal completed in {:.2f} seconds".format(time.time() - t))
|
| 195 |
return result
|
| 196 |
|
| 197 |
def make_square(img: np.ndarray):
|
|
|
|
| 473 |
print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
|
| 474 |
except DrawerNotDetectedError as e:
|
| 475 |
return None, None, None, None, f"Error: {str(e)}"
|
| 476 |
+
# Ensure that shrunked_img is defined only after successful detection.
|
| 477 |
t = time.time()
|
| 478 |
shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
|
| 479 |
del drawer_img
|
|
|
|
| 495 |
# ---------------------
|
| 496 |
t = time.time()
|
| 497 |
reference_obj_img = make_square(reference_obj_img)
|
| 498 |
+
reference_square_mask = remove_bg(reference_obj_img)
|
|
|
|
| 499 |
print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
|
| 500 |
+
|
| 501 |
t = time.time()
|
| 502 |
try:
|
| 503 |
cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
|
|
|
|
| 570 |
del objects_mask
|
| 571 |
gc.collect()
|
| 572 |
print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
|
| 573 |
+
|
| 574 |
Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
|
| 575 |
|
| 576 |
# ---------------------
|
|
|
|
| 579 |
t = time.time()
|
| 580 |
outlines, contours = extract_outlines(dilated_mask)
|
| 581 |
print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
|
| 582 |
+
|
| 583 |
output_img = shrunked_img.copy()
|
| 584 |
del shrunked_img
|
| 585 |
gc.collect()
|
| 586 |
+
|
| 587 |
t = time.time()
|
| 588 |
use_finger_clearance = True if finger_clearance.lower() == "yes" else False
|
| 589 |
+
doc, final_polygons_inch = save_dxf_spline(
|
| 590 |
+
contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance
|
| 591 |
+
)
|
| 592 |
del contours
|
| 593 |
gc.collect()
|
| 594 |
print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
|
|
|
|
| 633 |
text_x = (inner_min_x + inner_max_x) / 2.0
|
| 634 |
text_height_dxf = 0.5
|
| 635 |
text_y_dxf = inner_min_y - 0.125 - text_height_dxf
|
| 636 |
+
text_entity = msp.add_text(
|
| 637 |
+
annotation_text.strip(),
|
| 638 |
+
dxfattribs={
|
| 639 |
+
"height": text_height_dxf,
|
| 640 |
+
"layer": "ANNOTATION",
|
| 641 |
+
"style": "Bold"
|
| 642 |
+
}
|
| 643 |
+
)
|
| 644 |
text_entity.dxf.insert = (text_x, text_y_dxf)
|
| 645 |
|
| 646 |
# Save the DXF
|
|
|
|
| 660 |
text_y_in = inner_min_y - 0.125 - text_height_cv
|
| 661 |
text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
|
| 662 |
org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
|
| 663 |
+
|
| 664 |
+
cv2.putText(
|
| 665 |
+
output_img,
|
| 666 |
+
annotation_text.strip(),
|
| 667 |
+
org,
|
| 668 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 669 |
+
1.3,
|
| 670 |
+
(0, 0, 255),
|
| 671 |
+
3,
|
| 672 |
+
cv2.LINE_AA
|
| 673 |
+
)
|
| 674 |
+
cv2.putText(
|
| 675 |
+
new_outlines,
|
| 676 |
+
annotation_text.strip(),
|
| 677 |
+
org,
|
| 678 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 679 |
+
1.3,
|
| 680 |
+
(0, 0, 255),
|
| 681 |
+
3,
|
| 682 |
+
cv2.LINE_AA
|
| 683 |
+
)
|
| 684 |
|
| 685 |
# Restore brightness for display purposes:
|
| 686 |
# Since we reduced brightness by 0.5 during preprocessing,
|
|
|
|
| 691 |
|
| 692 |
outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
|
| 693 |
print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
|
| 694 |
+
|
| 695 |
+
return (
|
| 696 |
+
cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB),
|
| 697 |
+
outlines_color,
|
| 698 |
+
dxf_filepath,
|
| 699 |
+
dilated_mask,
|
| 700 |
+
str(scaling_factor)
|
| 701 |
+
)
|
| 702 |
|
| 703 |
# ---------------------
|
| 704 |
# Gradio Interface
|