twarner commited on
Commit
0073cbe
·
1 Parent(s): 2f64c24

Add model dropdown, improve UI

Browse files
Files changed (1) hide show
  1. app.py +89 -63
app.py CHANGED
@@ -3,82 +3,94 @@
3
  import re
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
 
8
- # Model config
9
- MODEL_ID = "twarner/dcode-flan-t5-base"
 
 
10
 
11
  # Machine limits
12
  BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
13
 
 
 
14
 
15
- class GcodeGenerator:
16
- def __init__(self):
17
- self.model = None
18
- self.tokenizer = None
19
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- def load(self):
22
- if self.model is None:
23
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
25
- MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
26
- ).to(self.device)
27
- self.model.eval()
28
-
29
- def generate(self, prompt: str, max_length: int = 1024, temperature: float = 0.8) -> str:
30
- self.load()
31
 
32
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
33
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
34
-
35
- with torch.no_grad():
36
- outputs = self.model.generate(
37
- **inputs,
38
- max_new_tokens=max_length,
39
- do_sample=True,
40
- temperature=temperature,
41
- top_p=0.9,
42
- )
43
-
44
- gcode = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return self.validate(gcode)
46
-
47
- def validate(self, gcode: str) -> str:
48
- """Clamp coordinates to machine bounds."""
49
- lines = []
50
- for line in gcode.split("\n"):
51
- corrected = line
52
-
53
- x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
54
- if x_match:
55
- x = float(x_match.group(1))
56
- x = max(BOUNDS["left"], min(BOUNDS["right"], x))
57
- corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
58
 
59
- y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
60
- if y_match:
61
- y = float(y_match.group(1))
62
- y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
63
- corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
64
 
65
- lines.append(corrected)
 
 
 
 
 
 
 
 
 
 
66
 
67
- return "\n".join(lines)
 
 
 
 
68
 
 
69
 
70
- generator = GcodeGenerator()
71
 
72
 
73
- def generate(prompt: str, temperature: float) -> str:
74
  """Generate gcode from prompt."""
75
  if not prompt or not prompt.strip():
76
  return "Enter a prompt to generate gcode"
77
 
78
  try:
