|
|
"""dcode Gradio Space - Text to Gcode inference.""" |
|
|
|
|
|
import re |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
|
|
|
MODEL_ID = "twarner/dcode-flan-t5-base" |
|
|
|
|
|
|
|
|
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
|
|
PEN = {"up": 90, "down": 40, "travel": 1000, "draw": 500} |
|
|
|
|
|
|
|
|
class GcodeGenerator: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def load(self): |
|
|
if self.model is None: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def generate(self, prompt: str, max_length: int = 1024, temperature: float = 0.8) -> str: |
|
|
self.load() |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
gcode = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return self.validate(gcode) |
|
|
|
|
|
def validate(self, gcode: str) -> str: |
|
|
"""Clamp coordinates to machine bounds.""" |
|
|
lines = [] |
|
|
for line in gcode.split("\n"): |
|
|
corrected = line |
|
|
|
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
if x_match: |
|
|
x = float(x_match.group(1)) |
|
|
x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
|
|
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
if y_match: |
|
|
y = float(y_match.group(1)) |
|
|
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
|
|
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
|
|
|
|
|
lines.append(corrected) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def gcode_to_svg(gcode: str) -> str: |
|
|
"""Convert gcode to SVG for preview.""" |
|
|
paths = [] |
|
|
current_path = [] |
|
|
x, y = 0, 0 |
|
|
pen_down = False |
|
|
|
|
|
for line in gcode.split("\n"): |
|
|
line = line.strip() |
|
|
if not line or line.startswith(";"): |
|
|
continue |
|
|
|
|
|
|
|
|
if "M280" in line: |
|
|
match = re.search(r"S(\d+)", line) |
|
|
if match: |
|
|
angle = int(match.group(1)) |
|
|
was_down = pen_down |
|
|
pen_down = angle < 50 |
|
|
if was_down and not pen_down and len(current_path) > 1: |
|
|
paths.append(current_path[:]) |
|
|
current_path = [] |
|
|
|
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
if x_match: |
|
|
x = float(x_match.group(1)) |
|
|
if y_match: |
|
|
y = float(y_match.group(1)) |
|
|
|
|
|
if (x_match or y_match) and pen_down: |
|
|
current_path.append((x, y)) |
|
|
|
|
|
if len(current_path) > 1: |
|
|
paths.append(current_path) |
|
|
|
|
|
|
|
|
w = BOUNDS["right"] - BOUNDS["left"] |
|
|
h = BOUNDS["top"] - BOUNDS["bottom"] |
|
|
|
|
|
svg = f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="{BOUNDS["left"]} {-BOUNDS["top"]} {w} {h}" style="background:#111">' |
|
|
svg += f'<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="none" stroke="#333" stroke-width="2"/>' |
|
|
|
|
|
for path in paths: |
|
|
if len(path) < 2: |
|
|
continue |
|
|
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path)) |
|
|
svg += f'<path d="{d}" fill="none" stroke="#4ade80" stroke-width="1.5" stroke-linecap="round"/>' |
|
|
|
|
|
svg += "</svg>" |
|
|
return svg |
|
|
|
|
|
|
|
|
|
|
|
generator = GcodeGenerator() |
|
|
|
|
|
|
|
|
def generate(prompt: str, temperature: float) -> tuple[str, str, str]: |
|
|
"""Generate gcode from prompt.""" |
|
|
if not prompt.strip(): |
|
|
return "", "", "Enter a prompt" |
|
|
|
|
|
try: |
|
|
gcode = generator.generate(prompt, temperature=temperature) |
|
|
svg = gcode_to_svg(gcode) |
|
|
status = f"✓ Generated {len(gcode.split(chr(10)))} lines, machine compatible" |
|
|
return gcode, svg, status |
|
|
except Exception as e: |
|
|
return "", "", f"Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
#workplane { |
|
|
background: #111; |
|
|
border-radius: 8px; |
|
|
min-height: 500px; |
|
|
} |
|
|
#workplane svg { |
|
|
width: 100%; |
|
|
height: 100%; |
|
|
} |
|
|
.status { |
|
|
font-family: monospace; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Base(primary_hue="green")) as demo: |
|
|
gr.Markdown("# dcode\nText prompt → Polargraph Gcode") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt = gr.Textbox(label="Prompt", placeholder="drawing of a cat...", lines=1) |
|
|
temperature = gr.Slider(0.1, 1.5, value=0.8, label="Temperature") |
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
status = gr.Textbox(label="Status", interactive=False, elem_classes=["status"]) |
|
|
|
|
|
with gr.Row(): |
|
|
preview = gr.HTML(elem_id="workplane") |
|
|
|
|
|
with gr.Row(): |
|
|
gcode_output = gr.Textbox(label="Gcode", lines=10, show_copy_button=True) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.DownloadButton("Download Gcode", value=lambda: None, visible=False) |
|
|
|
|
|
generate_btn.click(generate, [prompt, temperature], [gcode_output, preview, status]) |
|
|
prompt.submit(generate, [prompt, temperature], [gcode_output, preview, status]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|