Spaces:
Runtime error
Runtime error
| # utils/bubble_detect_rtdetr.py | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from shapely.geometry import Polygon | |
| from shapely.ops import unary_union | |
| from transformers import AutoImageProcessor, RTDetrForObjectDetection | |
| from utils.polygon_utils import sanitize_polygon | |
| MODEL_NAME = "ogkalu/comic-text-and-bubble-detector" | |
| _processor = None | |
| _model = None | |
| # ------------------------------------------------------------ | |
| # Load model (cached) | |
| # ------------------------------------------------------------ | |
| def load_rtdetr_model(): | |
| global _processor, _model | |
| if _processor is None: | |
| print("π Loading RT-DETR-v2 processor...") | |
| _processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| if _model is None: | |
| print("π Loading RT-DETR-v2 model...") | |
| _model = RTDetrForObjectDetection.from_pretrained(MODEL_NAME) | |
| _model.eval() | |
| if torch.cuda.is_available(): | |
| _model.to("cuda") | |
| print("β RT-DETR-v2 loaded.") | |
| return _processor, _model | |
| # ------------------------------------------------------------ | |
| # Run detector | |
| # ------------------------------------------------------------ | |
| def detect_bubbles_rtdetr(image_pil, conf_threshold=0.30): | |
| processor, model = load_rtdetr_model() | |
| inputs = processor(images=image_pil, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image_pil.size[::-1]]) # (H, W) | |
| results = processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=conf_threshold | |
| )[0] | |
| detections = [] | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| x1, y1, x2, y2 = [float(v) for v in box] | |
| detections.append({ | |
| "class": int(label), | |
| "score": float(score), | |
| "bbox": [x1, y1, x2, y2] | |
| }) | |
| return detections | |
| # ------------------------------------------------------------ | |
| # Bubble box β refined outer + inner safe polygon | |
| # ------------------------------------------------------------ | |
| def refine_bubble_from_bbox(image_pil, bbox): | |
| """ | |
| Refines bubble polygon using Smart Flood-Fill. | |
| Fixes the 'Seed Trap' where flood-fill accidentally fills text instead of background. | |
| """ | |
| import cv2 | |
| import numpy as np | |
| from shapely.geometry import Polygon | |
| img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| H, W = img.shape[:2] | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # 1. ROI Padding | |
| w_box = x2 - x1 | |
| h_box = y2 - y1 | |
| pad_x = max(20, int(w_box * 0.15)) | |
| pad_y = max(20, int(h_box * 0.15)) | |
| px1 = max(0, x1 - pad_x) | |
| py1 = max(0, y1 - pad_y) | |
| px2 = min(W, x2 + pad_x) | |
| py2 = min(H, y2 + pad_y) | |
| roi = img[py1:py2, px1:px2] | |
| if roi.size == 0: | |
| return None, None | |
| h, w = roi.shape[:2] | |
| gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) | |
| # ------------------------------------------------------------ | |
| # FIX: Smart Seed Search | |
| # We search a small area (20x20) around the center for the | |
| # brightest pixel (background) to avoid clicking on black text. | |
| # ------------------------------------------------------------ | |
| cx, cy = w // 2, h // 2 | |
| # Crop a small safe zone in the center | |
| search_radius = 15 | |
| sx1 = max(0, cx - search_radius) | |
| sy1 = max(0, cy - search_radius) | |
| sx2 = min(w, cx + search_radius) | |
| sy2 = min(h, cy + search_radius) | |
| center_patch = gray[sy1:sy2, sx1:sx2] | |
| # Find the coordinates of the brightest pixel in this patch | |
| min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(center_patch) | |
| # Adjust local patch coordinates back to ROI coordinates | |
| seed_x = sx1 + max_loc[0] | |
| seed_y = sy1 + max_loc[1] | |
| seed = (seed_x, seed_y) | |
| # Sanity Check: If the "brightest" pixel is still dark (e.g. night scene), | |
| # flood fill might fail. But for bubbles, it should be > 200. | |
| if max_val < 100: | |
| # Fallback: Just try center anyway if everything is dark | |
| seed = (cx, cy) | |
| # ------------------------------------------------------------ | |
| # 2. Flood Fill | |
| # ------------------------------------------------------------ | |
| # Blur slightly to ignore paper grain/noise | |
| gray_blur = cv2.GaussianBlur(gray, (3, 3), 0) | |
| mask = np.zeros((h + 2, w + 2), np.uint8) | |
| flood_img = gray_blur.copy() | |
| # Looser tolerance (30) helps capture gradients in bubble background | |
| cv2.floodFill( | |
| flood_img, | |
| mask, | |
| seedPoint=seed, | |
| newVal=255, | |
| loDiff=30, | |
| upDiff=30, | |
| flags=cv2.FLOODFILL_FIXED_RANGE | |
| ) | |
| # The mask contains the "flooded" area. | |
| # We want the filled area (mask == 1). | |
| filled_mask = mask[1:-1, 1:-1] | |
| # 3. Morphological Cleanup (Close text holes) | |
| # A larger kernel (15) is better for bridging gaps over large text | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) | |
| filled_mask = cv2.morphologyEx(filled_mask, cv2.MORPH_CLOSE, kernel) | |
| # 4. Find Contours | |
| contours, _ = cv2.findContours( | |
| filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| if not contours: | |
| # Fallback to rect | |
| outer = [(x1,y1), (x2,y1), (x2,y2), (x1,y2)] | |
| return outer, outer | |
| # Pick the largest contour that includes our seed point (robustness) | |
| best_cnt = None | |
| max_area = 0 | |
| for cnt in contours: | |
| area = cv2.contourArea(cnt) | |
| if area > max_area: | |
| # Check if this contour actually contains our seed | |
| if cv2.pointPolygonTest(cnt, seed, False) >= 0: | |
| max_area = area | |
| best_cnt = cnt | |
| if best_cnt is None: | |
| best_cnt = max(contours, key=cv2.contourArea) | |
| # 5. Convert to Global Coords & Convex Hull | |
| hull = cv2.convexHull(best_cnt) | |
| # Smoothing | |
| peri = cv2.arcLength(hull, True) | |
| approx = cv2.approxPolyDP(hull, 0.002 * peri, True) | |
| outer = [(int(p[0][0] + px1), int(p[0][1] + py1)) for p in approx] | |
| # 6. Shrink for Inner | |
| if len(outer) < 3: | |
| return outer, outer | |
| poly = Polygon(outer) | |
| # Dynamic shrink: 5% of length | |
| shrink_px = -0.05 * poly.length | |
| # Clamp to reasonable values (don't shrink more than 15px) | |
| shrink_px = max(shrink_px, -15.0) | |
| inner_poly = poly.buffer(shrink_px) | |
| if inner_poly.is_empty or inner_poly.area < poly.area * 0.4: | |
| inner_poly = poly.buffer(-3) # Minimal shrink fallback | |
| if inner_poly.geom_type == "MultiPolygon": | |
| inner_poly = max(inner_poly.geoms, key=lambda g: g.area) | |
| inner = [(int(x), int(y)) for x, y in inner_poly.exterior.coords[:-1]] | |
| return outer, inner | |
| # ------------------------------------------------------------ | |
| # Public: detect β refine β return polygons | |
| # ------------------------------------------------------------ | |
| def detect_and_refine_bubbles(full_img, conf_threshold=0.30): | |
| detections = detect_bubbles_rtdetr(full_img, conf_threshold) | |
| # raw boxes from RT-DETR | |
| bubble_boxes = [d["bbox"] for d in detections if d["class"] == 0] | |
| bubble_polygons = [] | |
| interior_polygons = [] | |
| for i, bbox in enumerate(bubble_boxes): | |
| outer, inner = refine_bubble_from_bbox(full_img, bbox) | |
| # ----------------------------- | |
| # Sanitize outer polygon | |
| # ----------------------------- | |
| outer = sanitize_polygon(outer) | |
| if outer is None: | |
| print(f"β οΈ Bubble {i}: outer invalid β fallback to rectangle") | |
| x1, y1, x2, y2 = map(int, bbox) | |
| outer = [(x1,y1), (x2,y1), (x2,y2), (x1,y2)] | |
| # ----------------------------- | |
| # Sanitize inner polygon | |
| # ----------------------------- | |
| inner = sanitize_polygon(inner) | |
| if inner is None: | |
| print(f"β οΈ Bubble {i}: inner invalid β using outer") | |
| inner = outer | |
| bubble_polygons.append(outer) | |
| interior_polygons.append(inner) | |
| print(f"β¨ RT-DETR refined bubbles: {len(bubble_polygons)}") | |
| # Debug summary | |
| for i, poly in enumerate(interior_polygons): | |
| if poly is None or len(poly) < 4: | |
| print(f"β interior[{i}] INVALID ({poly})") | |
| else: | |
| print(f"β interior[{i}] OK ({len(poly)} pts)") | |
| return bubble_polygons, interior_polygons, bubble_boxes | |
| # ------------------------------------------------------------ | |
| # Polygon β mask | |
| # ------------------------------------------------------------ | |
| def polygon_to_mask(image_size, polygon): | |
| W, H = image_size | |
| mask = np.zeros((H, W), dtype=np.uint8) | |
| if not polygon or len(polygon) < 3: | |
| return mask | |
| try: | |
| pts = np.array(polygon, np.int32).reshape((-1, 1, 2)) | |
| cv2.fillPoly(mask, [pts], 255) | |
| except Exception as e: | |
| print(f"β οΈ polygon_to_mask failure: {e}") | |
| return mask | |