dcode / app.py
twarner's picture
Update to latent-gcode diffusion model with SD 1.5
956dba9
raw
history blame
12.9 kB
"""dcode Gradio Space - Text to Gcode via Latent Diffusion."""
import re
import gradio as gr
import torch
from pathlib import Path
# Machine limits
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
# Model caches
_generator = None
def get_generator():
"""Load and cache the latent-gcode generator."""
global _generator
if _generator is None:
from diffusers import StableDiffusionPipeline, AutoencoderKL
from transformers import AutoTokenizer
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("Loading Stable Diffusion pipeline...")
# Use SD 1.5 which is more reliably available
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
safety_checker=None,
use_safetensors=True,
).to(device)
print("Loading gcode decoder...")
from huggingface_hub import hf_hub_download
# Download model files
model_path = hf_hub_download("twarner/dcode-latent-gcode", "pytorch_model.bin")
config_path = hf_hub_download("twarner/dcode-latent-gcode", "config.json")
import json
with open(config_path) as f:
config = json.load(f)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
# Build decoder model
class LatentProjector(nn.Module):
def __init__(self, latent_dim, hidden_size):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(latent_dim, hidden_size * 2),
nn.GELU(),
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
)
def forward(self, x):
return self.proj(x)
class GcodeDecoder(nn.Module):
def __init__(self, hidden_size, vocab_size, num_layers, num_heads, max_seq_len):
super().__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.pos_embed = nn.Embedding(max_seq_len, hidden_size)
layer = nn.TransformerDecoderLayer(hidden_size, num_heads, hidden_size * 4, batch_first=True)
self.decoder = nn.TransformerDecoder(layer, num_layers)
self.head = nn.Linear(hidden_size, vocab_size)
self.max_seq_len = max_seq_len
def forward(self, tgt, memory, tgt_mask=None):
pos = torch.arange(tgt.size(1), device=tgt.device)
x = self.embed(tgt) + self.pos_embed(pos)
x = self.decoder(x, memory, tgt_mask=tgt_mask)
return self.head(x)
# Initialize models
latent_dim = 4 * 64 * 64
hidden_size = config.get("hidden_size", 512)
vocab_size = tokenizer.vocab_size
num_layers = config.get("num_layers", 6)
num_heads = config.get("num_heads", 8)
max_seq_len = config.get("max_seq_len", 1024)
projector = LatentProjector(latent_dim, hidden_size).to(device, dtype)
decoder = GcodeDecoder(hidden_size, vocab_size, num_layers, num_heads, max_seq_len).to(device, dtype)
# Load weights
state_dict = torch.load(model_path, map_location=device)
proj_state = {k.replace("projector.", ""): v for k, v in state_dict.items() if k.startswith("projector.")}
dec_state = {k.replace("decoder.", ""): v for k, v in state_dict.items() if k.startswith("decoder.")}
projector.load_state_dict(proj_state)
decoder.load_state_dict(dec_state)
projector.eval()
decoder.eval()
_generator = {
"pipe": pipe,
"projector": projector,
"decoder": decoder,
"tokenizer": tokenizer,
"device": device,
"dtype": dtype,
"max_seq_len": max_seq_len,
}
print("Models loaded!")
return _generator
def validate_gcode(gcode: str) -> str:
"""Clamp coordinates to machine bounds."""
lines = []
for line in gcode.split("\n"):
corrected = line
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
x = max(BOUNDS["left"], min(BOUNDS["right"], x))
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if y_match:
try:
y = float(y_match.group(1))
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
lines.append(corrected)
return "\n".join(lines)
def gcode_to_svg(gcode: str) -> str:
"""Convert gcode to SVG for visual preview."""
paths = []
current_path = []
x, y = 0.0, 0.0
pen_down = False
lines = []
for line in gcode.split("\n"):
line = line.strip()
if not line:
continue
parts = re.split(r'(?=[GM]\d)', line)
for part in parts:
part = part.strip()
if part and not part.startswith(";"):
lines.append(part)
for line in lines:
if "M280" in line.upper():
match = re.search(r"S(\d+)", line, re.IGNORECASE)
if match:
angle = int(match.group(1))
was_down = pen_down
pen_down = angle < 50
if was_down and not pen_down and len(current_path) > 1:
paths.append(current_path[:])
current_path = []
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
except ValueError:
pass
if y_match:
try:
y = float(y_match.group(1))
except ValueError:
pass
if (x_match or y_match) and pen_down:
current_path.append((x, y))
if len(current_path) > 1:
paths.append(current_path)
w = BOUNDS["right"] - BOUNDS["left"]
h = BOUNDS["top"] - BOUNDS["bottom"]
padding = 20
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
fill="#fff" stroke="#ccc" stroke-width="2"/>
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
'''
for path in paths:
if len(path) < 2:
continue
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
total_points = sum(len(p) for p in paths)
svg += f'''
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
Paths: {len(paths)} | Points: {total_points}
</text>
'''
svg += "</svg>"
return svg
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
"""Generate gcode from text prompt via latent diffusion."""
if not prompt or not prompt.strip():
return "Enter a prompt to generate gcode", gcode_to_svg("")
try:
gen = get_generator()
pipe = gen["pipe"]
projector = gen["projector"]
decoder = gen["decoder"]
tokenizer = gen["tokenizer"]
device = gen["device"]
dtype = gen["dtype"]
max_seq_len = gen["max_seq_len"]
# 1. Text -> Latent via Stable Diffusion
with torch.no_grad():
result = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance,
output_type="latent",
)
latent = result.images # [1, 4, 64, 64]
# 2. Latent -> Gcode via decoder
with torch.no_grad():
# Flatten and project latent
latent_flat = latent.view(1, -1).to(dtype) # [1, 4*64*64]
memory = projector(latent_flat).unsqueeze(1) # [1, 1, hidden]
# Autoregressive decoding
bos_id = tokenizer.bos_token_id or tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id
tokens = torch.tensor([[bos_id]], device=device)
for _ in range(min(max_tokens, max_seq_len - 1)):
logits = decoder(tokens, memory)
next_logits = logits[:, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
tokens = torch.cat([tokens, next_token], dim=1)
if next_token.item() == eos_id:
break
gcode = tokenizer.decode(tokens[0], skip_special_tokens=True)
gcode = validate_gcode(gcode)
line_count = len(gcode.split("\n"))
svg = gcode_to_svg(gcode)
gcode_with_header = f"; dcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}"
return gcode_with_header, svg
except Exception as e:
import traceback
traceback.print_exc()
return f"; Error: {e}", gcode_to_svg("")
# Custom CSS
custom_css = """
.gradio-container {
max-width: 1200px !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
gr.Markdown("""
# dcode
**Text → Polargraph Gcode via Latent Diffusion**
Uses Stable Diffusion to generate latents from text, then decodes to machine gcode.
[GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-latent-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="drawing of a cat, abstract spiral, portrait...",
lines=2
)
with gr.Row():
temperature = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens")
with gr.Row():
num_steps = gr.Slider(10, 50, value=25, step=5, label="Diffusion Steps")
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
gr.Examples(
examples=[
["line drawing of a cat"],
["abstract spiral pattern"],
["simple house with chimney"],
["portrait sketch"],
["geometric shapes and lines"],
],
inputs=prompt,
)
with gr.Column(scale=2):
preview = gr.HTML(
value=gcode_to_svg(""),
label="Preview",
)
with gr.Accordion("Gcode Output", open=False):
gcode_output = gr.Code(label="Gcode", language=None, lines=15)
gr.Markdown("""
---
**Machine Bounds**: X: ±420.5mm, Y: ±594.5mm | Pen servo: 40° (down) / 90° (up) | **License**: MIT
""")
generate_btn.click(
generate,
[prompt, temperature, max_tokens, num_steps, guidance],
[gcode_output, preview]
)
prompt.submit(
generate,
[prompt, temperature, max_tokens, num_steps, guidance],
[gcode_output, preview]
)
if __name__ == "__main__":
demo.launch()