dcode / app.py
twarner's picture
Update to SD-Gcode end-to-end diffusion model
783cc24
raw
history blame
14.7 kB
"""dcode Gradio Space - Text to Gcode via SD-Gcode Diffusion."""
import re
import os
import json
import gradio as gr
import torch
import torch.nn as nn
from pathlib import Path
# Machine limits
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
# Model cache
_model = None
class GcodeDecoderConfig:
"""Configuration for gcode decoder."""
def __init__(
self,
latent_channels: int = 4,
latent_size: int = 64,
hidden_size: int = 768,
num_layers: int = 6,
num_heads: int = 12,
vocab_size: int = 32128,
max_seq_len: int = 1024,
dropout: float = 0.1,
):
self.latent_channels = latent_channels
self.latent_size = latent_size
self.latent_dim = latent_channels * latent_size * latent_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.dropout = dropout
class GcodeDecoder(nn.Module):
"""Transformer decoder: SD latent -> gcode tokens."""
def __init__(self, config: GcodeDecoderConfig):
super().__init__()
self.config = config
self.latent_proj = nn.Sequential(
nn.Linear(config.latent_dim, config.hidden_size * 4),
nn.GELU(),
nn.Linear(config.hidden_size * 4, config.hidden_size * 16),
nn.LayerNorm(config.hidden_size * 16),
)
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.hidden_size * 4,
dropout=config.dropout,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
self.ln_f = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embed.weight
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
device = input_ids.device
latent_flat = latent.view(batch_size, -1)
memory = self.latent_proj(latent_flat)
memory = memory.view(batch_size, 16, self.config.hidden_size)
positions = torch.arange(seq_len, device=device)
x = self.token_embed(input_ids) + self.pos_embed(positions)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
x = self.decoder(x, memory, tgt_mask=causal_mask)
x = self.ln_f(x)
return self.lm_head(x)
@torch.no_grad()
def generate(self, latent, tokenizer, max_length=512, temperature=0.8, top_p=0.9):
device = latent.device
batch_size = latent.shape[0]
input_ids = torch.full((batch_size, 1), tokenizer.pad_token_id, dtype=torch.long, device=device)
for _ in range(max_length - 1):
logits = self(latent, input_ids)
next_logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
for b in range(batch_size):
next_logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float('-inf')
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def get_model():
"""Load and cache the SD-Gcode model."""
global _model
if _model is None:
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("Loading SD-Gcode model...")
# Download config and weights
config_path = hf_hub_download("twarner/dcode-sd-gcode", "config.json")
weights_path = hf_hub_download("twarner/dcode-sd-gcode", "pytorch_model.bin")
with open(config_path) as f:
config = json.load(f)
# Load SD pipeline
sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
print(f"Loading SD from {sd_model_id}...")
pipe = StableDiffusionPipeline.from_pretrained(
sd_model_id,
torch_dtype=dtype,
safety_checker=None,
).to(device)
# Build gcode decoder
gcode_cfg = config.get("gcode_decoder", {})
decoder_config = GcodeDecoderConfig(
latent_channels=gcode_cfg.get("latent_channels", 4),
latent_size=gcode_cfg.get("latent_size", 64),
hidden_size=gcode_cfg.get("hidden_size", 768),
num_layers=gcode_cfg.get("num_layers", 6),
num_heads=gcode_cfg.get("num_heads", 12),
vocab_size=gcode_cfg.get("vocab_size", 32128),
max_seq_len=gcode_cfg.get("max_seq_len", 1024),
)
gcode_decoder = GcodeDecoder(decoder_config).to(device, dtype)
# Load weights
state_dict = torch.load(weights_path, map_location=device)
# Extract decoder weights
decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
if k.startswith("gcode_decoder.")}
gcode_decoder.load_state_dict(decoder_state)
gcode_decoder.eval()
# Gcode tokenizer
gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
_model = {
"pipe": pipe,
"gcode_decoder": gcode_decoder,
"gcode_tokenizer": gcode_tokenizer,
"device": device,
"dtype": dtype,
"num_inference_steps": config.get("num_inference_steps", 20),
}
print("Model loaded!")
return _model
def validate_gcode(gcode: str) -> str:
"""Clamp coordinates to machine bounds."""
lines = []
for line in gcode.split("\n"):
corrected = line
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
x = max(BOUNDS["left"], min(BOUNDS["right"], x))
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if y_match:
try:
y = float(y_match.group(1))
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
lines.append(corrected)
return "\n".join(lines)
def gcode_to_svg(gcode: str) -> str:
"""Convert gcode to SVG for visual preview."""
paths = []
current_path = []
x, y = 0.0, 0.0
pen_down = False
lines = []
for line in gcode.split("\n"):
line = line.strip()
if not line:
continue
parts = re.split(r'(?=[GM]\d)', line)
for part in parts:
part = part.strip()
if part and not part.startswith(";"):
lines.append(part)
for line in lines:
if "M280" in line.upper():
match = re.search(r"S(\d+)", line, re.IGNORECASE)
if match:
angle = int(match.group(1))
was_down = pen_down
pen_down = angle < 50
if was_down and not pen_down and len(current_path) > 1:
paths.append(current_path[:])
current_path = []
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
except ValueError:
pass
if y_match:
try:
y = float(y_match.group(1))
except ValueError:
pass
if (x_match or y_match) and pen_down:
current_path.append((x, y))
if len(current_path) > 1:
paths.append(current_path)
w = BOUNDS["right"] - BOUNDS["left"]
h = BOUNDS["top"] - BOUNDS["bottom"]
padding = 20
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
fill="#fff" stroke="#ccc" stroke-width="2"/>
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
'''
for path in paths:
if len(path) < 2:
continue
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
total_points = sum(len(p) for p in paths)
svg += f'''
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
Paths: {len(paths)} | Points: {total_points}
</text>
'''
svg += "</svg>"
return svg
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
"""Generate gcode from text prompt via SD-Gcode diffusion."""
if not prompt or not prompt.strip():
return "Enter a prompt to generate gcode", gcode_to_svg("")
try:
m = get_model()
pipe = m["pipe"]
gcode_decoder = m["gcode_decoder"]
gcode_tokenizer = m["gcode_tokenizer"]
device = m["device"]
dtype = m["dtype"]
# 1. Text -> Latent via full SD diffusion
with torch.no_grad():
result = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance,
output_type="latent",
)
latent = result.images.to(dtype) # [1, 4, 64, 64]
# 2. Latent -> Gcode via trained decoder
with torch.no_grad():
gcode = gcode_decoder.generate(
latent,
gcode_tokenizer,
max_length=max_tokens,
temperature=temperature,
)
gcode = validate_gcode(gcode)
line_count = len(gcode.split("\n"))
svg = gcode_to_svg(gcode)
gcode_with_header = f"; dcode SD-Gcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}"
return gcode_with_header, svg
except Exception as e:
import traceback
traceback.print_exc()
return f"; Error: {e}", gcode_to_svg("")
# Custom CSS
custom_css = """
.gradio-container {
max-width: 1200px !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
gr.Markdown("""
# dcode
**Text -> Polargraph Gcode via Stable Diffusion**
Single end-to-end diffusion model: text -> CLIP -> UNet -> latent -> gcode decoder -> gcode
[GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-sd-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="drawing of a cat, abstract spiral, portrait...",
lines=2
)
with gr.Row():
temperature = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens")
with gr.Row():
num_steps = gr.Slider(10, 50, value=20, step=5, label="Diffusion Steps")
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
gr.Examples(
examples=[
["line drawing of a cat"],
["abstract spiral pattern"],
["simple house with chimney"],
["portrait sketch"],
["geometric shapes and lines"],
],
inputs=prompt,
)
with gr.Column(scale=2):
preview = gr.HTML(
value=gcode_to_svg(""),
label="Preview",
)
with gr.Accordion("Gcode Output", open=False):
gcode_output = gr.Code(label="Gcode", language=None, lines=15)
gr.Markdown("""
---
**Machine Bounds**: X: +/-420.5mm, Y: +/-594.5mm | Pen servo: 40 deg (down) / 90 deg (up) | **License**: MIT
""")
generate_btn.click(
generate,
[prompt, temperature, max_tokens, num_steps, guidance],
[gcode_output, preview]
)
prompt.submit(
generate,
[prompt, temperature, max_tokens, num_steps, guidance],
[gcode_output, preview]
)
if __name__ == "__main__":
demo.launch()