| """dcode Gradio Space - Text to Gcode via Latent Diffusion.""" |
|
|
| import re |
| import gradio as gr |
| import torch |
| from pathlib import Path |
|
|
| |
| BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
|
|
| |
| _generator = None |
|
|
|
|
| def get_generator(): |
| """Load and cache the latent-gcode generator.""" |
| global _generator |
| if _generator is None: |
| from diffusers import StableDiffusionPipeline, AutoencoderKL |
| from transformers import AutoTokenizer |
| import torch.nn as nn |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
| |
| print("Loading Stable Diffusion pipeline...") |
| pipe = StableDiffusionPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-2-1-base", |
| torch_dtype=dtype, |
| safety_checker=None, |
| ).to(device) |
| |
| print("Loading gcode decoder...") |
| from huggingface_hub import hf_hub_download |
| |
| |
| model_path = hf_hub_download("twarner/dcode-latent-gcode", "pytorch_model.bin") |
| config_path = hf_hub_download("twarner/dcode-latent-gcode", "config.json") |
| |
| import json |
| with open(config_path) as f: |
| config = json.load(f) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
| |
| |
| class LatentProjector(nn.Module): |
| def __init__(self, latent_dim, hidden_size): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(latent_dim, hidden_size * 2), |
| nn.GELU(), |
| nn.Linear(hidden_size * 2, hidden_size), |
| nn.LayerNorm(hidden_size), |
| ) |
| def forward(self, x): |
| return self.proj(x) |
| |
| class GcodeDecoder(nn.Module): |
| def __init__(self, hidden_size, vocab_size, num_layers, num_heads, max_seq_len): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, hidden_size) |
| self.pos_embed = nn.Embedding(max_seq_len, hidden_size) |
| layer = nn.TransformerDecoderLayer(hidden_size, num_heads, hidden_size * 4, batch_first=True) |
| self.decoder = nn.TransformerDecoder(layer, num_layers) |
| self.head = nn.Linear(hidden_size, vocab_size) |
| self.max_seq_len = max_seq_len |
| |
| def forward(self, tgt, memory, tgt_mask=None): |
| pos = torch.arange(tgt.size(1), device=tgt.device) |
| x = self.embed(tgt) + self.pos_embed(pos) |
| x = self.decoder(x, memory, tgt_mask=tgt_mask) |
| return self.head(x) |
| |
| |
| latent_dim = 4 * 64 * 64 |
| hidden_size = config.get("hidden_size", 512) |
| vocab_size = tokenizer.vocab_size |
| num_layers = config.get("num_layers", 6) |
| num_heads = config.get("num_heads", 8) |
| max_seq_len = config.get("max_seq_len", 1024) |
| |
| projector = LatentProjector(latent_dim, hidden_size).to(device, dtype) |
| decoder = GcodeDecoder(hidden_size, vocab_size, num_layers, num_heads, max_seq_len).to(device, dtype) |
| |
| |
| state_dict = torch.load(model_path, map_location=device) |
| |
| proj_state = {k.replace("projector.", ""): v for k, v in state_dict.items() if k.startswith("projector.")} |
| dec_state = {k.replace("decoder.", ""): v for k, v in state_dict.items() if k.startswith("decoder.")} |
| |
| projector.load_state_dict(proj_state) |
| decoder.load_state_dict(dec_state) |
| |
| projector.eval() |
| decoder.eval() |
| |
| _generator = { |
| "pipe": pipe, |
| "projector": projector, |
| "decoder": decoder, |
| "tokenizer": tokenizer, |
| "device": device, |
| "dtype": dtype, |
| "max_seq_len": max_seq_len, |
| } |
| print("Models loaded!") |
| |
| return _generator |
|
|
|
|
| def validate_gcode(gcode: str) -> str: |
| """Clamp coordinates to machine bounds.""" |
| lines = [] |
| for line in gcode.split("\n"): |
| corrected = line |
| |
| x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| if x_match: |
| try: |
| x = float(x_match.group(1)) |
| x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
| corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
| except ValueError: |
| pass |
|
|
| y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| if y_match: |
| try: |
| y = float(y_match.group(1)) |
| y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
| corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
| except ValueError: |
| pass |
|
|
| lines.append(corrected) |
|
|
| return "\n".join(lines) |
|
|
|
|
| def gcode_to_svg(gcode: str) -> str: |
| """Convert gcode to SVG for visual preview.""" |
| paths = [] |
| current_path = [] |
| x, y = 0.0, 0.0 |
| pen_down = False |
|
|
| lines = [] |
| for line in gcode.split("\n"): |
| line = line.strip() |
| if not line: |
| continue |
| parts = re.split(r'(?=[GM]\d)', line) |
| for part in parts: |
| part = part.strip() |
| if part and not part.startswith(";"): |
| lines.append(part) |
| |
| for line in lines: |
| if "M280" in line.upper(): |
| match = re.search(r"S(\d+)", line, re.IGNORECASE) |
| if match: |
| angle = int(match.group(1)) |
| was_down = pen_down |
| pen_down = angle < 50 |
| if was_down and not pen_down and len(current_path) > 1: |
| paths.append(current_path[:]) |
| current_path = [] |
|
|
| x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| |
| if x_match: |
| try: |
| x = float(x_match.group(1)) |
| except ValueError: |
| pass |
| if y_match: |
| try: |
| y = float(y_match.group(1)) |
| except ValueError: |
| pass |
|
|
| if (x_match or y_match) and pen_down: |
| current_path.append((x, y)) |
|
|
| if len(current_path) > 1: |
| paths.append(current_path) |
|
|
| w = BOUNDS["right"] - BOUNDS["left"] |
| h = BOUNDS["top"] - BOUNDS["bottom"] |
| padding = 20 |
| |
| svg = f'''<svg xmlns="http://www.w3.org/2000/svg" |
| viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}" |
| style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;"> |
| <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" |
| fill="#fff" stroke="#ccc" stroke-width="2"/> |
| <line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/> |
| <line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/> |
| ''' |
|
|
| for path in paths: |
| if len(path) < 2: |
| continue |
| d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path)) |
| svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>' |
|
|
| total_points = sum(len(p) for p in paths) |
| svg += f''' |
| <text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14"> |
| Paths: {len(paths)} | Points: {total_points} |
| </text> |
| ''' |
| svg += "</svg>" |
| return svg |
|
|
|
|
| def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float): |
| """Generate gcode from text prompt via latent diffusion.""" |
| if not prompt or not prompt.strip(): |
| return "Enter a prompt to generate gcode", gcode_to_svg("") |
|
|
| try: |
| gen = get_generator() |
| pipe = gen["pipe"] |
| projector = gen["projector"] |
| decoder = gen["decoder"] |
| tokenizer = gen["tokenizer"] |
| device = gen["device"] |
| dtype = gen["dtype"] |
| max_seq_len = gen["max_seq_len"] |
| |
| |
| with torch.no_grad(): |
| result = pipe( |
| prompt, |
| num_inference_steps=num_steps, |
| guidance_scale=guidance, |
| output_type="latent", |
| ) |
| latent = result.images |
| |
| |
| with torch.no_grad(): |
| |
| latent_flat = latent.view(1, -1).to(dtype) |
| memory = projector(latent_flat).unsqueeze(1) |
| |
| |
| bos_id = tokenizer.bos_token_id or tokenizer.pad_token_id |
| eos_id = tokenizer.eos_token_id |
| |
| tokens = torch.tensor([[bos_id]], device=device) |
| |
| for _ in range(min(max_tokens, max_seq_len - 1)): |
| logits = decoder(tokens, memory) |
| next_logits = logits[:, -1, :] / temperature |
| probs = torch.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, 1) |
| tokens = torch.cat([tokens, next_token], dim=1) |
| |
| if next_token.item() == eos_id: |
| break |
| |
| gcode = tokenizer.decode(tokens[0], skip_special_tokens=True) |
| |
| gcode = validate_gcode(gcode) |
| line_count = len(gcode.split("\n")) |
| svg = gcode_to_svg(gcode) |
| |
| gcode_with_header = f"; dcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}" |
| return gcode_with_header, svg |
| |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"; Error: {e}", gcode_to_svg("") |
|
|
|
|
| |
| custom_css = """ |
| .gradio-container { |
| max-width: 1200px !important; |
| } |
| """ |
|
|
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo: |
| gr.Markdown(""" |
| # dcode |
| **Text → Polargraph Gcode via Latent Diffusion** |
| |
| Uses Stable Diffusion to generate latents from text, then decodes to machine gcode. |
| |
| [GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-latent-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode) |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="drawing of a cat, abstract spiral, portrait...", |
| lines=2 |
| ) |
| |
| with gr.Row(): |
| temperature = gr.Slider(0.5, 1.5, value=0.9, label="Temperature") |
| max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens") |
| |
| with gr.Row(): |
| num_steps = gr.Slider(10, 50, value=25, step=5, label="Diffusion Steps") |
| guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale") |
| |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") |
| |
| gr.Examples( |
| examples=[ |
| ["line drawing of a cat"], |
| ["abstract spiral pattern"], |
| ["simple house with chimney"], |
| ["portrait sketch"], |
| ["geometric shapes and lines"], |
| ], |
| inputs=prompt, |
| ) |
| |
| with gr.Column(scale=2): |
| preview = gr.HTML( |
| value=gcode_to_svg(""), |
| label="Preview", |
| ) |
| |
| with gr.Accordion("Gcode Output", open=False): |
| gcode_output = gr.Code(label="Gcode", language=None, lines=15) |
| |
| gr.Markdown(""" |
| --- |
| **Machine Bounds**: X: ±420.5mm, Y: ±594.5mm | Pen servo: 40° (down) / 90° (up) | **License**: MIT |
| """) |
| |
| generate_btn.click( |
| generate, |
| [prompt, temperature, max_tokens, num_steps, guidance], |
| [gcode_output, preview] |
| ) |
| prompt.submit( |
| generate, |
| [prompt, temperature, max_tokens, num_steps, guidance], |
| [gcode_output, preview] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|