dcode / app.py
twarner's picture
Gradio inference
1d11bfa
raw
history blame
6.07 kB
"""dcode Gradio Space - Text to Gcode inference."""
import re
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Model config - update after training
MODEL_ID = "twarner/dcode-flan-t5-base" # Will upload after training
# Machine limits
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
PEN = {"up": 90, "down": 40, "travel": 1000, "draw": 500}
class GcodeGenerator:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load(self):
if self.model is None:
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.model.eval()
def generate(self, prompt: str, max_length: int = 1024, temperature: float = 0.8) -> str:
self.load()
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
top_p=0.9,
)
gcode = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return self.validate(gcode)
def validate(self, gcode: str) -> str:
"""Clamp coordinates to machine bounds."""
lines = []
for line in gcode.split("\n"):
corrected = line
# Clamp X
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
if x_match:
x = float(x_match.group(1))
x = max(BOUNDS["left"], min(BOUNDS["right"], x))
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
# Clamp Y
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if y_match:
y = float(y_match.group(1))
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
lines.append(corrected)
return "\n".join(lines)
def gcode_to_svg(gcode: str) -> str:
"""Convert gcode to SVG for preview."""
paths = []
current_path = []
x, y = 0, 0
pen_down = False
for line in gcode.split("\n"):
line = line.strip()
if not line or line.startswith(";"):
continue
# Pen state
if "M280" in line:
match = re.search(r"S(\d+)", line)
if match:
angle = int(match.group(1))
was_down = pen_down
pen_down = angle < 50
if was_down and not pen_down and len(current_path) > 1:
paths.append(current_path[:])
current_path = []
# Position
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
x = float(x_match.group(1))
if y_match:
y = float(y_match.group(1))
if (x_match or y_match) and pen_down:
current_path.append((x, y))
if len(current_path) > 1:
paths.append(current_path)
# Build SVG
w = BOUNDS["right"] - BOUNDS["left"]
h = BOUNDS["top"] - BOUNDS["bottom"]
svg = f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{BOUNDS["left"]} {-BOUNDS["top"]} {w} {h}" style="background:#111">'
svg += f'<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="none" stroke="#333" stroke-width="2"/>'
for path in paths:
if len(path) < 2:
continue
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
svg += f'<path d="{d}" fill="none" stroke="#4ade80" stroke-width="1.5" stroke-linecap="round"/>'
svg += "</svg>"
return svg
# Initialize generator
generator = GcodeGenerator()
def generate(prompt: str, temperature: float) -> tuple[str, str, str]:
"""Generate gcode from prompt."""
if not prompt.strip():
return "", "", "Enter a prompt"
try:
gcode = generator.generate(prompt, temperature=temperature)
svg = gcode_to_svg(gcode)
status = f"✓ Generated {len(gcode.split(chr(10)))} lines, machine compatible"
return gcode, svg, status
except Exception as e:
return "", "", f"Error: {e}"
# Custom CSS
custom_css = """
#workplane {
background: #111;
border-radius: 8px;
min-height: 500px;
}
#workplane svg {
width: 100%;
height: 100%;
}
.status {
font-family: monospace;
}
"""
# Gradio UI
with gr.Blocks(css=custom_css, theme=gr.themes.Base(primary_hue="green")) as demo:
gr.Markdown("# dcode\nText prompt → Polargraph Gcode")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt", placeholder="drawing of a cat...", lines=1)
temperature = gr.Slider(0.1, 1.5, value=0.8, label="Temperature")
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
status = gr.Textbox(label="Status", interactive=False, elem_classes=["status"])
with gr.Row():
preview = gr.HTML(elem_id="workplane")
with gr.Row():
gcode_output = gr.Textbox(label="Gcode", lines=10, show_copy_button=True)
with gr.Row():
gr.DownloadButton("Download Gcode", value=lambda: None, visible=False)
generate_btn.click(generate, [prompt, temperature], [gcode_output, preview, status])
prompt.submit(generate, [prompt, temperature], [gcode_output, preview, status])
if __name__ == "__main__":
demo.launch()