twarner commited on
Commit
99753aa
·
1 Parent(s): 6dbfa53

Fix gcode tokenizer config and add post-processing to restore newlines

Browse files
Files changed (1) hide show
  1. app.py +69 -27
app.py CHANGED
@@ -305,8 +305,20 @@ def get_model():
305
  # Try loading custom tokenizer from v3 model
306
  tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json")
307
  gcode_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
308
- print("Loaded custom gcode tokenizer")
309
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
310
  # Fallback to T5 tokenizer
311
  gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
312
  print("Using fallback T5 tokenizer")
@@ -365,19 +377,21 @@ def gcode_to_svg(gcode: str) -> str:
365
  x, y = 0.0, 0.0
366
  pen_down = False
367
 
368
- # Split on newlines, newline tokens, or command boundaries
369
- lines = []
370
  # Replace newline tokens with actual newlines
371
  gcode = gcode.replace("<newline>", "\n")
372
 
373
- for line in gcode.replace(";", "\n;").split("\n"):
374
- line = line.strip()
375
- if not line:
 
 
 
376
  continue
377
- parts = re.split(r'(?=[GM]\d)', line)
 
378
  for part in parts:
379
  part = part.strip()
380
- if part and not part.startswith(";"):
381
  lines.append(part)
382
 
383
  for line in lines:
@@ -525,22 +539,23 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
525
  with torch.no_grad():
526
  batch_size = latent.shape[0]
527
 
528
- # Start token
529
  if is_v3:
530
- # V3 uses custom tokenizer with BOS
531
- start_id = gcode_tokenizer.bos_token_id or 0
532
  else:
533
- # V2 uses semicolon as start
534
  start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
535
- start_id = start_tokens[0] if start_tokens else gcode_tokenizer.pad_token_id
536
 
 
537
  input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
538
 
539
  max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
 
540
 
541
  # Track generated content for repetition detection
542
  recent_tokens = []
543
- repetition_window = 50
 
544
 
545
  for step in range(max_gen):
546
  logits = gcode_decoder(latent, input_ids)
@@ -549,11 +564,11 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
549
  # Repetition penalty - reduce probability of recent tokens
550
  if recent_tokens:
551
  for token_id in set(recent_tokens[-repetition_window:]):
552
- next_logits[:, token_id] *= 0.7
553
 
554
  # Top-k + Top-p sampling for better coherence
555
- top_k = 50
556
- top_p = 0.85
557
 
558
  # Top-k filtering
559
  top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
@@ -575,33 +590,60 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
575
  recent_tokens.append(next_token.item())
576
 
577
  # Check EOS
578
- if next_token.item() == gcode_tokenizer.eos_token_id:
 
579
  break
580
 
581
  # Early stop on excessive repetition
582
  if len(recent_tokens) > 20:
583
  last_20 = recent_tokens[-20:]
584
- if len(set(last_20)) < 5: # Less than 5 unique tokens in last 20
585
- print("Stopping due to repetition")
586
  break
587
 
588
  print(f"Generated {input_ids.shape[1]} tokens")
 
 
589
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
590
 
591
- # Post-process for v3: restore newlines
592
  if is_v3:
593
  gcode = gcode.replace("<newline>", "\n")
594
 
595
- # Filter out training metadata lines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  filtered_lines = []
597
  for line in gcode.split("\n"):
598
- # Skip metadata headers from training data
 
 
 
599
  if line.startswith("Source:") or line.startswith(";Generated"):
600
  continue
601
- filtered_lines.append(line)
602
- gcode = "\n".join(filtered_lines)
 
 
 
603
 
604
- print(f"Decoded gcode length: {len(gcode)} chars")
 
605
 
606
  gcode = validate_gcode(gcode)
607
  line_count = len([l for l in gcode.split("\n") if l.strip()])
 
305
  # Try loading custom tokenizer from v3 model
306
  tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json")
307
  gcode_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
308
+ # Ensure special tokens are set
309
+ if gcode_tokenizer.pad_token is None:
310
+ gcode_tokenizer.pad_token = "<pad>"
311
+ gcode_tokenizer.pad_token_id = 0
312
+ if gcode_tokenizer.bos_token is None:
313
+ gcode_tokenizer.bos_token = "<s>"
314
+ gcode_tokenizer.bos_token_id = 1
315
+ if gcode_tokenizer.eos_token is None:
316
+ gcode_tokenizer.eos_token = "</s>"
317
+ gcode_tokenizer.eos_token_id = 2
318
+ print(f"Loaded custom gcode tokenizer (vocab={gcode_tokenizer.vocab_size})")
319
+ print(f" BOS={gcode_tokenizer.bos_token_id}, EOS={gcode_tokenizer.eos_token_id}, PAD={gcode_tokenizer.pad_token_id}")
320
+ except Exception as e:
321
+ print(f"Failed to load custom tokenizer: {e}")
322
  # Fallback to T5 tokenizer
323
  gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
324
  print("Using fallback T5 tokenizer")
 
377
  x, y = 0.0, 0.0
