3morrrrr commited on
Commit
c0f74fe
·
verified ·
1 Parent(s): 5b17b59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -57
app.py CHANGED
@@ -3,18 +3,18 @@ import re
3
  import xml.etree.ElementTree as ET
4
  import gradio as gr
5
 
6
- # If you use these in your env, keep them. Otherwise they're optional.
7
- # from huggingface_hub import hf_hub_download
8
- # from handwriting_api import InputData, validate_input
9
 
10
- from hand import Hand # your handwriting model wrapper
11
-
12
- # --- Setup --------------------------------------------------------------------
13
  os.makedirs("img", exist_ok=True)
14
  hand = Hand()
15
 
16
- # --- SVG helpers ---------------------------------------------------------------
17
-
 
18
  def _extract_paths(svg_root_or_group):
19
  """Return a list of <path> elements (namespace-agnostic)."""
20
  return [el for el in svg_root_or_group.iter() if el.tag.endswith('path')]
@@ -41,12 +41,51 @@ def _translate_group(elem, dx, dy):
41
  t = f"translate({dx},{dy})"
42
  elem.set('transform', (prev + " " + t).strip())
43
 
44
- # --- Segment compositor (the key fix) -----------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  def _render_part_group(text_part, style, bias, color, stroke_width):
47
  """
48
- Render a single text 'part' (no underscores) using the Hand model to a temp SVG,
49
- pull out its paths, wrap them in a <g>, and return (group, bbox).
50
  """
51
  # Avoid empty part causing errors: render a single space
52
  lines = [text_part if text_part else " "]
@@ -78,42 +117,49 @@ def render_line_with_underscores(
78
  underscore_len=22
79
  ):
80
  """
81
- Compose a line that may include underscores by rendering surrounding segments,
82
- measuring their true widths, and drawing the underscore at the baseline.
83
- Returns (list_of_elements, rightmost_x).
84
  """
85
  parts = line.split('_')
86
 
87
- # Render every part and measure bboxes
88
- part_groups, part_bboxes = [], []
 
89
  for part in parts:
90
  g, bbox = _render_part_group(part, style, bias, color, stroke_width)
91
  part_groups.append(g)
92
- part_bboxes.append(bbox)
93
 
94
- # Estimate line metrics from first non-empty part (fallbacks if needed)
 
 
 
 
 
 
 
 
95
  line_height = 40.0
96
- for bbox in part_bboxes:
97
- minx, miny, maxx, maxy = bbox
98
  if maxy - miny > 0.0:
99
  line_height = max(20.0, maxy - miny)
100
  break
101
  base_y = line_y + 0.8 * line_height
102
 
103
- # Compose into positioned elements
104
  composed = []
105
  cursor_x = x_start
106
 
107
- for i, (g, bbox) in enumerate(zip(part_groups, part_bboxes)):
108
  minx, _, maxx, _ = bbox
109
- width = max(0.0, maxx - minx)
110
 
111
- # Place this segment at current cursor_x
112
  _translate_group(g, dx=(cursor_x - minx), dy=line_y)
113
  composed.append(g)
114
- cursor_x += width
115
 
116
- # If there is an underscore after this part, draw it now
117
  if i < len(part_groups) - 1:
118
  x0 = cursor_x + gap_left
119
  x1 = x0 + underscore_len
