qqwjq1981 commited on
Commit
d5f595e
·
verified ·
1 Parent(s): a479a6f

Update utils/ocr_utils.py

Browse files
Files changed (1) hide show
  1. utils/ocr_utils.py +50 -27
utils/ocr_utils.py CHANGED
@@ -9,48 +9,71 @@ from shapely.geometry import box as shapely_box
9
  from shapely.geometry import Polygon
10
  from shapely.ops import unary_union
11
  import networkx as nx
 
12
 
13
  ocr_model = PaddleOCR(use_textline_orientation=True, lang='ch')
14
 
15
- def inflate_bbox(polygon, inflation_ratio=0.05):
16
- xs, ys = zip(*polygon)
17
- x_min, x_max = min(xs), max(xs)
18
- y_min, y_max = min(ys), max(ys)
19
-
20
- width = x_max - x_min
21
- height = y_max - y_min
22
 
23
- x_min -= width * inflation_ratio
24
- x_max += width * inflation_ratio
25
- y_min -= height * inflation_ratio
26
- y_max += height * inflation_ratio
 
 
 
 
27
 
28
- return shapely_box(x_min, y_min, x_max, y_max)
 
29
 
30
- def group_nearby_boxes(lines, inflation_ratio=0.05):
31
  n = len(lines)
32
- inflated_boxes = [inflate_bbox(poly, inflation_ratio) for poly, _ in lines]
33
-
34
- # Build graph where edges represent overlapping boxes
35
- G = nx.Graph()
36
- G.add_nodes_from(range(n))
37
-
 
 
 
 
 
 
 
38
  for i in range(n):
39
  for j in range(i + 1, n):
40
- if inflated_boxes[i].intersects(inflated_boxes[j]):
41
- G.add_edge(i, j)
 
42
 
43
- # Extract connected components as groups
 
44
  groups = []
45
- for component in nx.connected_components(G):
46
- group_polys = [lines[i][0] for i in component]
47
- group_texts = [lines[i][1] for i in component]
48
- groups.append({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  "polygons": group_polys,
50
  "texts": group_texts
51
  })
52
 
53
- return groups
54
 
55
  def extract_and_translate_chunk(image: Image.Image):
56
  np_img = np.array(image)
 
9
  from shapely.geometry import Polygon
10
  from shapely.ops import unary_union
11
  import networkx as nx
12
+ from shapely.ops import unary_union
13
 
14
  ocr_model = PaddleOCR(use_textline_orientation=True, lang='ch')
15
 
 
 
 
 
 
 
 
16
 
17
+ def inflate_polygon(polygon_points, percent=0.05):
18
+ poly = Polygon(polygon_points)
19
+ if not poly.is_valid:
20
+ poly = poly.convex_hull
21
+ minx, miny, maxx, maxy = poly.bounds
22
+ diagonal = ((maxx - minx)**2 + (maxy - miny)**2)**0.5
23
+ inflate_dist = diagonal * percent
24
+ return poly.buffer(inflate_dist)
25
 
26
+ def group_nearby_boxes(lines, inflation_percent=0.05):
27
+ from collections import defaultdict
28
 
 
29
  n = len(lines)
30
+ inflated_polys = []
31
+ original_polys = []
32
+ texts = []
33
+
34
+ for poly_pts, text in lines:
35
+ inflated = inflate_polygon(poly_pts, percent=inflation_percent)
36
+ original = Polygon(poly_pts)
37
+ inflated_polys.append(inflated)
38
+ original_polys.append(original)
39
+ texts.append(text)
40
+
41
+ # Build connectivity graph
42
+ adjacency = defaultdict(set)
43
  for i in range(n):
44
  for j in range(i + 1, n):
45
+ if inflated_polys[i].intersects(inflated_polys[j]):
46
+ adjacency[i].add(j)
47
+ adjacency[j].add(i)
48
 
49
+ # DFS to find connected components
50
+ visited = [False] * n
51
  groups = []
52
+
53
+ def dfs(i, group):
54
+ visited[i] = True
55
+ group.append(i)
56
+ for neighbor in adjacency[i]:
57
+ if not visited[neighbor]:
58
+ dfs(neighbor, group)
59
+
60
+ for i in range(n):
61
+ if not visited[i]:
62
+ group = []
63
+ dfs(i, group)
64
+ groups.append(group)
65
+
66
+ # Construct output groups
67
+ grouped = []
68
+ for group in groups:
69
+ group_polys = [list(original_polys[i].exterior.coords) for i in group]
70
+ group_texts = [texts[i] for i in group]
71
+ grouped.append({
72
  "polygons": group_polys,
73
  "texts": group_texts
74
  })
75
 
76
+ return grouped
77
 
78
  def extract_and_translate_chunk(image: Image.Image):
79
  np_img = np.array(image)