potato commited on
Commit
1670782
Β·
1 Parent(s): 1de7538

add utils.py

Browse files
Files changed (1) hide show
  1. utils.py +216 -0
utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import vtracer
5
+ import svgpathtools
6
+ import cairosvg
7
+ import io
8
+ import cv2
9
+ from lxml import etree
10
+ from scipy.cluster.hierarchy import linkage, fcluster
11
+ from scipy.spatial.distance import cdist
12
+ from python_tsp.heuristics import solve_tsp_local_search
13
+ from fast_tsp import find_tour
14
+ from svgpathtools import Path
15
+ from tqdm import tqdm
16
+
17
+ def parse_transform(transform_str):
18
+ if not transform_str: return np.eye(3)
19
+ matrix = np.eye(3)
20
+
21
+ numbers = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
22
+
23
+ match = re.findall(r"matrix\(" + ",".join([numbers]*6) + r"\)", transform_str)
24
+ if match:
25
+ a, b, c, d, e, f = map(float, re.findall(numbers, match[0]))
26
+ m = np.array([[a, c, e], [b, d, f], [0, 0, 1]])
27
+ matrix = m @ matrix
28
+
29
+ match = re.findall(r"translate\(([^)]+)\)", transform_str)
30
+ if match:
31
+ parts = [float(v) for v in re.findall(numbers, match[0])]
32
+ tx, ty = parts if len(parts) == 2 else (parts[0], 0)
33
+ m = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
34
+ matrix = m @ matrix
35
+
36
+ match = re.findall(r"scale\(([^)]+)\)", transform_str)
37
+ if match:
38
+ parts = [float(v) for v in re.findall(numbers, match[0])]
39
+ sx, sy = parts if len(parts) == 2 else (parts[0], parts[0])
40
+ m = np.array([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
41
+ matrix = m @ matrix
42
+
43
+ return matrix
44
+
45
+ def get_global_transform(element):
46
+ transform = np.eye(3)
47
+ while element is not None:
48
+ t = element.get("transform")
49
+ if t: transform = parse_transform(t) @ transform
50
+ element = element.getparent()
51
+ return transform
52
+
53
+ def get_transformed_paths_and_coords(svg_content, mode):
54
+ """
55
+ Parses the SVG, extracts path elements and their global transforms,
56
+ and returns data structures for sequencing and later rendering.
57
+ """
58
+ parser = etree.XMLParser(remove_blank_text=True)
59
+ if mode == 'file':
60
+ svg_content = io.BytesIO(svg_content.encode('utf-8'))
61
+ tree = etree.parse(svg_content, parser)
62
+ root = tree.getroot()
63
+
64
+ width_str = root.get("width")
65
+ height_str = root.get("height")
66
+ viewBox = root.get("viewBox")
67
+ path_elements = root.findall(".//{*}path")
68
+
69
+ if not viewBox and width_str and height_str:
70
+ # Remove potential units like 'px' to get clean numbers
71
+ width = re.sub(r'[a-zA-Z%]', '', width_str)
72
+ height = re.sub(r'[a-zA-Z%]', '', height_str)
73
+ viewBox = f"0 0 {width} {height}"
74
+ print(f"SVG missing viewBox. Created a default: '{viewBox}'")
75
+
76
+ width = root.get("width")
77
+ height = root.get("height")
78
+ transformed_coords = []
79
+ paths_data = []
80
+
81
+ for elem in path_elements:
82
+ d_string = elem.get('d')
83
+ if not d_string:
84
+ continue
85
+ path = svgpathtools.parse_path(d_string)
86
+
87
+ transform = get_global_transform(elem)
88
+
89
+ start_vec = np.array([[path.start.real], [path.start.imag], [1]])
90
+ transformed_start = transform @ start_vec
91
+
92
+ coord = (transformed_start[0, 0], transformed_start[1, 0])
93
+ transformed_coords.append(coord)
94
+
95
+ paths_data.append({
96
+ 'path': path,
97
+ 'transform': transform,
98
+ 'element': elem,
99
+ 'coord': coord,
100
+ })
101
+
102
+ print(f"Extracted {len(transformed_coords)} paths with their elements.")
103
+ return paths_data, np.array(transformed_coords), width, height, viewBox
104
+
105
+ def transform_path(path, matrix):
106
+ """Apply a 3x3 numpy transform to an svgpathtools Path object."""
107
+ new_segments = []
108
+ for seg in path:
109
+ start = np.array([[seg.start.real], [seg.start.imag], [1]])
110
+ end = np.array([[seg.end.real], [seg.end.imag], [1]])
111
+ start_t = matrix @ start
112
+ end_t = matrix @ end
113
+ seg.start = complex(start_t[0,0], start_t[1,0])
114
+ seg.end = complex(end_t[0,0], end_t[1,0])
115
+ new_segments.append(seg)
116
+ return Path(*new_segments)
117
+
118
+ def sequence_strokes(paths_data, coords, proximity_threshold=40):
119
+ """
120
+ Clusters strokes by proximity and then finds the optimal drawing order
121
+ """
122
+ if len(coords) < 2:
123
+ print("Fewer than 2 strokes, no sequencing needed.")
124
+ return paths_data
125
+
126
+ print("Clustering strokes by proximity...")
127
+
128
+ Z = linkage(coords, method='ward')
129
+ labels = fcluster(Z, t=proximity_threshold, criterion='distance')
130
+ num_clusters = len(set(labels))
131
+ print(f"{num_clusters} clusters detected.")
132
+
133
+ if num_clusters <= 1:
134
+ print("All strokes are in a single cluster, no reordering needed.")
135
+ return paths_data
136
+
137
+ clusters = {i: {'paths_data': [], 'coords': []} for i in range(1, num_clusters + 1)}
138
+ for i, label in enumerate(labels):
139
+ clusters[label]['paths_data'].append(paths_data[i])
140
+ clusters[label]['coords'].append(coords[i])
141
+
142
+ centroids = [np.mean(c['coords'], axis=0) for c in clusters.values()]
143
+
144
+ print("Solving TSP for optimal cluster drawing order...")
145
+ distance_matrix_float = cdist(centroids, centroids)
146
+
147
+ integer_distance_matrix = distance_matrix_float.astype(np.int32).tolist()
148
+
149
+ permutation = find_tour(integer_distance_matrix)
150
+
151
+ final_sequence = []
152
+ for cluster_idx in permutation:
153
+ cluster_label = cluster_idx + 1
154
+ final_sequence.extend(clusters[cluster_label]['paths_data'])
155
+
156
+ print("Final stroke sequence created.")
157
+ return final_sequence
158
+
159
+ def serialize_paths(paths_data):
160
+ serialized = []
161
+ for data in paths_data:
162
+ elem = data['element']
163
+ style = elem.get("style", "")
164
+ fill_match = re.search(r'fill:\s*([^;]+)', style)
165
+ stroke_match = re.search(r'stroke:\s*([^;]+)', style)
166
+
167
+ serialized.append({
168
+ "d": elem.get("d"),
169
+ "transform": elem.get("transform"),
170
+ "fill": elem.get("fill") or (fill_match.group(1).strip() if fill_match else "#000000"),
171
+ })
172
+ return serialized
173
+
174
+ def process_svg(svg_content, mode):
175
+ paths_data, coords, width, height, viewBox = get_transformed_paths_and_coords(svg_content, mode)
176
+
177
+ for data in paths_data:
178
+ path = data['path']
179
+ xmin, xmax, ymin, ymax = path.bbox()
180
+ area = (xmax - xmin) * (ymax - ymin)
181
+ data['area'] = area
182
+
183
+ areas = [p['area'] for p in paths_data]
184
+ areas_sorted = sorted(areas)
185
+ median_index = len(areas_sorted) // 2
186
+ layer_index = int(len(areas_sorted) * 0.99)
187
+ areaThreshold = areas_sorted[median_index]
188
+ layerThreshold = areas_sorted[layer_index]
189
+
190
+ fill_strokes = []
191
+ detail_strokes = []
192
+ layer_strokes = []
193
+
194
+ for data in paths_data:
195
+ if data['area'] >= layerThreshold:
196
+ layer_strokes.append(data)
197
+ elif data['area'] >= areaThreshold:
198
+ fill_strokes.append(data)
199
+ else:
200
+ detail_strokes.append(data)
201
+
202
+ ordered_fills = sequence_strokes(fill_strokes, np.array([p['coord'] for p in fill_strokes]))
203
+ ordered_details = sequence_strokes(detail_strokes, np.array([p['coord'] for p in detail_strokes]))
204
+
205
+ return {
206
+ "width": width,
207
+ "height": height,
208
+ "viewbox": viewBox,
209
+ "layers": serialize_paths(layer_strokes),
210
+ "fills": serialize_paths(ordered_fills),
211
+ "details": serialize_paths(ordered_details)
212
+ }
213
+
214
+
215
+
216
+