Improve inference: prompt enhancement, better sampling, repetition penalty
Browse files
app.py
CHANGED
|
@@ -456,6 +456,21 @@ def gcode_to_svg(gcode: str) -> str:
|
|
| 456 |
# GENERATION
|
| 457 |
# ============================================================================
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
@spaces.GPU
|
| 460 |
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
|
| 461 |
"""Generate gcode from text prompt."""
|
|
@@ -471,10 +486,16 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
|
|
| 471 |
dtype = m["dtype"]
|
| 472 |
is_v3 = m.get("is_v3", False)
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# Text -> Latent via SD diffusion
|
| 475 |
with torch.no_grad():
|
|
|
|
| 476 |
result = pipe(
|
| 477 |
-
|
|
|
|
| 478 |
num_inference_steps=num_steps,
|
| 479 |
guidance_scale=guidance,
|
| 480 |
output_type="latent",
|
|
@@ -499,27 +520,52 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
|
|
| 499 |
|
| 500 |
max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
for step in range(max_gen):
|
| 503 |
logits = gcode_decoder(latent, input_ids)
|
| 504 |
next_logits = logits[:, -1, :] / temperature
|
| 505 |
|
| 506 |
-
#
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 509 |
-
sorted_indices_to_remove = cumulative_probs >
|
| 510 |
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 511 |
sorted_indices_to_remove[:, 0] = False
|
|
|
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
|
| 516 |
-
|
| 517 |
-
next_token =
|
| 518 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
| 519 |
|
| 520 |
# Check EOS
|
| 521 |
if next_token.item() == gcode_tokenizer.eos_token_id:
|
| 522 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
print(f"Generated {input_ids.shape[1]} tokens")
|
| 525 |
gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
@@ -659,22 +705,25 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
|
|
| 659 |
)
|
| 660 |
|
| 661 |
with gr.Accordion("settings", open=False):
|
| 662 |
-
temperature = gr.Slider(0.
|
| 663 |
-
max_tokens = gr.Slider(256, 2048, value=
|
| 664 |
-
num_steps = gr.Slider(
|
| 665 |
-
guidance = gr.Slider(
|
| 666 |
|
| 667 |
generate_btn = gr.Button("generate", variant="secondary")
|
| 668 |
|
| 669 |
gr.Examples(
|
| 670 |
examples=[
|
| 671 |
-
["
|
| 672 |
-
["
|
| 673 |
-
["
|
|
|
|
|
|
|
|
|
|
| 674 |
],
|
| 675 |
inputs=prompt,
|
| 676 |
label=None,
|
| 677 |
-
examples_per_page=
|
| 678 |
)
|
| 679 |
|
| 680 |
with gr.Column(scale=2):
|
|
|
|
| 456 |
# GENERATION
|
| 457 |
# ============================================================================
|
| 458 |
|
| 459 |
+
def enhance_prompt(prompt: str) -> str:
|
| 460 |
+
"""Enhance prompt for better SD line drawing generation."""
|
| 461 |
+
prompt = prompt.strip().lower()
|
| 462 |
+
|
| 463 |
+
# Skip if already detailed
|
| 464 |
+
if any(x in prompt for x in ["drawing", "sketch", "line", "illustration"]):
|
| 465 |
+
enhanced = prompt
|
| 466 |
+
else:
|
| 467 |
+
enhanced = f"a simple line drawing of {prompt}"
|
| 468 |
+
|
| 469 |
+
# Add style suffixes for better SD output
|
| 470 |
+
enhanced += ", black ink on white paper, single continuous line, minimalist sketch, vector art style"
|
| 471 |
+
return enhanced
|
| 472 |
+
|
| 473 |
+
|
| 474 |
@spaces.GPU
|
| 475 |
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
|
| 476 |
"""Generate gcode from text prompt."""
|
|
|
|
| 486 |
dtype = m["dtype"]
|
| 487 |
is_v3 = m.get("is_v3", False)
|
| 488 |
|
| 489 |
+
# Enhance prompt for better line drawing generation
|
| 490 |
+
enhanced = enhance_prompt(prompt)
|
| 491 |
+
print(f"Enhanced prompt: {enhanced}")
|
| 492 |
+
|
| 493 |
# Text -> Latent via SD diffusion
|
| 494 |
with torch.no_grad():
|
| 495 |
+
# Use negative prompt to avoid unwanted styles
|
| 496 |
result = pipe(
|
| 497 |
+
enhanced,
|
| 498 |
+
negative_prompt="color, shading, gradient, photorealistic, 3d, complex, detailed texture",
|
| 499 |
num_inference_steps=num_steps,
|
| 500 |
guidance_scale=guidance,
|
| 501 |
output_type="latent",
|
|
|
|
| 520 |
|
| 521 |
max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
|
| 522 |
|
| 523 |
+
# Track generated content for repetition detection
|
| 524 |
+
recent_tokens = []
|
| 525 |
+
repetition_window = 50
|
| 526 |
+
|
| 527 |
for step in range(max_gen):
|
| 528 |
logits = gcode_decoder(latent, input_ids)
|
| 529 |
next_logits = logits[:, -1, :] / temperature
|
| 530 |
|
| 531 |
+
# Repetition penalty - reduce probability of recent tokens
|
| 532 |
+
if recent_tokens:
|
| 533 |
+
for token_id in set(recent_tokens[-repetition_window:]):
|
| 534 |
+
next_logits[:, token_id] *= 0.7
|
| 535 |
+
|
| 536 |
+
# Top-k + Top-p sampling for better coherence
|
| 537 |
+
top_k = 50
|
| 538 |
+
top_p = 0.85
|
| 539 |
+
|
| 540 |
+
# Top-k filtering
|
| 541 |
+
top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
|
| 542 |
+
|
| 543 |
+
# Top-p filtering within top-k
|
| 544 |
+
sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1)
|
| 545 |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 546 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 547 |
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 548 |
sorted_indices_to_remove[:, 0] = False
|
| 549 |
+
sorted_logits[sorted_indices_to_remove] = float('-inf')
|
| 550 |
|
| 551 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 552 |
+
sampled_idx = torch.multinomial(probs, num_samples=1)
|
| 553 |
|
| 554 |
+
# Map back to vocabulary indices
|
| 555 |
+
next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx))
|
| 556 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 557 |
+
recent_tokens.append(next_token.item())
|
| 558 |
|
| 559 |
# Check EOS
|
| 560 |
if next_token.item() == gcode_tokenizer.eos_token_id:
|
| 561 |
break
|
| 562 |
+
|
| 563 |
+
# Early stop on excessive repetition
|
| 564 |
+
if len(recent_tokens) > 20:
|
| 565 |
+
last_20 = recent_tokens[-20:]
|
| 566 |
+
if len(set(last_20)) < 5: # Less than 5 unique tokens in last 20
|
| 567 |
+
print("Stopping due to repetition")
|
| 568 |
+
break
|
| 569 |
|
| 570 |
print(f"Generated {input_ids.shape[1]} tokens")
|
| 571 |
gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
with gr.Accordion("settings", open=False):
|
| 708 |
+
temperature = gr.Slider(0.3, 1.2, value=0.6, label="temperature", step=0.1)
|
| 709 |
+
max_tokens = gr.Slider(256, 2048, value=1536, step=256, label="max tokens")
|
| 710 |
+
num_steps = gr.Slider(20, 50, value=30, step=5, label="diffusion steps")
|
| 711 |
+
guidance = gr.Slider(5.0, 20.0, value=12.0, step=0.5, label="guidance")
|
| 712 |
|
| 713 |
generate_btn = gr.Button("generate", variant="secondary")
|
| 714 |
|
| 715 |
gr.Examples(
|
| 716 |
examples=[
|
| 717 |
+
["horse"],
|
| 718 |
+
["cat face"],
|
| 719 |
+
["spiral"],
|
| 720 |
+
["star"],
|
| 721 |
+
["tree"],
|
| 722 |
+
["flower"],
|
| 723 |
],
|
| 724 |
inputs=prompt,
|
| 725 |
label=None,
|
| 726 |
+
examples_per_page=6,
|
| 727 |
)
|
| 728 |
|
| 729 |
with gr.Column(scale=2):
|