twarner commited on
Commit
9a32c26
·
1 Parent(s): 89d775f

Fix model loading, minimal monochrome design

Browse files
Files changed (1) hide show
  1. app.py +100 -71
app.py CHANGED
@@ -1,4 +1,4 @@
1
- """dcode Gradio Space - Text to Gcode via SD-Gcode Diffusion."""
2
 
3
  import re
4
  import os
@@ -17,7 +17,6 @@ _model = None
17
 
18
 
19
  class GcodeDecoderConfig:
20
- """Configuration for gcode decoder."""
21
  def __init__(
22
  self,
23
  latent_channels: int = 4,
@@ -41,8 +40,6 @@ class GcodeDecoderConfig:
41
 
42
 
43
  class GcodeDecoder(nn.Module):
44
- """Transformer decoder: SD latent -> gcode tokens."""
45
-
46
  def __init__(self, config: GcodeDecoderConfig):
47
  super().__init__()
48
  self.config = config
@@ -120,7 +117,7 @@ class GcodeDecoder(nn.Module):
120
 
121
 
122
  def get_model():
123
- """Load and cache the SD-Gcode model."""
124
  global _model
125
  if _model is None:
126
  from diffusers import StableDiffusionPipeline
@@ -139,7 +136,7 @@ def get_model():
139
  with open(config_path) as f:
140
  config = json.load(f)
141
 
142
- # Load SD pipeline
143
  sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
144
  print(f"Loading SD from {sd_model_id}...")
145
  pipe = StableDiffusionPipeline.from_pretrained(
@@ -161,13 +158,31 @@ def get_model():
161
  )
162
  gcode_decoder = GcodeDecoder(decoder_config).to(device, dtype)
163
 
164
- # Load weights
165
- state_dict = torch.load(weights_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Extract decoder weights
168
  decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
169
  if k.startswith("gcode_decoder.")}
170
- gcode_decoder.load_state_dict(decoder_state)
 
 
 
171
  gcode_decoder.eval()
172
 
173
  # Gcode tokenizer
@@ -268,25 +283,24 @@ def gcode_to_svg(gcode: str) -> str:
268
  h = BOUNDS["top"] - BOUNDS["bottom"]
269
  padding = 20
270
 
 
271
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
272
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
273
- style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
274
  <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
275
- fill="#fff" stroke="#ccc" stroke-width="2"/>
276
- <line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
277
- <line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
278
  '''
279
 
280
  for path in paths:
281
  if len(path) < 2:
282
  continue
283
  d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
284
- svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
285
 
286
  total_points = sum(len(p) for p in paths)
287
  svg += f'''
288
- <text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
289
- Paths: {len(paths)} | Points: {total_points}
290
  </text>
291
  '''
292
  svg += "</svg>"
@@ -295,7 +309,7 @@ def gcode_to_svg(gcode: str) -> str:
295
 
296
  @spaces.GPU
297
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
298
- """Generate gcode from text prompt via SD-Gcode diffusion."""
299
  if not prompt or not prompt.strip():
300
  return "Enter a prompt to generate gcode", gcode_to_svg("")
301
 
@@ -307,7 +321,7 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
307
  device = m["device"]
308
  dtype = m["dtype"]
309
 
310
- # 1. Text -> Latent via full SD diffusion
311
  with torch.no_grad():
312
  result = pipe(
313
  prompt,
@@ -315,9 +329,9 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
315
  guidance_scale=guidance,
316
  output_type="latent",
317
  )
318
- latent = result.images.to(dtype) # [1, 4, 64, 64]
319
 
320
- # 2. Latent -> Gcode via trained decoder
321
  with torch.no_grad():
322
  gcode = gcode_decoder.generate(
323
  latent,
@@ -327,11 +341,11 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
327
  )
328
 
329
  gcode = validate_gcode(gcode)
330
- line_count = len(gcode.split("\n"))
331
  svg = gcode_to_svg(gcode)
332
 
333
- gcode_with_header = f"; dcode SD-Gcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}"
334
- return gcode_with_header, svg
335
 
336
  except Exception as e:
337
  import traceback
@@ -339,76 +353,91 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
339
  return f"; Error: {e}", gcode_to_svg("")
340
 
341
 
342
- # Custom CSS
343
- custom_css = """
 
 
 
 
 
 
344
  .gradio-container {
345
- max-width: 1200px !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  }
347
  """
348
 
349
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
350
- gr.Markdown("""
351
- # dcode
352
- **Text -> Polargraph Gcode via Stable Diffusion**
353
-
354
- Single end-to-end diffusion model: text -> CLIP -> UNet -> latent -> gcode decoder -> gcode
355
-
356
- [GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-sd-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
357
- """)
358
 
359
  with gr.Row():
360
  with gr.Column(scale=1):
361
  prompt = gr.Textbox(
362
- label="Prompt",
363
- placeholder="drawing of a cat, abstract spiral, portrait...",
364
- lines=2
 
365
  )
366
 
367
- with gr.Row():
368
- temperature = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
369
- max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens")
 
 
370
 
371
- with gr.Row():
372
- num_steps = gr.Slider(10, 50, value=20, step=5, label="Diffusion Steps")
373
- guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
374
-
375
- generate_btn = gr.Button("Generate", variant="primary", size="lg")
376
 
377
  gr.Examples(
378
  examples=[
379
- ["line drawing of a cat"],
380
- ["abstract spiral pattern"],
381
- ["simple house with chimney"],
382
  ["portrait sketch"],
383
- ["geometric shapes and lines"],
384
  ],
385
  inputs=prompt,
386
  )
387
 
388
  with gr.Column(scale=2):
389
- preview = gr.HTML(
390
- value=gcode_to_svg(""),
391
- label="Preview",
392
- )
393
 
394
- with gr.Accordion("Gcode Output", open=False):
395
- gcode_output = gr.Code(label="Gcode", language=None, lines=15)
396
 
397
- gr.Markdown("""
398
- ---
399
- **Machine Bounds**: X: +/-420.5mm, Y: +/-594.5mm | Pen servo: 40 deg (down) / 90 deg (up) | **License**: MIT
400
- """)
401
 
402
- generate_btn.click(
403
- generate,
404
- [prompt, temperature, max_tokens, num_steps, guidance],
405
- [gcode_output, preview]
406
- )
407
- prompt.submit(
408
- generate,
409
- [prompt, temperature, max_tokens, num_steps, guidance],
410
- [gcode_output, preview]
411
- )
412
 
413
  if __name__ == "__main__":
414
  demo.launch()
 
1
+ """dcode - Text to Polargraph Gcode via Stable Diffusion"""
2
 
3
  import re
4
  import os
 
17
 
18
 
19
  class GcodeDecoderConfig:
 
20
  def __init__(
21
  self,
22
  latent_channels: int = 4,
 
40
 
41
 
42
  class GcodeDecoder(nn.Module):
 
 
43
  def __init__(self, config: GcodeDecoderConfig):
44
  super().__init__()
45
  self.config = config
 
117
 
118
 
119
  def get_model():
120
+ """Load and cache the SD-Gcode model with full finetuned weights."""
121
  global _model
122
  if _model is None:
123
  from diffusers import StableDiffusionPipeline
 
136
  with open(config_path) as f:
137
  config = json.load(f)
138
 
139
+ # Load SD pipeline (we'll replace weights with finetuned ones)
140
  sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
141
  print(f"Loading SD from {sd_model_id}...")
142
  pipe = StableDiffusionPipeline.from_pretrained(
 
158
  )
159
  gcode_decoder = GcodeDecoder(decoder_config).to(device, dtype)
160
 
161
+ # Load ALL finetuned weights
162
+ print("Loading finetuned weights...")
163
+ state_dict = torch.load(weights_path, map_location=device, weights_only=False)
164
+
165
+ # Load text encoder weights
166
+ text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
167
+ if k.startswith("text_encoder.")}
168
+ if text_encoder_state:
169
+ pipe.text_encoder.load_state_dict(text_encoder_state, strict=False)
170
+ print(f"Loaded {len(text_encoder_state)} text encoder weights")
171
+
172
+ # Load UNet weights
173
+ unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items()
174
+ if k.startswith("unet.")}
175
+ if unet_state:
176
+ pipe.unet.load_state_dict(unet_state, strict=False)
177
+ print(f"Loaded {len(unet_state)} UNet weights")
178
 
179
+ # Load gcode decoder weights
180
  decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
181
  if k.startswith("gcode_decoder.")}
182
+ if decoder_state:
183
+ gcode_decoder.load_state_dict(decoder_state, strict=False)
184
+ print(f"Loaded {len(decoder_state)} decoder weights")
185
+
186
  gcode_decoder.eval()
187
 
188
  # Gcode tokenizer
 
283
  h = BOUNDS["top"] - BOUNDS["bottom"]
284
  padding = 20
285
 
286
+ # Minimal monochrome styling
287
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
288
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
289
+ style="background: #fff; width: 100%; height: 480px; border: 1px solid #e0e0e0;">
290
  <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
291
+ fill="#fafafa" stroke="#ccc" stroke-width="1"/>
 
 
292
  '''
293
 
294
  for path in paths:
295
  if len(path) < 2:
296
  continue
297
  d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
298
+ svg += f'<path d="{d}" fill="none" stroke="#000" stroke-width="1" stroke-linecap="round" stroke-linejoin="round"/>'
299
 
300
  total_points = sum(len(p) for p in paths)
301
  svg += f'''
302
+ <text x="{BOUNDS["left"] + 8}" y="{-BOUNDS["top"] + 20}" fill="#999" font-family="monospace" font-size="12">
303
+ {len(paths)} paths / {total_points} points
304
  </text>
305
  '''
306
  svg += "</svg>"
 
309
 
310
  @spaces.GPU
311
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
312
+ """Generate gcode from text prompt."""
313
  if not prompt or not prompt.strip():
314
  return "Enter a prompt to generate gcode", gcode_to_svg("")
315
 
 
321
  device = m["device"]
322
  dtype = m["dtype"]
323
 
324
+ # Text -> Latent via SD diffusion
325
  with torch.no_grad():
326
  result = pipe(
327
  prompt,
 
329
  guidance_scale=guidance,
330
  output_type="latent",
331
  )
332
+ latent = result.images.to(dtype)
333
 
334
+ # Latent -> Gcode via trained decoder
335
  with torch.no_grad():
336
  gcode = gcode_decoder.generate(
337
  latent,
 
341
  )
342
 
343
  gcode = validate_gcode(gcode)
344
+ line_count = len([l for l in gcode.split("\n") if l.strip()])
345
  svg = gcode_to_svg(gcode)
346
 
347
+ header = f"; dcode output\n; prompt: {prompt}\n; {line_count} commands\n\n"
348
+ return header + gcode, svg
349
 
350
  except Exception as e:
351
  import traceback
 
353
  return f"; Error: {e}", gcode_to_svg("")
354
 
355
 
356
+ # Minimal monochrome CSS
357
+ css = """
358
+ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap');
359
+
360
+ * {
361
+ font-family: 'IBM Plex Mono', monospace !important;
362
+ }
363
+
364
  .gradio-container {
365
+ max-width: 900px !important;
366
+ margin: auto;
367
+ background: #fff !important;
368
+ }
369
+
370
+ .gr-button-primary {
371
+ background: #000 !important;
372
+ border: none !important;
373
+ color: #fff !important;
374
+ font-weight: 500 !important;
375
+ }
376
+
377
+ .gr-button-primary:hover {
378
+ background: #333 !important;
379
+ }
380
+
381
+ footer {
382
+ display: none !important;
383
+ }
384
+
385
+ h1 {
386
+ font-weight: 500 !important;
387
+ letter-spacing: -0.02em !important;
388
+ }
389
+
390
+ .gr-box {
391
+ border-radius: 0 !important;
392
+ border: 1px solid #e0e0e0 !important;
393
+ }
394
+
395
+ input, textarea {
396
+ border-radius: 0 !important;
397
  }
398
  """
399
 
400
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
401
+ gr.Markdown("# dcode")
402
+ gr.Markdown("text → polargraph gcode via stable diffusion")
 
 
 
 
 
 
403
 
404
  with gr.Row():
405
  with gr.Column(scale=1):
406
  prompt = gr.Textbox(
407
+ label="prompt",
408
+ placeholder="describe what to draw...",
409
+ lines=2,
410
+ show_label=True,
411
  )
412
 
413
+ with gr.Accordion("settings", open=False):
414
+ temperature = gr.Slider(0.5, 1.5, value=0.8, label="temperature", step=0.1)
415
+ max_tokens = gr.Slider(256, 1024, value=512, step=128, label="max tokens")
416
+ num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
417
+ guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
418
 
419
+ generate_btn = gr.Button("generate", variant="primary")
 
 
 
 
420
 
421
  gr.Examples(
422
  examples=[
423
+ ["a line drawing of a horse"],
 
 
424
  ["portrait sketch"],
425
+ ["geometric shapes"],
426
  ],
427
  inputs=prompt,
428
  )
429
 
430
  with gr.Column(scale=2):
431
+ preview = gr.HTML(value=gcode_to_svg(""))
 
 
 
432
 
433
+ with gr.Accordion("gcode", open=False):
434
+ gcode_output = gr.Code(label=None, language=None, lines=12)
435
 
436
+ gr.Markdown("---")
437
+ gr.Markdown("machine: 841×1189mm / pen servo 40-90° / [github](https://github.com/Twarner491/dcode) / [model](https://huggingface.co/twarner/dcode-sd-gcode) / mit")
 
 
438
 
439
+ generate_btn.click(generate, [prompt, temperature, max_tokens, num_steps, guidance], [gcode_output, preview])
440
+ prompt.submit(generate, [prompt, temperature, max_tokens, num_steps, guidance], [gcode_output, preview])
 
 
 
 
 
 
 
 
441
 
442
  if __name__ == "__main__":
443
  demo.launch()