twarner commited on
Commit
916a1f7
·
1 Parent(s): 242131a

Start with semicolon, debug logits

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -375,7 +375,15 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
375
  # Latent -> Gcode via trained decoder (with debug)
376
  with torch.no_grad():
377
  batch_size = latent.shape[0]
378
- input_ids = torch.full((batch_size, 1), gcode_tokenizer.pad_token_id, dtype=torch.long, device=device)
 
 
 
 
 
 
 
 
379
 
380
  generated_tokens = []
381
  for step in range(min(max_tokens, 1024) - 1):
@@ -402,7 +410,10 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
402
  # Debug first few tokens
403
  if step < 5:
404
  token_str = gcode_tokenizer.decode([token_id])
405
- print(f"Step {step}: token_id={token_id}, token='{token_str}'")
 
 
 
406
 
407
  if token_id == gcode_tokenizer.eos_token_id:
408
  print(f"Hit EOS at step {step}")
 
375
  # Latent -> Gcode via trained decoder (with debug)
376
  with torch.no_grad():
377
  batch_size = latent.shape[0]
378
+ # Start with semicolon (gcode comment start) instead of pad
379
+ # Gcode files start with "; Source: ..."
380
+ start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
381
+ print(f"Start tokens for ';': {start_tokens}")
382
+ if start_tokens:
383
+ start_id = start_tokens[0]
384
+ else:
385
+ start_id = gcode_tokenizer.pad_token_id
386
+ input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
387
 
388
  generated_tokens = []
389
  for step in range(min(max_tokens, 1024) - 1):
 
410
  # Debug first few tokens
411
  if step < 5:
412
  token_str = gcode_tokenizer.decode([token_id])
413
+ # Check logits distribution
414
+ top5_vals, top5_ids = torch.topk(logits[0, -1, :], 5)
415
+ top5_tokens = [gcode_tokenizer.decode([i.item()]) for i in top5_ids]
416
+ print(f"Step {step}: token_id={token_id}, token='{token_str}', top5={list(zip(top5_tokens, top5_vals.tolist()))}")
417
 
418
  if token_id == gcode_tokenizer.eos_token_id:
419
  print(f"Hit EOS at step {step}")