manga_translation / utils /ocr_utils.py
qqwjq1981's picture
Update utils/ocr_utils.py
d5f595e verified
raw
history blame
3.6 kB
from paddleocr import PaddleOCR
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from scipy.spatial import ConvexHull
from utils.azure_translate import translate_text_azure
from math import dist
import numpy as np
from shapely.geometry import box as shapely_box
from shapely.geometry import Polygon
from shapely.ops import unary_union
import networkx as nx
from shapely.ops import unary_union
ocr_model = PaddleOCR(use_textline_orientation=True, lang='ch')
def inflate_polygon(polygon_points, percent=0.05):
poly = Polygon(polygon_points)
if not poly.is_valid:
poly = poly.convex_hull
minx, miny, maxx, maxy = poly.bounds
diagonal = ((maxx - minx)**2 + (maxy - miny)**2)**0.5
inflate_dist = diagonal * percent
return poly.buffer(inflate_dist)
def group_nearby_boxes(lines, inflation_percent=0.05):
from collections import defaultdict
n = len(lines)
inflated_polys = []
original_polys = []
texts = []
for poly_pts, text in lines:
inflated = inflate_polygon(poly_pts, percent=inflation_percent)
original = Polygon(poly_pts)
inflated_polys.append(inflated)
original_polys.append(original)
texts.append(text)
# Build connectivity graph
adjacency = defaultdict(set)
for i in range(n):
for j in range(i + 1, n):
if inflated_polys[i].intersects(inflated_polys[j]):
adjacency[i].add(j)
adjacency[j].add(i)
# DFS to find connected components
visited = [False] * n
groups = []
def dfs(i, group):
visited[i] = True
group.append(i)
for neighbor in adjacency[i]:
if not visited[neighbor]:
dfs(neighbor, group)
for i in range(n):
if not visited[i]:
group = []
dfs(i, group)
groups.append(group)
# Construct output groups
grouped = []
for group in groups:
group_polys = [list(original_polys[i].exterior.coords) for i in group]
group_texts = [texts[i] for i in group]
grouped.append({
"polygons": group_polys,
"texts": group_texts
})
return grouped
def extract_and_translate_chunk(image: Image.Image):
np_img = np.array(image)
results = ocr_model.ocr(np_img)
if not results or not isinstance(results[0], dict):
return []
result_dict = results[0]
polygons = result_dict.get("rec_polys", [])
texts = result_dict.get("rec_texts", [])
if not polygons or not texts or len(polygons) != len(texts):
return []
lines = list(zip([[(int(x), int(y)) for x, y in poly] for poly in polygons], texts))
print("🔍 OCR Raw Output:", lines)
grouped = group_nearby_boxes(lines)
translations = []
for group in grouped:
polygons = group["polygons"]
texts = group["texts"]
merged_text = "".join(texts).strip()
if not merged_text:
continue
try:
translated = translate_text_azure(merged_text)
except Exception as e:
print("⚠️ Translation failed:", e)
translated = ""
all_points = np.array([pt for polygon in polygons for pt in polygon])
if len(all_points) < 3:
continue
hull_indices = ConvexHull(all_points).vertices
hull = [tuple(map(int, all_points[i])) for i in hull_indices]
translations.append({
"original": merged_text,
"translated": translated,
"polygon": hull
})
return translations