twarner commited on
Commit
9f5b647
·
1 Parent(s): 5d7555b

Improve inference: prompt enhancement, better sampling, repetition penalty

Browse files
Files changed (1) hide show
  1. app.py +65 -16
app.py CHANGED
@@ -456,6 +456,21 @@ def gcode_to_svg(gcode: str) -> str:
456
  # GENERATION
457
  # ============================================================================
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  @spaces.GPU
460
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
461
  """Generate gcode from text prompt."""
@@ -471,10 +486,16 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
471
  dtype = m["dtype"]
472
  is_v3 = m.get("is_v3", False)
473
 
 
 
 
 
474
  # Text -> Latent via SD diffusion
475
  with torch.no_grad():
 
476
  result = pipe(
477
- prompt,
 
478
  num_inference_steps=num_steps,
479
  guidance_scale=guidance,
480
  output_type="latent",
@@ -499,27 +520,52 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
499
 
500
  max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
501
 
 
 
 
 
502
  for step in range(max_gen):
503
  logits = gcode_decoder(latent, input_ids)
504
  next_logits = logits[:, -1, :] / temperature
505
 
506
- # Top-p sampling
507
- sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
 
 
 
 
 
 
 
 
 
 
 
 
508
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
509
- sorted_indices_to_remove = cumulative_probs > 0.9
510
  sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
511
  sorted_indices_to_remove[:, 0] = False
 
512
 
513
- for b in range(batch_size):
514
- next_logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float('-inf')
515
 
516
- probs = torch.softmax(next_logits, dim=-1)
517
- next_token = torch.multinomial(probs, num_samples=1)
518
  input_ids = torch.cat([input_ids, next_token], dim=1)
 
519
 
520
  # Check EOS
521
  if next_token.item() == gcode_tokenizer.eos_token_id:
522
  break
 
 
 
 
 
 
 
523
 
524
  print(f"Generated {input_ids.shape[1]} tokens")
525
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
@@ -659,22 +705,25 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
659
  )
660
 
661
  with gr.Accordion("settings", open=False):
662
- temperature = gr.Slider(0.5, 1.5, value=0.8, label="temperature", step=0.1)
663
- max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="max tokens")
664
- num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
665
- guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
666
 
667
  generate_btn = gr.Button("generate", variant="secondary")
668
 
669
  gr.Examples(
670
  examples=[
671
- ["a line drawing of a horse"],
672
- ["portrait sketch"],
673
- ["geometric shapes"],
 
 
 
674
  ],
675
  inputs=prompt,
676
  label=None,
677
- examples_per_page=3,
678
  )
679
 
680
  with gr.Column(scale=2):
 
456
  # GENERATION
457
  # ============================================================================
458
 
459
+ def enhance_prompt(prompt: str) -> str:
460
+ """Enhance prompt for better SD line drawing generation."""
461
+ prompt = prompt.strip().lower()
462
+
463
+ # Skip if already detailed
464
+ if any(x in prompt for x in ["drawing", "sketch", "line", "illustration"]):
465
+ enhanced = prompt
466
+ else:
467
+ enhanced = f"a simple line drawing of {prompt}"
468
+
469
+ # Add style suffixes for better SD output
470
+ enhanced += ", black ink on white paper, single continuous line, minimalist sketch, vector art style"
471
+ return enhanced
472
+
473
+
474
  @spaces.GPU
475
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
476
  """Generate gcode from text prompt."""
 
486
  dtype = m["dtype"]
487
  is_v3 = m.get("is_v3", False)
488
 
489
+ # Enhance prompt for better line drawing generation
490
+ enhanced = enhance_prompt(prompt)
491
+ print(f"Enhanced prompt: {enhanced}")
492
+
493
  # Text -> Latent via SD diffusion
494
  with torch.no_grad():
495
+ # Use negative prompt to avoid unwanted styles
496
  result = pipe(
497
+ enhanced,
498
+ negative_prompt="color, shading, gradient, photorealistic, 3d, complex, detailed texture",
499
  num_inference_steps=num_steps,
500
  guidance_scale=guidance,
501
  output_type="latent",
 
520
 
521
  max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
522
 
523
+ # Track generated content for repetition detection
524
+ recent_tokens = []
525
+ repetition_window = 50
526
+
527
  for step in range(max_gen):
528
  logits = gcode_decoder(latent, input_ids)
529
  next_logits = logits[:, -1, :] / temperature
530
 
531
+ # Repetition penalty - reduce probability of recent tokens
532
+ if recent_tokens:
533
+ for token_id in set(recent_tokens[-repetition_window:]):
534
+ next_logits[:, token_id] *= 0.7
535
+
536
+ # Top-k + Top-p sampling for better coherence
537
+ top_k = 50
538
+ top_p = 0.85
539
+
540
+ # Top-k filtering
541
+ top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
542
+
543
+ # Top-p filtering within top-k
544
+ sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1)
545
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
546
+ sorted_indices_to_remove = cumulative_probs > top_p
547
  sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
548
  sorted_indices_to_remove[:, 0] = False
549
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
550
 
551
+ probs = torch.softmax(sorted_logits, dim=-1)
552
+ sampled_idx = torch.multinomial(probs, num_samples=1)
553
 
554
+ # Map back to vocabulary indices
555
+ next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx))
556
  input_ids = torch.cat([input_ids, next_token], dim=1)
557
+ recent_tokens.append(next_token.item())
558
 
559
  # Check EOS
560
  if next_token.item() == gcode_tokenizer.eos_token_id:
561
  break
562
+
563
+ # Early stop on excessive repetition
564
+ if len(recent_tokens) > 20:
565
+ last_20 = recent_tokens[-20:]
566
+ if len(set(last_20)) < 5: # Less than 5 unique tokens in last 20
567
+ print("Stopping due to repetition")
568
+ break
569
 
570
  print(f"Generated {input_ids.shape[1]} tokens")
571
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
705
  )
706
 
707
  with gr.Accordion("settings", open=False):
708
+ temperature = gr.Slider(0.3, 1.2, value=0.6, label="temperature", step=0.1)
709
+ max_tokens = gr.Slider(256, 2048, value=1536, step=256, label="max tokens")
710
+ num_steps = gr.Slider(20, 50, value=30, step=5, label="diffusion steps")
711
+ guidance = gr.Slider(5.0, 20.0, value=12.0, step=0.5, label="guidance")
712
 
713
  generate_btn = gr.Button("generate", variant="secondary")
714
 
715
  gr.Examples(
716
  examples=[
717
+ ["horse"],
718
+ ["cat face"],
719
+ ["spiral"],
720
+ ["star"],
721
+ ["tree"],
722
+ ["flower"],
723
  ],
724
  inputs=prompt,
725
  label=None,
726
+ examples_per_page=6,
727
  )
728
 
729
  with gr.Column(scale=2):