378
  pen_down = False
379
 
 
 
380
  # Replace newline tokens with actual newlines
381
  gcode = gcode.replace("<newline>", "\n")
382
 
383
+ # Split concatenated gcode into separate commands
384
+ # First split on explicit newlines
385
+ lines = []
386
+ for raw_line in gcode.split("\n"):
387
+ raw_line = raw_line.strip()
388
+ if not raw_line:
389
  continue
390
+ # Split on command boundaries (G0, G1, M280, etc)
391
+ parts = re.split(r'(?=[GM]\d)', raw_line)
392
  for part in parts:
393
  part = part.strip()
394
+ if part and not part.startswith(";") and part[0] in "GMgm":
395
  lines.append(part)
396
 
397
  for line in lines:
 
539
  with torch.no_grad():
540
  batch_size = latent.shape[0]
541
 
542
+ # Start token - use BOS for v3, semicolon for v2
543
  if is_v3:
544
+ start_id = gcode_tokenizer.bos_token_id if gcode_tokenizer.bos_token_id is not None else 1
 
545
  else:
 
546
  start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
547
+ start_id = start_tokens[0] if start_tokens else 0
548
 
549
+ print(f"Starting generation with token id: {start_id}")
550
  input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
551
 
552
  max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
553
+ eos_id = gcode_tokenizer.eos_token_id if gcode_tokenizer.eos_token_id is not None else 2
554
 
555
  # Track generated content for repetition detection
556
  recent_tokens = []
557
+ recent_coords = []
558
+ repetition_window = 30
559
 
560
  for step in range(max_gen):
561
  logits = gcode_decoder(latent, input_ids)
 
564
  # Repetition penalty - reduce probability of recent tokens
565
  if recent_tokens:
566
  for token_id in set(recent_tokens[-repetition_window:]):
567
+ next_logits[:, token_id] *= 0.6 # Stronger penalty
568
 
569
  # Top-k + Top-p sampling for better coherence
570
+ top_k = 40
571
+ top_p = 0.9
572
 
573
  # Top-k filtering
574
  top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
 
590
  recent_tokens.append(next_token.item())
591
 
592
  # Check EOS
593
+ if next_token.item() == eos_id:
594
+ print(f"Hit EOS at step {step}")
595
  break
596
 
597
  # Early stop on excessive repetition
598
  if len(recent_tokens) > 20:
599
  last_20 = recent_tokens[-20:]
600
+ if len(set(last_20)) < 4: # Less than 4 unique tokens in last 20
601
+ print(f"Stopping due to token repetition at step {step}")
602
  break
603
 
604
  print(f"Generated {input_ids.shape[1]} tokens")
605
+
606
+ # Decode - skip special tokens
607
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
608
+ print(f"Raw decoded (first 200): {gcode[:200]}")
609
 
610
+ # Post-process for v3: restore newlines from <newline> token
611
  if is_v3:
612
  gcode = gcode.replace("<newline>", "\n")
613
 
614
+ # If still no newlines, try to split on command boundaries
615
+ if "\n" not in gcode or gcode.count("\n") < 5:
616
+ print("No newlines found, splitting on command boundaries...")
617
+ # Split before G0, G1, G28, M280 commands
618
+ gcode = re.sub(r'(G0\s)', r'\n\1', gcode)
619
+ gcode = re.sub(r'(G1\s)', r'\n\1', gcode)
620
+ gcode = re.sub(r'(G1X)', r'\nG1 X', gcode)
621
+ gcode = re.sub(r'(G0X)', r'\nG0 X', gcode)
622
+ gcode = re.sub(r'(G28)', r'\nG28', gcode)
623
+ gcode = re.sub(r'(G21)', r'\nG21', gcode)
624
+ gcode = re.sub(r'(G90)', r'\nG90', gcode)
625
+ gcode = re.sub(r'(M280)', r'\nM280', gcode)
626
+ # Split on F speed values that are followed by another command
627
+ gcode = re.sub(r'(F\d+)(G)', r'\1\n\2', gcode)
628
+ gcode = re.sub(r'(F\d+)(M)', r'\1\n\2', gcode)
629
+
630
+ # Filter out training metadata and garbage lines
631
  filtered_lines = []
632
  for line in gcode.split("\n"):
633
+ line = line.strip()
634
+ # Skip empty lines and metadata
635
+ if not line:
636
+ continue
637
  if line.startswith("Source:") or line.startswith(";Generated"):
638
  continue
639
+ if line.lower() in ["dcode", "gcode", "code"]: # Skip garbage words
640
+ continue
641
+ # Only keep lines that look like gcode (start with G, M, or ;)
642
+ if line[0] in "GMgm;":
643
+ filtered_lines.append(line)
644
 
645
+ gcode = "\n".join(filtered_lines)
646
+ print(f"Filtered gcode: {len(filtered_lines)} lines, {len(gcode)} chars")
647
 
648
  gcode = validate_gcode(gcode)
649
  line_count = len([l for l in gcode.split("\n") if l.strip()])