|
|
"""dcode Gradio Space - Text to Gcode inference.""" |
|
|
|
|
|
import re |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
|
|
|
MODEL_ID = "twarner/dcode-flan-t5-base" |
|
|
|
|
|
|
|
|
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
|
|
|
|
|
|
|
|
class GcodeGenerator: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def load(self): |
|
|
if self.model is None: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def generate(self, prompt: str, max_length: int = 1024, temperature: float = 0.8) -> str: |
|
|
self.load() |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
gcode = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return self.validate(gcode) |
|
|
|
|
|
def validate(self, gcode: str) -> str: |
|
|
"""Clamp coordinates to machine bounds.""" |
|
|
lines = [] |
|
|
for line in gcode.split("\n"): |
|
|
corrected = line |
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
if x_match: |
|
|
x = float(x_match.group(1)) |
|
|
x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
|
|
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
|
|
|
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
if y_match: |
|
|
y = float(y_match.group(1)) |
|
|
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
|
|
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
|
|
|
|
|
lines.append(corrected) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
generator = GcodeGenerator() |
|
|
|
|
|
|
|
|
def generate(prompt: str, temperature: float) -> str: |
|
|
"""Generate gcode from prompt.""" |
|
|
if not prompt or not prompt.strip(): |
|
|
return "Enter a prompt to generate gcode" |
|
|
|
|
|
try: |
|
|
gcode = generator.generate(prompt, temperature=temperature) |
|
|
line_count = len(gcode.split("\n")) |
|
|
return f"; dcode output - {line_count} lines\n; Machine validated\n\n{gcode}" |
|
|
except Exception as e: |
|
|
return f"; Error: {e}" |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Prompt", placeholder="drawing of a cat..."), |
|
|
gr.Slider(0.1, 1.5, value=0.8, label="Temperature"), |
|
|
], |
|
|
outputs=gr.Textbox(label="Gcode", lines=20, show_copy_button=True), |
|
|
title="dcode", |
|
|
description="Text prompt → Polargraph Gcode. Generate machine-compatible gcode from natural language descriptions.", |
|
|
examples=[ |
|
|
["drawing of a cat", 0.8], |
|
|
["abstract spiral pattern", 0.9], |
|
|
["simple house with chimney", 0.7], |
|
|
], |
|
|
theme=gr.themes.Base(primary_hue="green"), |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|