twarner commited on
Commit
7505a9b
·
1 Parent(s): 9a32c26

Debug decoder output, fix button color, inline examples

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -330,15 +330,48 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
330
  output_type="latent",
331
  )
332
  latent = result.images.to(dtype)
 
 
333
 
334
- # Latent -> Gcode via trained decoder
335
  with torch.no_grad():
336
- gcode = gcode_decoder.generate(
337
- latent,
338
- gcode_tokenizer,
339
- max_length=max_tokens,
340
- temperature=temperature,
341
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  gcode = validate_gcode(gcode)
344
  line_count = len([l for l in gcode.split("\n") if l.strip()])
@@ -368,14 +401,23 @@ css = """
368
  }
369
 
370
  .gr-button-primary {
371
- background: #000 !important;
372
- border: none !important;
373
- color: #fff !important;
374
  font-weight: 500 !important;
375
  }
376
 
377
  .gr-button-primary:hover {
378
- background: #333 !important;
 
 
 
 
 
 
 
 
 
379
  }
380
 
381
  footer {
@@ -425,6 +467,8 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
425
  ["geometric shapes"],
426
  ],
427
  inputs=prompt,
 
 
428
  )
429
 
430
  with gr.Column(scale=2):
 
330
  output_type="latent",
331
  )
332
  latent = result.images.to(dtype)
333
+ print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}")
334
+ print(f"Latent stats: min={latent.min():.3f}, max={latent.max():.3f}, mean={latent.mean():.3f}")
335
 
336
+ # Latent -> Gcode via trained decoder (with debug)
337
  with torch.no_grad():
338
+ batch_size = latent.shape[0]
339
+ input_ids = torch.full((batch_size, 1), gcode_tokenizer.pad_token_id, dtype=torch.long, device=device)
340
+
341
+ generated_tokens = []
342
+ for step in range(min(max_tokens, 1024) - 1):
343
+ logits = gcode_decoder(latent, input_ids)
344
+ next_logits = logits[:, -1, :] / temperature
345
+
346
+ # Top-p sampling
347
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
348
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
349
+ sorted_indices_to_remove = cumulative_probs > 0.9
350
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
351
+ sorted_indices_to_remove[:, 0] = False
352
+
353
+ for b in range(batch_size):
354
+ next_logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float('-inf')
355
+
356
+ probs = torch.softmax(next_logits, dim=-1)
357
+ next_token = torch.multinomial(probs, num_samples=1)
358
+ input_ids = torch.cat([input_ids, next_token], dim=1)
359
+
360
+ token_id = next_token.item()
361
+ generated_tokens.append(token_id)
362
+
363
+ # Debug first few tokens
364
+ if step < 5:
365
+ token_str = gcode_tokenizer.decode([token_id])
366
+ print(f"Step {step}: token_id={token_id}, token='{token_str}'")
367
+
368
+ if token_id == gcode_tokenizer.eos_token_id:
369
+ print(f"Hit EOS at step {step}")
370
+ break
371
+
372
+ print(f"Generated {len(generated_tokens)} tokens")
373
+ gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
374
+ print(f"Decoded gcode length: {len(gcode)} chars")
375
 
376
  gcode = validate_gcode(gcode)
377
  line_count = len([l for l in gcode.split("\n") if l.strip()])
 
401
  }
402
 
403
  .gr-button-primary {
404
+ background: #e8e8e8 !important;
405
+ border: 1px solid #ccc !important;
406
+ color: #333 !important;
407
  font-weight: 500 !important;
408
  }
409
 
410
  .gr-button-primary:hover {
411
+ background: #d8d8d8 !important;
412
+ }
413
+
414
+ .gr-examples {
415
+ margin-top: 8px !important;
416
+ }
417
+
418
+ .gr-examples .gr-sample-textbox {
419
+ display: inline-block !important;
420
+ margin-right: 8px !important;
421
  }
422
 
423
  footer {
 
467
  ["geometric shapes"],
468
  ],
469
  inputs=prompt,
470
+ label=None,
471
+ examples_per_page=3,
472
  )
473
 
474
  with gr.Column(scale=2):