@@ -124,12 +170,13 @@ def render_line_with_underscores(
124
  underscore.set('fill', 'none')
125
  underscore.set('stroke-linecap', 'round')
126
  composed.append(underscore)
127
- cursor_x = x1 + gap_right # advance cursor
128
 
129
  return composed, cursor_x
130
 
131
- # --- Handwriting generation (multi-line + composition) ------------------------
132
-
 
133
  def generate_handwriting(
134
  text,
135
  style,
@@ -139,18 +186,19 @@ def generate_handwriting(
139
  multiline=True
140
  ):
141
  """
142
- Generate a composed SVG that places underscores geometrically between segments.
 
143
  """
144
  try:
145
  lines = text.split('\n') if multiline else [text]
146
 
147
- # Light validation and slash normalization (your old behavior)
148
  for idx, ln in enumerate(lines):
149
  if len(ln) > 75:
150
  return f"Error: Line {idx + 1} is too long (max 75 characters)"
151
- lines[idx] = ln.replace('/', '-').replace('\\', '-') # keep underscores intact
152
 
153
- # Create a fresh SVG root we control (no background rect -> transparent)
154
  svg_root = ET.Element('svg', {
155
  'xmlns': 'http://www.w3.org/2000/svg',
156
  'viewBox': '0 0 1200 800'
@@ -184,7 +232,6 @@ def generate_handwriting(
184
  svg_root.set('viewBox', f"0 0 {width} {height}")
185
 
186
  svg_content = ET.tostring(svg_root, encoding='unicode')
187
- # Persist for download
188
  with open("img/output.svg", "w", encoding="utf-8") as f:
189
  f.write(svg_content)
190
  return svg_content
@@ -192,8 +239,9 @@ def generate_handwriting(
192
  except Exception as e:
193
  return f"Error: {str(e)}"
194
 
195
- # --- PNG export (transparent) --------------------------------------------------
196
-
 
197
  def export_to_png(svg_content):
198
  """Convert SVG to transparent PNG using CairoSVG and Pillow."""
199
  try:
@@ -203,12 +251,11 @@ def export_to_png(svg_content):
203
  if not svg_content or svg_content.startswith("Error:"):
204
  return None
205
 
206
- # Ensure we write the current svg to disk (CairoSVG can read from file)
207
  tmp_svg = "img/temp.svg"
208
  with open(tmp_svg, "w", encoding="utf-8") as f:
209
  f.write(svg_content)
210
 
211
- # Render at higher scale for crisp strokes
212
  cairosvg.svg2png(
213
  url=tmp_svg,
214
  write_to="img/output_temp.png",
@@ -220,7 +267,7 @@ def export_to_png(svg_content):
220
  if img.mode != 'RGBA':
221
  img = img.convert('RGBA')
222
 
223
- # Optional: force near-white background fully transparent (safety)
224
  datas = img.getdata()
225
  new_data = []
226
  for item in datas:
@@ -233,7 +280,6 @@ def export_to_png(svg_content):
233
  out_path = "img/output.png"
234
  img.save(out_path, "PNG")
235
 
236
- # cleanup
237
  try:
238
  os.remove("img/output_temp.png")
239
  except:
@@ -245,12 +291,12 @@ def export_to_png(svg_content):
245
  print(f"Error converting to PNG: {str(e)}")
246
  return None
247
 
248
- # --- Gradio UI ----------------------------------------------------------------
249
-
 
250
  def generate_handwriting_wrapper(text, style, bias, color, stroke_width, multiline=True):
251
  svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
252
  png_path = export_to_png(svg)
253
- # Display SVG inline; return file path for PNG
254
  return svg, png_path, "img/output.svg"
255
 
256
  css = """
@@ -262,13 +308,13 @@ css = """
262
 
263
  with gr.Blocks(css=css) as demo:
264
  gr.Markdown("# 🖋️ Handwriting Synthesis (Underscore-safe)")
265
- gr.Markdown("Generate realistic handwritten text with **true** underscore alignment—no retraining required.")
266
 
267
  with gr.Row():
268
  with gr.Column(scale=2):
269
  text_input = gr.Textbox(
270
  label="Text Input",
271
- placeholder="Try: zeb_3asba or user_name → underscores render perfectly",
272
  lines=5,
273
  max_lines=10,
274
  )
@@ -295,10 +341,9 @@ with gr.Blocks(css=css) as demo:
295
 
296
  gr.Markdown("""
297
  ### Notes
298
- - Underscores are drawn **between** segments using real geometry → perfect alignment.
299
  - Slashes (/, \\) are normalized to dashes (-) for model stability.
300
- - Each line ≤ 75 characters.
301
- - Transparent PNG export included.
302
  """)
303
 
304
  generate_btn.click(
@@ -317,24 +362,25 @@ with gr.Blocks(css=css) as demo:
317
  outputs=[text_input, style_select, bias_slider, color_picker, stroke_width]
318
  )
319
 
320
- # --- Main ---------------------------------------------------------------------
321
-
 
322
  if __name__ == "__main__":
323
  port = int(os.environ.get("PORT", 7860))
324
 
325
- # Soft check for optional deps
326
- missing_packages = []
327
  try:
328
  import cairosvg # noqa
329
  except ImportError:
330
- missing_packages.append("cairosvg")
331
  try:
332
  from PIL import Image # noqa
333
  except ImportError:
334
- missing_packages.append("pillow")
335
 
336
- if missing_packages:
337
- print(f"WARNING: Missing packages for transparent PNG export: {', '.join(missing_packages)}")
338
- print("Install with: pip install " + " ".join(missing_packages))
339
 
340
  demo.launch(server_name="0.0.0.0", server_port=port)
 
3
  import xml.etree.ElementTree as ET
4
  import gradio as gr
5
 
6
+ # Your handwriting model wrapper
7
+ from hand import Hand
 
8
 
9
+ # -----------------------------------------------------------------------------
10
+ # Setup
11
+ # -----------------------------------------------------------------------------
12
  os.makedirs("img", exist_ok=True)
13
  hand = Hand()
14
 
15
+ # -----------------------------------------------------------------------------
16
+ # SVG helpers
17
+ # -----------------------------------------------------------------------------
18
  def _extract_paths(svg_root_or_group):
19
  """Return a list of <path> elements (namespace-agnostic)."""
20
  return [el for el in svg_root_or_group.iter() if el.tag.endswith('path')]
 
41
  t = f"translate({dx},{dy})"
42
  elem.set('transform', (prev + " " + t).strip())
43
 
44
+ # -----------------------------------------------------------------------------
45
+ # Visual (raster) measurement for exact spacing
46
+ # -----------------------------------------------------------------------------
47
+ def _tight_width_via_raster(group_elem, stroke_color="#000", stroke_width=2):
48
+ """
49
+ Render a tiny SVG containing only this group, rasterize to PNG in-memory,
50
+ and return the non-transparent width in pixels. Produces pixel-accurate
51
+ width that's immune to relative path quirks.
52
+ """
53
+ import io
54
+ from PIL import Image
55
+ import cairosvg
56
+ import xml.etree.ElementTree as ET
57
+
58
+ # Build a tiny standalone SVG for the group (children are already paths)
59
+ svg = ET.Element('svg', {
60
+ 'xmlns': 'http://www.w3.org/2000/svg',
61
+ 'viewBox': '0 0 1200 400'
62
+ })
63
+ g = ET.Element('g')
64
+ for p in list(group_elem):
65
+ g.append(p)
66
+ svg.append(g)
67
+
68
+ svg_bytes = ET.tostring(svg, encoding='utf-8')
69
+
70
+ # Rasterize in memory
71
+ png_bytes = cairosvg.svg2png(bytestring=svg_bytes, scale=1.0, background_color="none")
72
+ img = Image.open(io.BytesIO(png_bytes)).convert('RGBA')
73
 
74
+ # Non-transparent bounding box
75
+ alpha = img.split()[-1]
76
+ bbox = alpha.getbbox() # (left, top, right, bottom) or None
77
+ if not bbox:
78
+ return 0.0
79
+ left, top, right, bottom = bbox
80
+ return float(right - left)
81
+
82
+ # -----------------------------------------------------------------------------
83
+ # Segment compositor (key fix for underscores)
84
+ # -----------------------------------------------------------------------------
85
  def _render_part_group(text_part, style, bias, color, stroke_width):
86
  """
87
+ Render one underscore-free 'part' using the Hand model to a temp SVG,
88
+ extract its path group, and return (group, path_bbox).
89
  """
90
  # Avoid empty part causing errors: render a single space
91
  lines = [text_part if text_part else " "]
 
117
  underscore_len=22
118
  ):
119
  """
120
+ Compose a line with underscores by rendering surrounding segments,
121
+ measuring their *visual* widths via rasterization, and drawing the
122
+ underscore at the baseline. Returns (list_of_elements, rightmost_x).
123
  """
124
  parts = line.split('_')
125
 
126
+ part_groups, visual_widths, path_bboxes = [], [], []
127
+
128
+ # 1) Render each part to a <g>
129
  for part in parts:
130
  g, bbox = _render_part_group(part, style, bias, color, stroke_width)
131
  part_groups.append(g)
132
+ path_bboxes.append(bbox)
133
 
134
+ # 2) Measure visual widths robustly (pixel-accurate)
135
+ for g in part_groups:
136
+ w = _tight_width_via_raster(g, stroke_color=color, stroke_width=stroke_width)
137
+ if w <= 0.1: # fallback to path bbox width if needed
138
+ minx, _, maxx, _ = _bbox_of_paths(_extract_paths(g))
139
+ w = max(0.0, maxx - minx)
140
+ visual_widths.append(w)
141
+
142
+ # 3) Estimate line metrics for baseline
143
  line_height = 40.0
144
+ for minx, miny, maxx, maxy in path_bboxes:
 
145
  if maxy - miny > 0.0:
146
  line_height = max(20.0, maxy - miny)
147
  break
148
  base_y = line_y + 0.8 * line_height
149
 
150
+ # 4) Compose all elements left-to-right using measured widths
151
  composed = []
152
  cursor_x = x_start
153
 
154
+ for i, (g, bbox, visw) in enumerate(zip(part_groups, path_bboxes, visual_widths)):
155
  minx, _, maxx, _ = bbox
 
156
 
157
+ # Align this group's left edge at cursor_x
158
  _translate_group(g, dx=(cursor_x - minx), dy=line_y)
159
  composed.append(g)
160
+ cursor_x += visw # advance by visual width
161
 
162
+ # Draw underscore between parts
163
  if i < len(part_groups) - 1:
164
  x0 = cursor_x + gap_left
165
  x1 = x0 + underscore_len
 
170
  underscore.set('fill', 'none')
171
  underscore.set('stroke-linecap', 'round')
172
  composed.append(underscore)
173
+ cursor_x = x1 + gap_right
174
 
175
  return composed, cursor_x
176
 
177
+ # -----------------------------------------------------------------------------
178
+ # Handwriting generation (multi-line + composition)
179
+ # -----------------------------------------------------------------------------
180
  def generate_handwriting(
181
  text,
182
  style,
 
186
  multiline=True
187
  ):
188
  """
189
+ Generate a composed SVG that places underscores geometrically between
190
+ segments with pixel-accurate spacing.
191
  """
192
  try:
193
  lines = text.split('\n') if multiline else [text]
194
 
195
+ # Validation + normalize slashes; KEEP underscores intact
196
  for idx, ln in enumerate(lines):
197
  if len(ln) > 75:
198
  return f"Error: Line {idx + 1} is too long (max 75 characters)"
199
+ lines[idx] = ln.replace('/', '-').replace('\\', '-')
200
 
201
+ # Fresh SVG root (no white background -> transparent)
202
  svg_root = ET.Element('svg', {
203
  'xmlns': 'http://www.w3.org/2000/svg',
204
  'viewBox': '0 0 1200 800'
 
232
  svg_root.set('viewBox', f"0 0 {width} {height}")
233
 
234
  svg_content = ET.tostring(svg_root, encoding='unicode')
 
235
  with open("img/output.svg", "w", encoding="utf-8") as f:
236
  f.write(svg_content)
237
  return svg_content
 
239
  except Exception as e:
240
  return f"Error: {str(e)}"
241
 
242
+ # -----------------------------------------------------------------------------
243
+ # PNG export (transparent)
244
+ # -----------------------------------------------------------------------------
245
  def export_to_png(svg_content):
246
  """Convert SVG to transparent PNG using CairoSVG and Pillow."""
247
  try:
 
251
  if not svg_content or svg_content.startswith("Error:"):
252
  return None
253
 
 
254
  tmp_svg = "img/temp.svg"
255
  with open(tmp_svg, "w", encoding="utf-8") as f:
256
  f.write(svg_content)
257
 
258
+ # High scale for crisp strokes
259
  cairosvg.svg2png(
260
  url=tmp_svg,
261
  write_to="img/output_temp.png",
 
267
  if img.mode != 'RGBA':
268
  img = img.convert('RGBA')
269
 
270
+ # Ensure any near-white is transparent (safety)
271
  datas = img.getdata()
272
  new_data = []
273
  for item in datas:
 
280
  out_path = "img/output.png"
281
  img.save(out_path, "PNG")
282
 
 
283
  try:
284
  os.remove("img/output_temp.png")
285
  except:
 
291
  print(f"Error converting to PNG: {str(e)}")
292
  return None
293
 
294
+ # -----------------------------------------------------------------------------
295
+ # Gradio UI
296
+ # -----------------------------------------------------------------------------
297
  def generate_handwriting_wrapper(text, style, bias, color, stroke_width, multiline=True):
298
  svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
299
  png_path = export_to_png(svg)
 
300
  return svg, png_path, "img/output.svg"
301
 
302
  css = """
 
308
 
309
  with gr.Blocks(css=css) as demo:
310
  gr.Markdown("# 🖋️ Handwriting Synthesis (Underscore-safe)")
311
+ gr.Markdown("Generate realistic handwritten text with **perfect underscore alignment**—no retraining required.")
312
 
313
  with gr.Row():
314
  with gr.Column(scale=2):
315
  text_input = gr.Textbox(
316
  label="Text Input",
317
+ placeholder="Try: zeb_3asba, user_name, long__double__underscore",
318
  lines=5,
319
  max_lines=10,
320
  )
 
341
 
342
  gr.Markdown("""
343
  ### Notes
344
+ - Underscores are drawn between segments using **pixel-accurate measurements**.
345
  - Slashes (/, \\) are normalized to dashes (-) for model stability.
346
+ - Each line ≤ 75 characters. Transparent PNG export included.
 
347
  """)
348
 
349
  generate_btn.click(
 
362
  outputs=[text_input, style_select, bias_slider, color_picker, stroke_width]
363
  )
364
 
365
+ # -----------------------------------------------------------------------------
366
+ # Main
367
+ # -----------------------------------------------------------------------------
368
  if __name__ == "__main__":
369
  port = int(os.environ.get("PORT", 7860))
370
 
371
+ # Soft check for optional deps used here and in measurer
372
+ missing = []
373
  try:
374
  import cairosvg # noqa
375
  except ImportError:
376
+ missing.append("cairosvg")
377
  try:
378
  from PIL import Image # noqa
379
  except ImportError:
380
+ missing.append("pillow")
381
 
382
+ if missing:
383
+ print(f"WARNING: Missing packages for underscore alignment & PNG export: {', '.join(missing)}")
384
+ print("Install with: pip install " + " ".join(missing))
385
 
386
  demo.launch(server_name="0.0.0.0", server_port=port)