79
- gcode = generator.generate(prompt, temperature=temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  line_count = len(gcode.split("\n"))
81
- return f"; dcode output - {line_count} lines\n; Machine validated\n\n{gcode}"
82
  except Exception as e:
83
  return f"; Error: {e}"
84
 
@@ -86,18 +98,32 @@ def generate(prompt: str, temperature: float) -> str:
86
  demo = gr.Interface(
87
  fn=generate,
88
  inputs=[
89
- gr.Textbox(label="Prompt", placeholder="drawing of a cat..."),
90
- gr.Slider(0.1, 1.5, value=0.8, label="Temperature"),
 
 
91
  ],
92
- outputs=gr.Textbox(label="Gcode", lines=20, show_copy_button=True),
93
  title="dcode",
94
- description="Text prompt → Polargraph Gcode. Generate machine-compatible gcode from natural language descriptions.",
95
  examples=[
96
- ["drawing of a cat", 0.8],
97
- ["abstract spiral pattern", 0.9],
98
- ["simple house with chimney", 0.7],
 
99
  ],
100
- theme=gr.themes.Base(primary_hue="green"),
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
 
103
  if __name__ == "__main__":
 
3
  import re
4
  import gradio as gr
5
  import torch
6
+ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
7
 
8
+ # Available models
9
+ MODELS = {
10
+ "flan-t5-base (best)": "twarner/dcode-flan-t5-base",
11
+ }
12
 
13
  # Machine limits
14
  BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
15
 
16
+ # Cache loaded models
17
+ _model_cache = {}
18
 
 
 
 
 
 
19
 
20
+ def get_model(model_name: str):
21
+ """Load and cache model."""
22
+ if model_name not in _model_cache:
23
+ model_id = MODELS[model_name]
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ dtype = torch.float16 if device == "cuda" else torch.float32
 
 
 
 
26
 
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+
29
+ if "gpt2" in model_id or "codegen" in model_id:
30
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
31
+ else:
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
33
+
34
+ model.eval()
35
+ _model_cache[model_name] = (model, tokenizer, device)
36
+
37
+ return _model_cache[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
39
 
40
+ def validate_gcode(gcode: str) -> str:
41
+ """Clamp coordinates to machine bounds."""
42
+ lines = []
43
+ for line in gcode.split("\n"):
44
+ corrected = line
45
+
46
+ x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
47
+ if x_match:
48
+ x = float(x_match.group(1))
49
+ x = max(BOUNDS["left"], min(BOUNDS["right"], x))
50
+ corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
51
 
52
+ y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
53
+ if y_match:
54
+ y = float(y_match.group(1))
55
+ y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
56
+ corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
57
 
58
+ lines.append(corrected)
59
 
60
+ return "\n".join(lines)
61
 
62
 
63
+ def generate(prompt: str, model_name: str, temperature: float, max_tokens: int) -> str:
64
  """Generate gcode from prompt."""
65
  if not prompt or not prompt.strip():
66
  return "Enter a prompt to generate gcode"
67
 
68
  try:
69
+ model, tokenizer, device = get_model(model_name)
70
+ model_id = MODELS[model_name]
71
+
72
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
73
+ inputs = {k: v.to(device) for k, v in inputs.items()}
74
+
75
+ with torch.no_grad():
76
+ outputs = model.generate(
77
+ **inputs,
78
+ max_new_tokens=max_tokens,
79
+ do_sample=True,
80
+ temperature=temperature,
81
+ top_p=0.9,
82
+ pad_token_id=tokenizer.eos_token_id,
83
+ )
84
+
85
+ # For causal models, skip the input tokens
86
+ if "gpt2" in model_id or "codegen" in model_id:
87
+ gcode = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
88
+ else:
89
+ gcode = tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+
91
+ gcode = validate_gcode(gcode)
92
  line_count = len(gcode.split("\n"))
93
+ return f"; dcode output - {line_count} lines\n; Model: {model_name}\n; Machine validated\n\n{gcode}"
94
  except Exception as e:
95
  return f"; Error: {e}"
96
 
 
98
  demo = gr.Interface(
99
  fn=generate,
100
  inputs=[
101
+ gr.Textbox(label="Prompt", placeholder="drawing of a cat...", lines=2),
102
+ gr.Dropdown(choices=list(MODELS.keys()), value="flan-t5-base (best)", label="Model"),
103
+ gr.Slider(0.1, 1.5, value=0.8, label="Temperature", info="Higher = more creative"),
104
+ gr.Slider(256, 2048, value=1024, step=256, label="Max Tokens"),
105
  ],
106
+ outputs=gr.Code(label="Gcode", language=None, lines=25),
107
  title="dcode",
108
+ description="**Text → Polargraph Gcode** | Generate machine-compatible gcode from natural language. [GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-flan-t5-base) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)",
109
  examples=[
110
+ ["drawing of a cat", "flan-t5-base (best)", 0.8, 1024],
111
+ ["abstract spiral pattern", "flan-t5-base (best)", 0.9, 1024],
112
+ ["simple house with chimney", "flan-t5-base (best)", 0.7, 512],
113
+ ["portrait of a woman", "flan-t5-base (best)", 0.8, 1024],
114
  ],
115
+ theme=gr.themes.Soft(primary_hue="emerald"),
116
+ article="""
117
+ ## About
118
+
119
+ dcode finetunes text-to-text models to directly output polargraph-compatible gcode from natural language descriptions.
120
+
121
+ **Training**: Flan-T5-base trained on 175,952 art-caption-gcode triplets for 20 epochs on H100.
122
+
123
+ **Machine Bounds**: X: ±420.5mm, Y: ±594.5mm | Pen servo: 40° (down) / 90° (up)
124
+
125
+ **License**: MIT
126
+ """,
127
  )
128
 
129
  if __name__ == "__main__":