twarner commited on
Commit
1b8fb6e
·
1 Parent(s): 99753aa

Fix inference: better tokenizer init, gcode cleaning, and debug output

Browse files
Files changed (1) hide show
  1. app.py +129 -75
app.py CHANGED
@@ -304,21 +304,29 @@ def get_model():
304
  try:
305
  # Try loading custom tokenizer from v3 model
306
  tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json")
307
- gcode_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
308
- # Ensure special tokens are set
309
- if gcode_tokenizer.pad_token is None:
310
- gcode_tokenizer.pad_token = "<pad>"
311
- gcode_tokenizer.pad_token_id = 0
312
- if gcode_tokenizer.bos_token is None:
313
- gcode_tokenizer.bos_token = "<s>"
314
- gcode_tokenizer.bos_token_id = 1
315
- if gcode_tokenizer.eos_token is None:
316
- gcode_tokenizer.eos_token = "</s>"
317
- gcode_tokenizer.eos_token_id = 2
318
  print(f"Loaded custom gcode tokenizer (vocab={gcode_tokenizer.vocab_size})")
319
- print(f" BOS={gcode_tokenizer.bos_token_id}, EOS={gcode_tokenizer.eos_token_id}, PAD={gcode_tokenizer.pad_token_id}")
 
 
 
 
 
 
 
 
 
320
  except Exception as e:
321
  print(f"Failed to load custom tokenizer: {e}")
 
 
322
  # Fallback to T5 tokenizer
323
  gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
324
  print("Using fallback T5 tokenizer")
@@ -341,6 +349,72 @@ def get_model():
341
  # GCODE PROCESSING
342
  # ============================================================================
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  def validate_gcode(gcode: str) -> str:
345
  """Clamp coordinates to machine bounds."""
346
  lines = []
@@ -539,41 +613,53 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
539
  with torch.no_grad():
540
  batch_size = latent.shape[0]
541
 
542
- # Start token - use BOS for v3, semicolon for v2
 
 
 
 
 
543
  if is_v3:
544
- start_id = gcode_tokenizer.bos_token_id if gcode_tokenizer.bos_token_id is not None else 1
 
 
 
 
 
545
  else:
546
  start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
547
  start_id = start_tokens[0] if start_tokens else 0
 
548
 
549
- print(f"Starting generation with token id: {start_id}")
550
- input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
551
 
552
- max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
553
- eos_id = gcode_tokenizer.eos_token_id if gcode_tokenizer.eos_token_id is not None else 2
554
 
555
- # Track generated content for repetition detection
556
  recent_tokens = []
557
- recent_coords = []
558
- repetition_window = 30
559
 
560
  for step in range(max_gen):
561
  logits = gcode_decoder(latent, input_ids)
562
  next_logits = logits[:, -1, :] / temperature
563
 
564
- # Repetition penalty - reduce probability of recent tokens
 
 
 
 
 
565
  if recent_tokens:
566
- for token_id in set(recent_tokens[-repetition_window:]):
567
- next_logits[:, token_id] *= 0.6 # Stronger penalty
568
 
569
- # Top-k + Top-p sampling for better coherence
570
- top_k = 40
571
- top_p = 0.9
572
 
573
  # Top-k filtering
574
  top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
575
 
576
- # Top-p filtering within top-k
577
  sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1)
578
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
579
  sorted_indices_to_remove = cumulative_probs > top_p
@@ -584,66 +670,34 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
584
  probs = torch.softmax(sorted_logits, dim=-1)
585
  sampled_idx = torch.multinomial(probs, num_samples=1)
586
 
587
- # Map back to vocabulary indices
588
  next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx))
589
  input_ids = torch.cat([input_ids, next_token], dim=1)
590
  recent_tokens.append(next_token.item())
591
 
 
 
 
 
 
592
  # Check EOS
593
- if next_token.item() == eos_id:
594
  print(f"Hit EOS at step {step}")
595
  break
596
 
597
- # Early stop on excessive repetition
598
- if len(recent_tokens) > 20:
599
- last_20 = recent_tokens[-20:]
600
- if len(set(last_20)) < 4: # Less than 4 unique tokens in last 20
601
- print(f"Stopping due to token repetition at step {step}")
602
  break
603
 
604
- print(f"Generated {input_ids.shape[1]} tokens")
605
 
