|
|
"""dcode Gradio Space - Text to Gcode inference with visual preview.""" |
|
|
|
|
|
import re |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"flan-t5-base (best)": "twarner/dcode-flan-t5-base", |
|
|
} |
|
|
|
|
|
|
|
|
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
|
|
|
def get_model(model_name: str): |
|
|
"""Load and cache model.""" |
|
|
if model_name not in _model_cache: |
|
|
model_id = MODELS[model_name] |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
if "gpt2" in model_id or "codegen" in model_id: |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device) |
|
|
else: |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device) |
|
|
|
|
|
model.eval() |
|
|
_model_cache[model_name] = (model, tokenizer, device) |
|
|
|
|
|
return _model_cache[model_name] |
|
|
|
|
|
|
|
|
def validate_gcode(gcode: str) -> str: |
|
|
"""Clamp coordinates to machine bounds.""" |
|
|
lines = [] |
|
|
for line in gcode.split("\n"): |
|
|
corrected = line |
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
if x_match: |
|
|
try: |
|
|
x = float(x_match.group(1)) |
|
|
x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
|
|
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
if y_match: |
|
|
try: |
|
|
y = float(y_match.group(1)) |
|
|
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
|
|
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
lines.append(corrected) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def gcode_to_svg(gcode: str) -> str: |
|
|
"""Convert gcode to SVG for visual preview.""" |
|
|
paths = [] |
|
|
current_path = [] |
|
|
x, y = 0.0, 0.0 |
|
|
pen_down = False |
|
|
|
|
|
for line in gcode.split("\n"): |
|
|
line = line.strip() |
|
|
if not line or line.startswith(";"): |
|
|
continue |
|
|
|
|
|
|
|
|
if "M280" in line.upper(): |
|
|
match = re.search(r"S(\d+)", line, re.IGNORECASE) |
|
|
if match: |
|
|
angle = int(match.group(1)) |
|
|
was_down = pen_down |
|
|
pen_down = angle < 50 |
|
|
if was_down and not pen_down and len(current_path) > 1: |
|
|
paths.append(current_path[:]) |
|
|
current_path = [] |
|
|
|
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
|
|
|
if x_match: |
|
|
try: |
|
|
x = float(x_match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
if y_match: |
|
|
try: |
|
|
y = float(y_match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
if (x_match or y_match) and pen_down: |
|
|
current_path.append((x, y)) |
|
|
|
|
|
if len(current_path) > 1: |
|
|
paths.append(current_path) |
|
|
|
|
|
|
|
|
w = BOUNDS["right"] - BOUNDS["left"] |
|
|
h = BOUNDS["top"] - BOUNDS["bottom"] |
|
|
padding = 20 |
|
|
|
|
|
svg = f'''<svg xmlns="http://www.w3.org/2000/svg" |
|
|
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}" |
|
|
style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;"> |
|
|
<!-- Work area border --> |
|
|
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" |
|
|
fill="#fff" stroke="#ccc" stroke-width="2"/> |
|
|
<!-- Center crosshair --> |
|
|
<line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/> |
|
|
<line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/> |
|
|
<!-- Grid --> |
|
|
<defs> |
|
|
<pattern id="grid" width="100" height="100" patternUnits="userSpaceOnUse"> |
|
|
<path d="M 100 0 L 0 0 0 100" fill="none" stroke="#eee" stroke-width="0.5"/> |
|
|
</pattern> |
|
|
</defs> |
|
|
<rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="url(#grid)"/> |
|
|
''' |
|
|
|
|
|
|
|
|
for path in paths: |
|
|
if len(path) < 2: |
|
|
continue |
|
|
|
|
|
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path)) |
|
|
svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>' |
|
|
|
|
|
|
|
|
total_points = sum(len(p) for p in paths) |
|
|
svg += f''' |
|
|
<text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14"> |
|
|
Paths: {len(paths)} | Points: {total_points} |
|
|
</text> |
|
|
''' |
|
|
|
|
|
svg += "</svg>" |
|
|
return svg |
|
|
|
|
|
|
|
|
def generate(prompt: str, model_name: str, temperature: float, max_tokens: int): |
|
|
"""Generate gcode from prompt and return both code and visualization.""" |
|
|
if not prompt or not prompt.strip(): |
|
|
empty_svg = gcode_to_svg("") |
|
|
return "Enter a prompt to generate gcode", empty_svg |
|
|
|
|
|
try: |
|
|
model, tokenizer, device = get_model(model_name) |
|
|
model_id = MODELS[model_name] |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
if "gpt2" in model_id or "codegen" in model_id: |
|
|
gcode = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
|
|
else: |
|
|
gcode = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
gcode = validate_gcode(gcode) |
|
|
line_count = len(gcode.split("\n")) |
|
|
|
|
|
|
|
|
svg = gcode_to_svg(gcode) |
|
|
|
|
|
gcode_with_header = f"; dcode output - {line_count} lines\n; Model: {model_name}\n; Machine validated\n\n{gcode}" |
|
|
return gcode_with_header, svg |
|
|
|
|
|
except Exception as e: |
|
|
error_svg = gcode_to_svg("") |
|
|
return f"; Error: {e}", error_svg |
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
#preview-container { |
|
|
background: #0a0a0a; |
|
|
border-radius: 8px; |
|
|
padding: 0; |
|
|
} |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo: |
|
|
gr.Markdown(""" |
|
|
# dcode |
|
|
**Text → Polargraph Gcode** | Generate machine-compatible gcode from natural language. |
|
|
|
|
|
[GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-flan-t5-base) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="drawing of a cat, abstract spiral, portrait...", |
|
|
lines=2 |
|
|
) |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODELS.keys()), |
|
|
value="flan-t5-base (best)", |
|
|
label="Model" |
|
|
) |
|
|
with gr.Row(): |
|
|
temperature = gr.Slider(0.1, 1.5, value=0.8, label="Temperature", info="Higher = more creative") |
|
|
max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="Max Tokens") |
|
|
|
|
|
generate_btn = gr.Button("Generate", variant="primary", size="lg") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["drawing of a cat"], |
|
|
["abstract spiral pattern"], |
|
|
["simple house with chimney"], |
|
|
["portrait of a woman"], |
|
|
["geometric shapes"], |
|
|
], |
|
|
inputs=prompt, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
preview = gr.HTML( |
|
|
value=gcode_to_svg(""), |
|
|
label="Preview", |
|
|
elem_id="preview-container" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Gcode Output", open=False): |
|
|
gcode_output = gr.Code(label="Gcode", language=None, lines=15) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Machine Bounds**: X: ±420.5mm, Y: ±594.5mm | Pen servo: 40° (down) / 90° (up) | **License**: MIT |
|
|
""") |
|
|
|
|
|
generate_btn.click( |
|
|
generate, |
|
|
[prompt, model_dropdown, temperature, max_tokens], |
|
|
[gcode_output, preview] |
|
|
) |
|
|
prompt.submit( |
|
|
generate, |
|
|
[prompt, model_dropdown, temperature, max_tokens], |
|
|
[gcode_output, preview] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|