manga_translation / utils /bubble_detect_rtdetr.py
qqwjq1981's picture
Update utils/bubble_detect_rtdetr.py
6a100df verified
# utils/bubble_detect_rtdetr.py
import torch
import numpy as np
from PIL import Image
import cv2
from shapely.geometry import Polygon
from shapely.ops import unary_union
from transformers import AutoImageProcessor, RTDetrForObjectDetection
from utils.polygon_utils import sanitize_polygon
MODEL_NAME = "ogkalu/comic-text-and-bubble-detector"
_processor = None
_model = None
# ------------------------------------------------------------
# Load model (cached)
# ------------------------------------------------------------
def load_rtdetr_model():
global _processor, _model
if _processor is None:
print("πŸ”„ Loading RT-DETR-v2 processor...")
_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
if _model is None:
print("πŸ”„ Loading RT-DETR-v2 model...")
_model = RTDetrForObjectDetection.from_pretrained(MODEL_NAME)
_model.eval()
if torch.cuda.is_available():
_model.to("cuda")
print("βœ… RT-DETR-v2 loaded.")
return _processor, _model
# ------------------------------------------------------------
# Run detector
# ------------------------------------------------------------
def detect_bubbles_rtdetr(image_pil, conf_threshold=0.30):
processor, model = load_rtdetr_model()
inputs = processor(images=image_pil, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image_pil.size[::-1]]) # (H, W)
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=conf_threshold
)[0]
detections = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
x1, y1, x2, y2 = [float(v) for v in box]
detections.append({
"class": int(label),
"score": float(score),
"bbox": [x1, y1, x2, y2]
})
return detections
# ------------------------------------------------------------
# Bubble box β†’ refined outer + inner safe polygon
# ------------------------------------------------------------
def refine_bubble_from_bbox(image_pil, bbox):
"""
Refines bubble polygon using Smart Flood-Fill.
Fixes the 'Seed Trap' where flood-fill accidentally fills text instead of background.
"""
import cv2
import numpy as np
from shapely.geometry import Polygon
img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
H, W = img.shape[:2]
x1, y1, x2, y2 = map(int, bbox)
# 1. ROI Padding
w_box = x2 - x1
h_box = y2 - y1
pad_x = max(20, int(w_box * 0.15))
pad_y = max(20, int(h_box * 0.15))
px1 = max(0, x1 - pad_x)
py1 = max(0, y1 - pad_y)
px2 = min(W, x2 + pad_x)
py2 = min(H, y2 + pad_y)
roi = img[py1:py2, px1:px2]
if roi.size == 0:
return None, None
h, w = roi.shape[:2]
gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
# ------------------------------------------------------------
# FIX: Smart Seed Search
# We search a small area (20x20) around the center for the
# brightest pixel (background) to avoid clicking on black text.
# ------------------------------------------------------------
cx, cy = w // 2, h // 2
# Crop a small safe zone in the center
search_radius = 15
sx1 = max(0, cx - search_radius)
sy1 = max(0, cy - search_radius)
sx2 = min(w, cx + search_radius)
sy2 = min(h, cy + search_radius)
center_patch = gray[sy1:sy2, sx1:sx2]
# Find the coordinates of the brightest pixel in this patch
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(center_patch)
# Adjust local patch coordinates back to ROI coordinates
seed_x = sx1 + max_loc[0]
seed_y = sy1 + max_loc[1]
seed = (seed_x, seed_y)
# Sanity Check: If the "brightest" pixel is still dark (e.g. night scene),
# flood fill might fail. But for bubbles, it should be > 200.
if max_val < 100:
# Fallback: Just try center anyway if everything is dark
seed = (cx, cy)
# ------------------------------------------------------------
# 2. Flood Fill
# ------------------------------------------------------------
# Blur slightly to ignore paper grain/noise
gray_blur = cv2.GaussianBlur(gray, (3, 3), 0)
mask = np.zeros((h + 2, w + 2), np.uint8)
flood_img = gray_blur.copy()
# Looser tolerance (30) helps capture gradients in bubble background
cv2.floodFill(
flood_img,
mask,
seedPoint=seed,
newVal=255,
loDiff=30,
upDiff=30,
flags=cv2.FLOODFILL_FIXED_RANGE
)
# The mask contains the "flooded" area.
# We want the filled area (mask == 1).
filled_mask = mask[1:-1, 1:-1]
# 3. Morphological Cleanup (Close text holes)
# A larger kernel (15) is better for bridging gaps over large text
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
filled_mask = cv2.morphologyEx(filled_mask, cv2.MORPH_CLOSE, kernel)
# 4. Find Contours
contours, _ = cv2.findContours(
filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
if not contours:
# Fallback to rect
outer = [(x1,y1), (x2,y1), (x2,y2), (x1,y2)]
return outer, outer
# Pick the largest contour that includes our seed point (robustness)
best_cnt = None
max_area = 0
for cnt in contours:
area = cv2.contourArea(cnt)
if area > max_area:
# Check if this contour actually contains our seed
if cv2.pointPolygonTest(cnt, seed, False) >= 0:
max_area = area
best_cnt = cnt
if best_cnt is None:
best_cnt = max(contours, key=cv2.contourArea)
# 5. Convert to Global Coords & Convex Hull
hull = cv2.convexHull(best_cnt)
# Smoothing
peri = cv2.arcLength(hull, True)
approx = cv2.approxPolyDP(hull, 0.002 * peri, True)
outer = [(int(p[0][0] + px1), int(p[0][1] + py1)) for p in approx]
# 6. Shrink for Inner
if len(outer) < 3:
return outer, outer
poly = Polygon(outer)
# Dynamic shrink: 5% of length
shrink_px = -0.05 * poly.length
# Clamp to reasonable values (don't shrink more than 15px)
shrink_px = max(shrink_px, -15.0)
inner_poly = poly.buffer(shrink_px)
if inner_poly.is_empty or inner_poly.area < poly.area * 0.4:
inner_poly = poly.buffer(-3) # Minimal shrink fallback
if inner_poly.geom_type == "MultiPolygon":
inner_poly = max(inner_poly.geoms, key=lambda g: g.area)
inner = [(int(x), int(y)) for x, y in inner_poly.exterior.coords[:-1]]
return outer, inner
# ------------------------------------------------------------
# Public: detect β†’ refine β†’ return polygons
# ------------------------------------------------------------
def detect_and_refine_bubbles(full_img, conf_threshold=0.30):
detections = detect_bubbles_rtdetr(full_img, conf_threshold)
# raw boxes from RT-DETR
bubble_boxes = [d["bbox"] for d in detections if d["class"] == 0]
bubble_polygons = []
interior_polygons = []
for i, bbox in enumerate(bubble_boxes):
outer, inner = refine_bubble_from_bbox(full_img, bbox)
# -----------------------------
# Sanitize outer polygon
# -----------------------------
outer = sanitize_polygon(outer)
if outer is None:
print(f"⚠️ Bubble {i}: outer invalid β†’ fallback to rectangle")
x1, y1, x2, y2 = map(int, bbox)
outer = [(x1,y1), (x2,y1), (x2,y2), (x1,y2)]
# -----------------------------
# Sanitize inner polygon
# -----------------------------
inner = sanitize_polygon(inner)
if inner is None:
print(f"⚠️ Bubble {i}: inner invalid β†’ using outer")
inner = outer
bubble_polygons.append(outer)
interior_polygons.append(inner)
print(f"✨ RT-DETR refined bubbles: {len(bubble_polygons)}")
# Debug summary
for i, poly in enumerate(interior_polygons):
if poly is None or len(poly) < 4:
print(f"❌ interior[{i}] INVALID ({poly})")
else:
print(f"βœ“ interior[{i}] OK ({len(poly)} pts)")
return bubble_polygons, interior_polygons, bubble_boxes
# ------------------------------------------------------------
# Polygon β†’ mask
# ------------------------------------------------------------
def polygon_to_mask(image_size, polygon):
W, H = image_size
mask = np.zeros((H, W), dtype=np.uint8)
if not polygon or len(polygon) < 3:
return mask
try:
pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], 255)
except Exception as e:
print(f"⚠️ polygon_to_mask failure: {e}")
return mask