dcode / app.py
twarner's picture
Light mode preview with dark lines
32c92a7
raw
history blame
9.5 kB
"""dcode Gradio Space - Text to Gcode inference with visual preview."""
import re
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
# Available models
MODELS = {
"flan-t5-base (best)": "twarner/dcode-flan-t5-base",
}
# Machine limits
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
# Cache loaded models
_model_cache = {}
def get_model(model_name: str):
"""Load and cache model."""
if model_name not in _model_cache:
model_id = MODELS[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_id)
if "gpt2" in model_id or "codegen" in model_id:
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
model.eval()
_model_cache[model_name] = (model, tokenizer, device)
return _model_cache[model_name]
def validate_gcode(gcode: str) -> str:
"""Clamp coordinates to machine bounds."""
lines = []
for line in gcode.split("\n"):
corrected = line
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
x = max(BOUNDS["left"], min(BOUNDS["right"], x))
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if y_match:
try:
y = float(y_match.group(1))
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
lines.append(corrected)
return "\n".join(lines)
def gcode_to_svg(gcode: str) -> str:
"""Convert gcode to SVG for visual preview."""
paths = []
current_path = []
x, y = 0.0, 0.0
pen_down = False
for line in gcode.split("\n"):
line = line.strip()
if not line or line.startswith(";"):
continue
# Pen state from M280 servo commands
if "M280" in line.upper():
match = re.search(r"S(\d+)", line, re.IGNORECASE)
if match:
angle = int(match.group(1))
was_down = pen_down
pen_down = angle < 50 # 40 = down, 90 = up
if was_down and not pen_down and len(current_path) > 1:
paths.append(current_path[:])
current_path = []
# Position from G0/G1 commands
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
except ValueError:
pass
if y_match:
try:
y = float(y_match.group(1))
except ValueError:
pass
if (x_match or y_match) and pen_down:
current_path.append((x, y))
if len(current_path) > 1:
paths.append(current_path)
# Build SVG - light mode with dark lines
w = BOUNDS["right"] - BOUNDS["left"]
h = BOUNDS["top"] - BOUNDS["bottom"]
padding = 20
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
<!-- Work area border -->
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
fill="#fff" stroke="#ccc" stroke-width="2"/>
<!-- Center crosshair -->
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
<!-- Grid -->
<defs>
<pattern id="grid" width="100" height="100" patternUnits="userSpaceOnUse">
<path d="M 100 0 L 0 0 0 100" fill="none" stroke="#eee" stroke-width="0.5"/>
</pattern>
</defs>
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="url(#grid)"/>
'''
# Draw paths - dark lines
for path in paths:
if len(path) < 2:
continue
# SVG Y is inverted
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
# Stats
total_points = sum(len(p) for p in paths)
svg += f'''
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
Paths: {len(paths)} | Points: {total_points}
</text>
'''
svg += "</svg>"
return svg
def generate(prompt: str, model_name: str, temperature: float, max_tokens: int):
"""Generate gcode from prompt and return both code and visualization."""
if not prompt or not prompt.strip():
empty_svg = gcode_to_svg("")
return "Enter a prompt to generate gcode", empty_svg
try:
model, tokenizer, device = get_model(model_name)
model_id = MODELS[model_name]
inputs = tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
# For causal models, skip the input tokens
if "gpt2" in model_id or "codegen" in model_id:
gcode = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
else:
gcode = tokenizer.decode(outputs[0], skip_special_tokens=True)
gcode = validate_gcode(gcode)
line_count = len(gcode.split("\n"))
# Generate SVG preview
svg = gcode_to_svg(gcode)
gcode_with_header = f"; dcode output - {line_count} lines\n; Model: {model_name}\n; Machine validated\n\n{gcode}"
return gcode_with_header, svg
except Exception as e:
error_svg = gcode_to_svg("")
return f"; Error: {e}", error_svg
# Custom CSS
custom_css = """
#preview-container {
background: #0a0a0a;
border-radius: 8px;
padding: 0;
}
.gradio-container {
max-width: 1200px !important;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
gr.Markdown("""
# dcode
**Text → Polargraph Gcode** | Generate machine-compatible gcode from natural language.
[GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-flan-t5-base) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="drawing of a cat, abstract spiral, portrait...",
lines=2
)
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="flan-t5-base (best)",
label="Model"
)
with gr.Row():
temperature = gr.Slider(0.1, 1.5, value=0.8, label="Temperature", info="Higher = more creative")
max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="Max Tokens")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
gr.Examples(
examples=[
["drawing of a cat"],
["abstract spiral pattern"],
["simple house with chimney"],
["portrait of a woman"],
["geometric shapes"],
],
inputs=prompt,
)
with gr.Column(scale=2):
preview = gr.HTML(
value=gcode_to_svg(""),
label="Preview",
elem_id="preview-container"
)
with gr.Accordion("Gcode Output", open=False):
gcode_output = gr.Code(label="Gcode", language=None, lines=15)
gr.Markdown("""
---
**Machine Bounds**: X: ±420.5mm, Y: ±594.5mm | Pen servo: 40° (down) / 90° (up) | **License**: MIT
""")
generate_btn.click(
generate,
[prompt, model_dropdown, temperature, max_tokens],
[gcode_output, preview]
)
prompt.submit(
generate,
[prompt, model_dropdown, temperature, max_tokens],
[gcode_output, preview]
)
if __name__ == "__main__":
demo.launch()