3morrrrr commited on
Commit
87bfba4
·
verified ·
1 Parent(s): c0f74fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -218
app.py CHANGED
@@ -1,10 +1,8 @@
1
- import os
2
- import re
3
  import xml.etree.ElementTree as ET
4
  import gradio as gr
5
 
6
- # Your handwriting model wrapper
7
- from hand import Hand
8
 
9
  # -----------------------------------------------------------------------------
10
  # Setup
@@ -13,169 +11,224 @@ os.makedirs("img", exist_ok=True)
13
  hand = Hand()
14
 
15
  # -----------------------------------------------------------------------------
16
- # SVG helpers
17
  # -----------------------------------------------------------------------------
18
- def _extract_paths(svg_root_or_group):
19
- """Return a list of <path> elements (namespace-agnostic)."""
20
- return [el for el in svg_root_or_group.iter() if el.tag.endswith('path')]
 
 
 
21
 
22
- def _bbox_of_paths(paths):
23
- """
24
- Conservative bbox by scanning numeric coords inside each path 'd'.
25
- Not mathematically perfect for curve extrema, but close enough for layout.
26
- """
27
- xs, ys = [], []
28
- num_re = re.compile(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?')
29
- for p in paths:
30
- d = p.get('d', '')
31
- nums = list(map(float, num_re.findall(d)))
32
- for i in range(0, len(nums) - 1, 2):
33
- xs.append(nums[i])
34
- ys.append(nums[i + 1])
35
- if not xs or not ys:
36
- return (0.0, 0.0, 0.0, 0.0)
37
- return (min(xs), min(ys), max(xs), max(ys))
38
 
39
  def _translate_group(elem, dx, dy):
40
- prev = elem.get('transform', '')
41
- t = f"translate({dx},{dy})"
42
- elem.set('transform', (prev + " " + t).strip())
43
-
44
- # -----------------------------------------------------------------------------
45
- # Visual (raster) measurement for exact spacing
46
- # -----------------------------------------------------------------------------
47
- def _tight_width_via_raster(group_elem, stroke_color="#000", stroke_width=2):
48
- """
49
- Render a tiny SVG containing only this group, rasterize to PNG in-memory,
50
- and return the non-transparent width in pixels. Produces pixel-accurate
51
- width that's immune to relative path quirks.
52
- """
53
- import io
54
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  import cairosvg
56
- import xml.etree.ElementTree as ET
57
-
58
- # Build a tiny standalone SVG for the group (children are already paths)
59
- svg = ET.Element('svg', {
60
- 'xmlns': 'http://www.w3.org/2000/svg',
61
- 'viewBox': '0 0 1200 400'
62
- })
63
- g = ET.Element('g')
64
- for p in list(group_elem):
65
- g.append(p)
66
- svg.append(g)
67
-
68
- svg_bytes = ET.tostring(svg, encoding='utf-8')
69
-
70
- # Rasterize in memory
71
- png_bytes = cairosvg.svg2png(bytestring=svg_bytes, scale=1.0, background_color="none")
72
- img = Image.open(io.BytesIO(png_bytes)).convert('RGBA')
73
-
74
- # Non-transparent bounding box
75
- alpha = img.split()[-1]
76
- bbox = alpha.getbbox() # (left, top, right, bottom) or None
77
  if not bbox:
78
- return 0.0
79
  left, top, right, bottom = bbox
80
- return float(right - left)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # -----------------------------------------------------------------------------
83
- # Segment compositor (key fix for underscores)
84
  # -----------------------------------------------------------------------------
85
- def _render_part_group(text_part, style, bias, color, stroke_width):
86
  """
87
- Render one underscore-free 'part' using the Hand model to a temp SVG,
88
- extract its path group, and return (group, path_bbox).
89
  """
90
- # Avoid empty part causing errors: render a single space
91
- lines = [text_part if text_part else " "]
 
 
 
92
  hand.write(
93
- filename='img/part.tmp.svg',
94
- lines=lines,
95
  biases=[bias],
96
  styles=[style],
97
  stroke_colors=[color],
98
  stroke_widths=[stroke_width]
99
  )
100
- svg = ET.parse('img/part.tmp.svg').getroot()
101
- g = ET.Element('g')
102
- for p in _extract_paths(svg):
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  g.append(p)
104
- bbox = _bbox_of_paths(_extract_paths(g))
105
- return g, bbox
106
 
107
- def render_line_with_underscores(
108
- line,
109
- style,
110
- bias,
111
- color,
112
- stroke_width,
113
- line_y,
114
- x_start=40,
115
- gap_left=8,
116
- gap_right=8,
117
- underscore_len=22
118
- ):
119
- """
120
- Compose a line with underscores by rendering surrounding segments,
121
- measuring their *visual* widths via rasterization, and drawing the
122
- underscore at the baseline. Returns (list_of_elements, rightmost_x).
123
- """
124
- parts = line.split('_')
125
-
126
- part_groups, visual_widths, path_bboxes = [], [], []
127
-
128
- # 1) Render each part to a <g>
129
- for part in parts:
130
- g, bbox = _render_part_group(part, style, bias, color, stroke_width)
131
- part_groups.append(g)
132
- path_bboxes.append(bbox)
133
-
134
- # 2) Measure visual widths robustly (pixel-accurate)
135
- for g in part_groups:
136
- w = _tight_width_via_raster(g, stroke_color=color, stroke_width=stroke_width)
137
- if w <= 0.1: # fallback to path bbox width if needed
138
- minx, _, maxx, _ = _bbox_of_paths(_extract_paths(g))
139
- w = max(0.0, maxx - minx)
140
- visual_widths.append(w)
141
-
142
- # 3) Estimate line metrics for baseline
143
- line_height = 40.0
144
- for minx, miny, maxx, maxy in path_bboxes:
145
- if maxy - miny > 0.0:
146
- line_height = max(20.0, maxy - miny)
147
- break
148
- base_y = line_y + 0.8 * line_height
149
-
150
- # 4) Compose all elements left-to-right using measured widths
151
- composed = []
152
- cursor_x = x_start
153
-
154
- for i, (g, bbox, visw) in enumerate(zip(part_groups, path_bboxes, visual_widths)):
155
- minx, _, maxx, _ = bbox
156
-
157
- # Align this group's left edge at cursor_x
158
- _translate_group(g, dx=(cursor_x - minx), dy=line_y)
159
- composed.append(g)
160
- cursor_x += visw # advance by visual width
161
-
162
- # Draw underscore between parts
163
- if i < len(part_groups) - 1:
164
- x0 = cursor_x + gap_left
165
- x1 = x0 + underscore_len
166
- underscore = ET.Element('path')
167
- underscore.set('d', f"M{x0},{base_y} L{x1},{base_y}")
168
- underscore.set('stroke', color)
169
- underscore.set('stroke-width', str(max(1, stroke_width)))
170
- underscore.set('fill', 'none')
171
- underscore.set('stroke-linecap', 'round')
172
- composed.append(underscore)
173
- cursor_x = x1 + gap_right
174
-
175
- return composed, cursor_x
176
 
177
  # -----------------------------------------------------------------------------
178
- # Handwriting generation (multi-line + composition)
179
  # -----------------------------------------------------------------------------
180
  def generate_handwriting(
181
  text,
@@ -185,53 +238,37 @@ def generate_handwriting(
185
  stroke_width=2,
186
  multiline=True
187
  ):
188
- """
189
- Generate a composed SVG that places underscores geometrically between
190
- segments with pixel-accurate spacing.
191
- """
192
  try:
193
- lines = text.split('\n') if multiline else [text]
194
 
195
- # Validation + normalize slashes; KEEP underscores intact
196
  for idx, ln in enumerate(lines):
197
  if len(ln) > 75:
198
- return f"Error: Line {idx + 1} is too long (max 75 characters)"
199
- lines[idx] = ln.replace('/', '-').replace('\\', '-')
200
 
201
- # Fresh SVG root (no white background -> transparent)
202
- svg_root = ET.Element('svg', {
203
- 'xmlns': 'http://www.w3.org/2000/svg',
204
- 'viewBox': '0 0 1200 800'
205
  })
206
 
207
  y0 = 80.0
208
- line_gap = 100.0
209
- max_x = 0.0
210
-
211
- for i, original_line in enumerate(lines):
212
- line_y = y0 + i * line_gap
213
- elems, right_x = render_line_with_underscores(
214
- original_line,
215
- style,
216
- bias,
217
- color,
218
- stroke_width,
219
- line_y,
220
- x_start=40,
221
- gap_left=8,
222
- gap_right=8,
223
- underscore_len=22
224
  )
225
- for el in elems:
226
- svg_root.append(el)
227
- max_x = max(max_x, right_x)
 
228
 
229
- # Tighten viewBox to content width + margin
230
- width = max(300, int(max_x + 40))
231
- height = int(y0 + len(lines) * line_gap)
232
- svg_root.set('viewBox', f"0 0 {width} {height}")
233
 
234
- svg_content = ET.tostring(svg_root, encoding='unicode')
235
  with open("img/output.svg", "w", encoding="utf-8") as f:
236
  f.write(svg_content)
237
  return svg_content
@@ -243,7 +280,6 @@ def generate_handwriting(
243
  # PNG export (transparent)
244
  # -----------------------------------------------------------------------------
245
  def export_to_png(svg_content):
246
- """Convert SVG to transparent PNG using CairoSVG and Pillow."""
247
  try:
248
  import cairosvg
249
  from PIL import Image
@@ -255,27 +291,16 @@ def export_to_png(svg_content):
255
  with open(tmp_svg, "w", encoding="utf-8") as f:
256
  f.write(svg_content)
257
 
258
- # High scale for crisp strokes
259
- cairosvg.svg2png(
260
- url=tmp_svg,
261
- write_to="img/output_temp.png",
262
- scale=2.0,
263
- background_color="none"
264
- )
265
 
266
  img = Image.open("img/output_temp.png")
267
- if img.mode != 'RGBA':
268
- img = img.convert('RGBA')
269
 
270
  # Ensure any near-white is transparent (safety)
271
  datas = img.getdata()
272
- new_data = []
273
- for item in datas:
274
- if item[0] > 240 and item[1] > 240 and item[2] > 240:
275
- new_data.append((255, 255, 255, 0))
276
- else:
277
- new_data.append(item)
278
- img.putdata(new_data)
279
 
280
  out_path = "img/output.png"
281
  img.save(out_path, "PNG")
@@ -296,8 +321,8 @@ def export_to_png(svg_content):
296
  # -----------------------------------------------------------------------------
297
  def generate_handwriting_wrapper(text, style, bias, color, stroke_width, multiline=True):
298
  svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
299
- png_path = export_to_png(svg)
300
- return svg, png_path, "img/output.svg"
301
 
302
  css = """
303
  .container {max-width: 900px; margin: auto;}
@@ -307,31 +332,29 @@ css = """
307
  """
308
 
309
  with gr.Blocks(css=css) as demo:
310
- gr.Markdown("# 🖋️ Handwriting Synthesis (Underscore-safe)")
311
- gr.Markdown("Generate realistic handwritten text with **perfect underscore alignment**—no retraining required.")
312
 
313
  with gr.Row():
314
  with gr.Column(scale=2):
315
  text_input = gr.Textbox(
316
  label="Text Input",
317
- placeholder="Try: zeb_3asba, user_name, long__double__underscore",
318
- lines=5,
319
- max_lines=10,
320
  )
321
  with gr.Row():
322
  with gr.Column(scale=1):
323
- style_select = gr.Slider(minimum=0, maximum=12, step=1, value=9, label="Handwriting Style")
324
  with gr.Column(scale=1):
325
- bias_slider = gr.Slider(minimum=0.5, maximum=1.0, step=0.05, value=0.75, label="Neatness (Higher = Neater)")
326
  with gr.Row():
327
  with gr.Column(scale=1):
328
  color_picker = gr.ColorPicker(label="Ink Color", value="#000000")
329
  with gr.Column(scale=1):
330
- stroke_width = gr.Slider(minimum=1, maximum=4, step=0.5, value=2, label="Stroke Width")
331
  with gr.Row():
332
  generate_btn = gr.Button("Generate Handwriting", variant="primary")
333
  clear_btn = gr.Button("Clear")
334
-
335
  with gr.Column(scale=3):
336
  output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"])
337
  output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"])
@@ -341,9 +364,9 @@ with gr.Blocks(css=css) as demo:
341
 
342
  gr.Markdown("""
343
  ### Notes
344
- - Underscores are drawn between segments using **pixel-accurate measurements**.
345
- - Slashes (/, \\) are normalized to dashes (-) for model stability.
346
- - Each line ≤ 75 characters. Transparent PNG export included.
347
  """)
348
 
349
  generate_btn.click(
@@ -367,8 +390,6 @@ with gr.Blocks(css=css) as demo:
367
  # -----------------------------------------------------------------------------
368
  if __name__ == "__main__":
369
  port = int(os.environ.get("PORT", 7860))
370
-
371
- # Soft check for optional deps used here and in measurer
372
  missing = []
373
  try:
374
  import cairosvg # noqa
@@ -378,9 +399,6 @@ if __name__ == "__main__":
378
  from PIL import Image # noqa
379
  except ImportError:
380
  missing.append("pillow")
381
-
382
  if missing:
383
- print(f"WARNING: Missing packages for underscore alignment & PNG export: {', '.join(missing)}")
384
- print("Install with: pip install " + " ".join(missing))
385
-
386
  demo.launch(server_name="0.0.0.0", server_port=port)
 
1
+ import os, re, io
 
2
  import xml.etree.ElementTree as ET
3
  import gradio as gr
4
 
5
+ from hand import Hand # your model
 
6
 
7
  # -----------------------------------------------------------------------------
8
  # Setup
 
11
  hand = Hand()
12
 
13
  # -----------------------------------------------------------------------------
14
+ # Small helpers
15
  # -----------------------------------------------------------------------------
16
+ def _parse_viewbox(root):
17
+ vb = root.get("viewBox")
18
+ if not vb:
19
+ return (0.0, 0.0, 1200.0, 400.0)
20
+ x, y, w, h = map(float, vb.split())
21
+ return (x, y, w, h)
22
 
23
+ def _extract_paths(elem):
24
+ return [e for e in elem.iter() if e.tag.endswith("path")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def _translate_group(elem, dx, dy):
27
+ prev = elem.get("transform", "")
28
+ elem.set("transform", (prev + f" translate({dx},{dy})").strip())
29
+
30
+ # Tokenize a line into runs of text / spaces / underscores (order matters)
31
+ def _tokenize_line(line):
32
+ tokens = []
33
+ i, n = 0, len(line)
34
+ while i < n:
35
+ ch = line[i]
36
+ if ch == '_':
37
+ j = i
38
+ while j < n and line[j] == '_': j += 1
39
+ tokens.append(("sep_underscore", line[i:j]))
40
+ i = j
41
+ elif ch.isspace():
42
+ j = i
43
+ while j < n and line[j].isspace(): j += 1
44
+ tokens.append(("sep_space", line[i:j]))
45
+ i = j
46
+ else:
47
+ j = i
48
+ while j < n and (line[j] != '_' and not line[j].isspace()):
49
+ j += 1
50
+ tokens.append(("text", line[i:j]))
51
+ i = j
52
+ return tokens
53
+
54
+ # Build the text that the model will actually render (underscores & spaces -> space)
55
+ def _display_text_from_tokens(tokens):
56
+ out = []
57
+ for t, v in tokens:
58
+ if t == "text":
59
+ out.append(v)
60
+ else:
61
+ out.append(" ") # normalize separators
62
+ return "".join(out)
63
+
64
+ # Rasterize an SVG string to RGBA PIL.Image and return (img, vw, vh, vx, vy)
65
+ def _rasterize_svg(svg_str, scale=2.0):
66
  import cairosvg
67
+ from PIL import Image
68
+ png = cairosvg.svg2png(bytestring=svg_str.encode("utf-8"), scale=scale, background_color="none")
69
+ img = Image.open(io.BytesIO(png)).convert("RGBA")
70
+ return img
71
+
72
+ # From alpha image, find left-to-right blobs & gaps within the content bbox
73
+ def _find_blobs_and_gaps(alpha_img):
74
+ w, h = alpha_img.size
75
+ bbox = alpha_img.getbbox()
 
 
 
 
 
 
 
 
 
 
 
 
76
  if not bbox:
77
+ return [], [], (0, 0, w, h) # nothing
78
  left, top, right, bottom = bbox
79
+
80
+ # columns with any ink
81
+ def col_has_ink(x):
82
+ # check a cropped band to speed up
83
+ for y in range(top, bottom):
84
+ if alpha_img.getpixel((x, y)) > 0:
85
+ return True
86
+ return False
87
+
88
+ blobs, gaps = [], []
89
+ x = left
90
+ in_blob = col_has_ink(x)
91
+ start = x
92
+ while x < right:
93
+ has = col_has_ink(x)
94
+ if has != in_blob:
95
+ if in_blob:
96
+ blobs.append((start, x)) # [start, x)
97
+ else:
98
+ gaps.append((start, x))
99
+ start = x
100
+ in_blob = has
101
+ x += 1
102
+ # close last run
103
+ if in_blob:
104
+ blobs.append((start, right))
105
+ else:
106
+ gaps.append((start, right))
107
+
108
+ # We only want gaps **between** blobs, not leading/trailing margins:
109
+ core_gaps = []
110
+ for i in range(len(blobs) - 1):
111
+ core_gaps.append((blobs[i][1], blobs[i + 1][0]))
112
+
113
+ return blobs, core_gaps, (left, top, right, bottom)
114
+
115
+ # Map pixel x/y to SVG coords via viewBox
116
+ def _px_to_svg_x(x_px, img_w, vb):
117
+ vx, vy, vw, vh = vb
118
+ return vx + (x_px / float(img_w)) * vw
119
+
120
+ def _px_to_svg_y(y_px, img_h, vb):
121
+ vx, vy, vw, vh = vb
122
+ return vy + (y_px / float(img_h)) * vh
123
+
124
+ # Draw N underscores filling a gap (pixel coords -> converted to SVG coords)
125
+ def _draw_underscores_in_gap(root, gap_px, baseline_px, img_w, img_h, vb,
126
+ color, stroke_width, n, pad_px=3, between_px=4):
127
+ gap_w = max(0, gap_px[1] - gap_px[0] - 2 * pad_px)
128
+ if gap_w <= 6 or n <= 0:
129
+ return
130
+ # fit underscores nicely within the gap
131
+ # give each underscore ~85% of its slot
132
+ slot = gap_w / n
133
+ line_len = max(8, int(slot * 0.85) - between_px)
134
+ offset = (slot - line_len) / 2.0
135
+
136
+ for i in range(n):
137
+ x0_px = gap_px[0] + pad_px + i * slot + offset
138
+ x1_px = x0_px + line_len
139
+ y_px = baseline_px
140
+
141
+ x0 = _px_to_svg_x(x0_px, img_w, vb)
142
+ x1 = _px_to_svg_x(x1_px, img_w, vb)
143
+ y = _px_to_svg_y(y_px, img_h, vb)
144
+
145
+ p = ET.Element("path")
146
+ p.set("d", f"M{x0},{y} L{x1},{y}")
147
+ p.set("stroke", color)
148
+ p.set("stroke-width", str(max(1, stroke_width)))
149
+ p.set("fill", "none")
150
+ p.set("stroke-linecap", "round")
151
+ root.append(p)
152
 
153
  # -----------------------------------------------------------------------------
154
+ # Render ONE line with correct underscores by analyzing the full-line mask
155
  # -----------------------------------------------------------------------------
156
+ def render_line_svg_with_underscores(line, style, bias, color, stroke_width):
157
  """
158
+ Returns (line_group, width_estimate). The group contains model strokes + our underscores.
 
159
  """
160
+ # 1) Tokenize and build the display line (underscores & spaces -> spaces)
161
+ tokens = _tokenize_line(line)
162
+ display_line = _display_text_from_tokens(tokens).replace("/", "-").replace("\\", "-")
163
+
164
+ # 2) Ask the model to render this single line to a temp SVG
165
  hand.write(
166
+ filename="img/line.tmp.svg",
167
+ lines=[display_line if display_line.strip() else " "],
168
  biases=[bias],
169
  styles=[style],
170
  stroke_colors=[color],
171
  stroke_widths=[stroke_width]
172
  )
173
+ root = ET.parse("img/line.tmp.svg").getroot()
174
+ vb = _parse_viewbox(root)
175
+
176
+ # 3) Rasterize the exact SVG we will augment, then find blobs/gaps
177
+ from PIL import Image
178
+ img = _rasterize_svg(ET.tostring(root, encoding="unicode"))
179
+ alpha = img.split()[-1]
180
+ blobs, gaps, content_bbox = _find_blobs_and_gaps(alpha)
181
+ img_w, img_h = img.size
182
+ left, top, right, bottom = content_bbox
183
+ line_h_px = max(20, bottom - top)
184
+ baseline_px = bottom - int(0.18 * line_h_px) # tuck slightly above bottom
185
+
186
+ # 4) Build a <g> that contains all original paths
187
+ g = ET.Element("g")
188
+ for p in _extract_paths(root):
189
  g.append(p)
 
 
190
 
191
+ # 5) Determine which gaps correspond to underscores (not regular spaces)
192
+ # Walk tokens: whenever we have TEXT then SEPs then TEXT, consume one gap.
193
+ gap_idx = 0
194
+ i = 0
195
+ text_run_count = sum(1 for t, v in tokens if t == "text" and len(v) > 0)
196
+ # Defensive: if model merged two words, blobs may be fewer—clamp at min
197
+ max_gaps_we_can_use = max(0, min(len(gaps), max(0, text_run_count - 1)))
198
+
199
+ while i < len(tokens):
200
+ t, v = tokens[i]
201
+ if t == "text" and len(v) > 0:
202
+ # Count any following separators as one logical gap between this and the next text
203
+ j = i + 1
204
+ underscore_count = 0
205
+ saw_sep = False
206
+ while j < len(tokens) and tokens[j][0].startswith("sep_"):
207
+ saw_sep = True
208
+ if tokens[j][0] == "sep_underscore":
209
+ underscore_count += len(tokens[j][1])
210
+ j += 1
211
+
212
+ if saw_sep and gap_idx < max_gaps_we_can_use:
213
+ if underscore_count > 0:
214
+ _draw_underscores_in_gap(
215
+ g, gaps[gap_idx], baseline_px, img_w, img_h, vb,
216
+ color, stroke_width, underscore_count
217
+ )
218
+ gap_idx += 1
219
+ i = j
220
+ else:
221
+ i += 1
222
+
223
+ # 6) Estimate width (use content bbox)
224
+ width_estimate = right - left
225
+ if width_estimate <= 0:
226
+ width_estimate = img_w // 2
227
+
228
+ return g, width_estimate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  # -----------------------------------------------------------------------------
231
+ # Generate full SVG (multi-line stacking)
232
  # -----------------------------------------------------------------------------
233
  def generate_handwriting(
234
  text,
 
238
  stroke_width=2,
239
  multiline=True
240
  ):
 
 
 
 
241
  try:
242
+ lines = text.split("\n") if multiline else [text]
243
 
 
244
  for idx, ln in enumerate(lines):
245
  if len(ln) > 75:
246
+ return f"Error: Line {idx+1} is too long (max 75 characters)"
 
247
 
248
+ # Compose final SVG
249
+ svg_root = ET.Element("svg", {
250
+ "xmlns": "http://www.w3.org/2000/svg",
251
+ "viewBox": "0 0 1200 800"
252
  })
253
 
254
  y0 = 80.0
255
+ line_gap = 110.0
256
+ max_right = 0.0
257
+
258
+ for i, line in enumerate(lines):
259
+ g, w = render_line_svg_with_underscores(
260
+ line, style, bias, color, stroke_width
 
 
 
 
 
 
 
 
 
 
261
  )
262
+ # place this line group lower on the page
263
+ _translate_group(g, dx=40, dy=y0 + i * line_gap)
264
+ svg_root.append(g)
265
+ max_right = max(max_right, 40 + w)
266
 
267
+ height = int(y0 + len(lines) * line_gap + 80)
268
+ width = max(300, int(max_right + 40))
269
+ svg_root.set("viewBox", f"0 0 {width} {height}")
 
270
 
271
+ svg_content = ET.tostring(svg_root, encoding="unicode")
272
  with open("img/output.svg", "w", encoding="utf-8") as f:
273
  f.write(svg_content)
274
  return svg_content
 
280
  # PNG export (transparent)
281
  # -----------------------------------------------------------------------------
282
  def export_to_png(svg_content):
 
283
  try:
284
  import cairosvg
285
  from PIL import Image
 
291
  with open(tmp_svg, "w", encoding="utf-8") as f:
292
  f.write(svg_content)
293
 
294
+ cairosvg.svg2png(url=tmp_svg, write_to="img/output_temp.png", scale=2.0, background_color="none")
 
 
 
 
 
 
295
 
296
  img = Image.open("img/output_temp.png")
297
+ if img.mode != "RGBA":
298
+ img = img.convert("RGBA")
299
 
300
  # Ensure any near-white is transparent (safety)
301
  datas = img.getdata()
302
+ img.putdata([(255, 255, 255, 0) if r > 240 and g > 240 and b > 240 else (r, g, b, a)
303
+ for (r, g, b, a) in datas])
 
 
 
 
 
304
 
305
  out_path = "img/output.png"
306
  img.save(out_path, "PNG")
 
321
  # -----------------------------------------------------------------------------
322
  def generate_handwriting_wrapper(text, style, bias, color, stroke_width, multiline=True):
323
  svg = generate_handwriting(text, style, bias, color, stroke_width, multiline)
324
+ png = export_to_png(svg)
325
+ return svg, png, "img/output.svg"
326
 
327
  css = """
328
  .container {max-width: 900px; margin: auto;}
 
332
  """
333
 
334
  with gr.Blocks(css=css) as demo:
335
+ gr.Markdown("# 🖋️ Handwriting Synthesis (Underscore-aware)")
336
+ gr.Markdown("Underscores are placed only where typed, aligned to real gaps from the model output.")
337
 
338
  with gr.Row():
339
  with gr.Column(scale=2):
340
  text_input = gr.Textbox(
341
  label="Text Input",
342
+ placeholder="Try: user_name, zeb_3asba, or zeb aasba (no underscore should be drawn)",
343
+ lines=5, max_lines=10,
 
344
  )
345
  with gr.Row():
346
  with gr.Column(scale=1):
347
+ style_select = gr.Slider(0, 12, step=1, value=9, label="Handwriting Style")
348
  with gr.Column(scale=1):
349
+ bias_slider = gr.Slider(0.5, 1.0, step=0.05, value=0.75, label="Neatness (Higher = Neater)")
350
  with gr.Row():
351
  with gr.Column(scale=1):
352
  color_picker = gr.ColorPicker(label="Ink Color", value="#000000")
353
  with gr.Column(scale=1):
354
+ stroke_width = gr.Slider(1, 4, step=0.5, value=2, label="Stroke Width")
355
  with gr.Row():
356
  generate_btn = gr.Button("Generate Handwriting", variant="primary")
357
  clear_btn = gr.Button("Clear")
 
358
  with gr.Column(scale=3):
359
  output_svg = gr.HTML(label="Generated Handwriting (SVG)", elem_classes=["output-container"])
360
  output_png = gr.Image(type="filepath", label="Generated Handwriting (PNG)", elem_classes=["output-container"])
 
364
 
365
  gr.Markdown("""
366
  ### Notes
367
+ - Only underscores you type are drawn; normal spaces remain spaces.
368
+ - Alignment uses the *actual gaps* between model-drawn word blobs.
369
+ - Lines ≤ 75 chars. Slashes (/ and \\) normalized to dashes (-).
370
  """)
371
 
372
  generate_btn.click(
 
390
  # -----------------------------------------------------------------------------
391
  if __name__ == "__main__":
392
  port = int(os.environ.get("PORT", 7860))
 
 
393
  missing = []
394
  try:
395
  import cairosvg # noqa
 
399
  from PIL import Image # noqa
400
  except ImportError:
401
  missing.append("pillow")
 
402
  if missing:
403
+ print("Install:", " ".join(missing))
 
 
404
  demo.launch(server_name="0.0.0.0", server_port=port)