Start with semicolon, debug logits
Browse files
app.py
CHANGED
|
@@ -375,7 +375,15 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
|
|
| 375 |
# Latent -> Gcode via trained decoder (with debug)
|
| 376 |
with torch.no_grad():
|
| 377 |
batch_size = latent.shape[0]
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
generated_tokens = []
|
| 381 |
for step in range(min(max_tokens, 1024) - 1):
|
|
@@ -402,7 +410,10 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
|
|
| 402 |
# Debug first few tokens
|
| 403 |
if step < 5:
|
| 404 |
token_str = gcode_tokenizer.decode([token_id])
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
if token_id == gcode_tokenizer.eos_token_id:
|
| 408 |
print(f"Hit EOS at step {step}")
|
|
|
|
| 375 |
# Latent -> Gcode via trained decoder (with debug)
|
| 376 |
with torch.no_grad():
|
| 377 |
batch_size = latent.shape[0]
|
| 378 |
+
# Start with semicolon (gcode comment start) instead of pad
|
| 379 |
+
# Gcode files start with "; Source: ..."
|
| 380 |
+
start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
|
| 381 |
+
print(f"Start tokens for ';': {start_tokens}")
|
| 382 |
+
if start_tokens:
|
| 383 |
+
start_id = start_tokens[0]
|
| 384 |
+
else:
|
| 385 |
+
start_id = gcode_tokenizer.pad_token_id
|
| 386 |
+
input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
|
| 387 |
|
| 388 |
generated_tokens = []
|
| 389 |
for step in range(min(max_tokens, 1024) - 1):
|
|
|
|
| 410 |
# Debug first few tokens
|
| 411 |
if step < 5:
|
| 412 |
token_str = gcode_tokenizer.decode([token_id])
|
| 413 |
+
# Check logits distribution
|
| 414 |
+
top5_vals, top5_ids = torch.topk(logits[0, -1, :], 5)
|
| 415 |
+
top5_tokens = [gcode_tokenizer.decode([i.item()]) for i in top5_ids]
|
| 416 |
+
print(f"Step {step}: token_id={token_id}, token='{token_str}', top5={list(zip(top5_tokens, top5_vals.tolist()))}")
|
| 417 |
|
| 418 |
if token_id == gcode_tokenizer.eos_token_id:
|
| 419 |
print(f"Hit EOS at step {step}")
|