606
- # Decode - skip special tokens
607
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
608
- print(f"Raw decoded (first 200): {gcode[:200]}")
609
-
610
- # Post-process for v3: restore newlines from <newline> token
611
- if is_v3:
612
- gcode = gcode.replace("<newline>", "\n")
613
-
614
- # If still no newlines, try to split on command boundaries
615
- if "\n" not in gcode or gcode.count("\n") < 5:
616
- print("No newlines found, splitting on command boundaries...")
617
- # Split before G0, G1, G28, M280 commands
618
- gcode = re.sub(r'(G0\s)', r'\n\1', gcode)
619
- gcode = re.sub(r'(G1\s)', r'\n\1', gcode)
620
- gcode = re.sub(r'(G1X)', r'\nG1 X', gcode)
621
- gcode = re.sub(r'(G0X)', r'\nG0 X', gcode)
622
- gcode = re.sub(r'(G28)', r'\nG28', gcode)
623
- gcode = re.sub(r'(G21)', r'\nG21', gcode)
624
- gcode = re.sub(r'(G90)', r'\nG90', gcode)
625
- gcode = re.sub(r'(M280)', r'\nM280', gcode)
626
- # Split on F speed values that are followed by another command
627
- gcode = re.sub(r'(F\d+)(G)', r'\1\n\2', gcode)
628
- gcode = re.sub(r'(F\d+)(M)', r'\1\n\2', gcode)
629
-
630
- # Filter out training metadata and garbage lines
631
- filtered_lines = []
632
- for line in gcode.split("\n"):
633
- line = line.strip()
634
- # Skip empty lines and metadata
635
- if not line:
636
- continue
637
- if line.startswith("Source:") or line.startswith(";Generated"):
638
- continue
639
- if line.lower() in ["dcode", "gcode", "code"]: # Skip garbage words
640
- continue
641
- # Only keep lines that look like gcode (start with G, M, or ;)
642
- if line[0] in "GMgm;":
643
- filtered_lines.append(line)
644
 
645
- gcode = "\n".join(filtered_lines)
646
- print(f"Filtered gcode: {len(filtered_lines)} lines, {len(gcode)} chars")
647
 
648
  gcode = validate_gcode(gcode)
649
  line_count = len([l for l in gcode.split("\n") if l.strip()])
 
304
  try:
305
  # Try loading custom tokenizer from v3 model
306
  tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json")
307
+ gcode_tokenizer = PreTrainedTokenizerFast(
308
+ tokenizer_file=tokenizer_path,
309
+ pad_token="<pad>",
310
+ unk_token="<unk>",
311
+ bos_token="<s>",
312
+ eos_token="</s>",
313
+ )
314
+ # Verify special tokens
 
 
 
315
  print(f"Loaded custom gcode tokenizer (vocab={gcode_tokenizer.vocab_size})")
316
+ print(f" BOS='{gcode_tokenizer.bos_token}' (id={gcode_tokenizer.bos_token_id})")
317
+ print(f" EOS='{gcode_tokenizer.eos_token}' (id={gcode_tokenizer.eos_token_id})")
318
+ print(f" PAD='{gcode_tokenizer.pad_token}' (id={gcode_tokenizer.pad_token_id})")
319
+
320
+ # Test encode/decode
321
+ test = "G0 X100 Y200\nG1 X150 Y250"
322
+ enc = gcode_tokenizer.encode(test)
323
+ dec = gcode_tokenizer.decode(enc)
324
+ print(f" Test encode: {len(enc)} tokens")
325
+ print(f" Test decode: '{dec[:50]}...'")
326
  except Exception as e:
327
  print(f"Failed to load custom tokenizer: {e}")
328
+ import traceback
329
+ traceback.print_exc()
330
  # Fallback to T5 tokenizer
331
  gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
332
  print("Using fallback T5 tokenizer")
 
349
  # GCODE PROCESSING
350
  # ============================================================================
351
 
352
+ def clean_gcode(gcode: str) -> str:
353
+ """Clean up generated gcode - fix formatting, remove garbage."""
354
+
355
+ # Replace <newline> tokens with actual newlines
356
+ gcode = gcode.replace("<newline>", "\n")
357
+
358
+ # If no/few newlines, split on command boundaries
359
+ if gcode.count("\n") < 10:
360
+ # Split before each gcode command
361
+ gcode = re.sub(r'([GM]\d+)', r'\n\1', gcode)
362
+
363
+ # Clean up each line
364
+ cleaned_lines = []
365
+ seen_coords = set() # Track to detect stuck coordinates
366
+
367
+ for line in gcode.split("\n"):
368
+ line = line.strip()
369
+ if not line:
370
+ continue
371
+
372
+ # Skip garbage/metadata lines
373
+ if line.lower() in ["dcode", "gcode", "code", "output"]:
374
+ continue
375
+ if line.startswith("Source:") or line.startswith(";Generated"):
376
+ continue
377
+ if line.startswith("Workarea:") or line.startswith("Algorithm:"):
378
+ continue
379
+
380
+ # Fix malformed coordinates like X-X-X-100 or X-361.X-390
381
+ line = re.sub(r'X-X-X-', 'X-', line)
382
+ line = re.sub(r'X-X-', 'X-', line)
383
+ line = re.sub(r'X-\d+\.X-', 'X-', line)
384
+ line = re.sub(r'Y-Y-Y-', 'Y-', line)
385
+ line = re.sub(r'Y-Y-', 'Y-', line)
386
+ line = re.sub(r'Y-\d+\.Y-', 'Y-', line)
387
+
388
+ # Fix missing spaces: G1X -> G1 X
389
+ line = re.sub(r'(G[01])X', r'\1 X', line)
390
+ line = re.sub(r'(G[01])Y', r'\1 Y', line)
391
+
392
+ # Extract coordinates to check for stuck positions
393
+ x_match = re.search(r'X([-\d.]+)', line)
394
+ y_match = re.search(r'Y([-\d.]+)', line)
395
+ if x_match and y_match:
396
+ try:
397
+ coord = (round(float(x_match.group(1)), 1), round(float(y_match.group(1)), 1))
398
+ if coord in seen_coords:
399
+ # Skip if we've seen this exact coordinate recently
400
+ if len(seen_coords) > 5:
401
+ continue
402
+ seen_coords.add(coord)
403
+ # Keep only last 50 coords
404
+ if len(seen_coords) > 50:
405
+ seen_coords = set(list(seen_coords)[-50:])
406
+ except ValueError:
407
+ pass
408
+
409
+ # Only keep lines starting with valid gcode commands
410
+ if line and line[0] in "GMgm;":
411
+ cleaned_lines.append(line)
412
+
413
+ result = "\n".join(cleaned_lines)
414
+ print(f"Cleaned gcode: {len(cleaned_lines)} lines")
415
+ return result
416
+
417
+
418
  def validate_gcode(gcode: str) -> str:
419
  """Clamp coordinates to machine bounds."""
420
  lines = []
 
613
  with torch.no_grad():
614
  batch_size = latent.shape[0]
615
 
616
+ # Get proper token IDs
617
+ bos_id = gcode_tokenizer.bos_token_id
618
+ eos_id = gcode_tokenizer.eos_token_id
619
+ pad_id = gcode_tokenizer.pad_token_id
620
+
621
+ # For v3, start with BOS token; for v2, encode gcode header
622
  if is_v3:
623
+ # Use the gcode header as the starting prompt
624
+ start_text = "G21\nG90\nM280 P0 S90\nG28\n"
625
+ start_tokens = gcode_tokenizer.encode(start_text, add_special_tokens=False)
626
+ if bos_id is not None:
627
+ start_tokens = [bos_id] + start_tokens
628
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=device)
629
  else:
630
  start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
631
  start_id = start_tokens[0] if start_tokens else 0
632
+ input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
633
 
634
+ print(f"Starting with {input_ids.shape[1]} tokens, BOS={bos_id}, EOS={eos_id}")
 
635
 
636
+ max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - input_ids.shape[1])
 
637
 
638
+ # Track for repetition detection
639
  recent_tokens = []
 
 
640
 
641
  for step in range(max_gen):
642
  logits = gcode_decoder(latent, input_ids)
643
  next_logits = logits[:, -1, :] / temperature
644
 
645
+ # Suppress pad and unk tokens
646
+ if pad_id is not None:
647
+ next_logits[:, pad_id] = float('-inf')
648
+ next_logits[:, 1] = float('-inf') # <unk>
649
+
650
+ # Repetition penalty
651
  if recent_tokens:
652
+ for token_id in set(recent_tokens[-30:]):
653
+ next_logits[:, token_id] *= 0.7
654
 
655
+ # Top-k + Top-p sampling
656
+ top_k = 50
657
+ top_p = 0.92
658
 
659
  # Top-k filtering
660
  top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
661
 
662
+ # Top-p filtering
663
  sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1)
664
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
665
  sorted_indices_to_remove = cumulative_probs > top_p
 
670
  probs = torch.softmax(sorted_logits, dim=-1)
671
  sampled_idx = torch.multinomial(probs, num_samples=1)
672
 
 
673
  next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx))
674
  input_ids = torch.cat([input_ids, next_token], dim=1)
675
  recent_tokens.append(next_token.item())
676
 
677
+ # Debug first few tokens
678
+ if step < 5:
679
+ tok_str = gcode_tokenizer.decode([next_token.item()])
680
+ print(f" Step {step}: token={next_token.item()}, str='{tok_str}'")
681
+
682
  # Check EOS
683
+ if eos_id is not None and next_token.item() == eos_id:
684
  print(f"Hit EOS at step {step}")
685
  break
686
 
687
+ # Early stop on repetition
688
+ if len(recent_tokens) > 30:
689
+ if len(set(recent_tokens[-30:])) < 5:
690
+ print(f"Stopping due to repetition at step {step}")
 
691
  break
692
 
693
+ print(f"Generated {input_ids.shape[1]} total tokens")
694
 
695
+ # Decode
696
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
697
+ print(f"Raw decoded (first 300 chars): {repr(gcode[:300])}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
+ # Clean up the gcode
700
+ gcode = clean_gcode(gcode)
701
 
702
  gcode = validate_gcode(gcode)
703
  line_count = len([l for l in gcode.split("\n") if l.strip()])