twarner commited on
Commit
bbd6111
·
1 Parent(s): f1b3e74

Update to latent-gcode diffusion model

Browse files
Files changed (3) hide show
  1. README.md +22 -7
  2. app.py +157 -89
  3. requirements.txt +3 -0
README.md CHANGED
@@ -4,24 +4,39 @@ emoji: ✏️
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Text to Polargraph Gcode
 
12
  ---
13
 
14
  # dcode
15
 
16
- Generate polargraph-compatible gcode from text prompts using finetuned diffusion models.
 
 
 
 
 
 
17
 
18
  ## Usage
19
 
20
- 1. Enter a prompt (e.g., "drawing of a cat")
21
- 2. Adjust temperature (higher = more creative)
22
  3. Click Generate
23
- 4. View preview and download gcode
24
 
25
  ## Model
26
 
27
- Finetuned Flan-T5-base on 175k image-gcode pairs.
 
 
 
 
 
 
 
 
 
4
  colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ hardware: t4-small
12
+ short_description: Text to Polargraph Gcode via Latent Diffusion
13
  ---
14
 
15
  # dcode
16
 
17
+ Generate polargraph-compatible gcode from text prompts using latent diffusion.
18
+
19
+ ## How it works
20
+
21
+ 1. **Text → Latent**: Stable Diffusion generates a latent representation from your text prompt
22
+ 2. **Latent → Gcode**: Custom transformer decoder converts the latent to gcode commands
23
+ 3. **Validation**: Coordinates are clamped to machine bounds
24
 
25
  ## Usage
26
 
27
+ 1. Enter a prompt (e.g., "line drawing of a cat")
28
+ 2. Adjust diffusion steps and guidance scale
29
  3. Click Generate
30
+ 4. View preview and copy gcode
31
 
32
  ## Model
33
 
34
+ - Base: Stable Diffusion 2.1
35
+ - Decoder: 6-layer transformer trained on 175k image-gcode pairs
36
+ - Final loss: 0.107
37
+
38
+ ## Links
39
+
40
+ - [Model](https://huggingface.co/twarner/dcode-latent-gcode)
41
+ - [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
42
+ - [GitHub](https://github.com/Twarner491/dcode)
app.py CHANGED
@@ -1,40 +1,113 @@
1
- """dcode Gradio Space - Text to Gcode inference with visual preview."""
2
 
3
  import re
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
7
-
8
- # Available models
9
- MODELS = {
10
- "flan-t5-base (best)": "twarner/dcode-flan-t5-base",
11
- }
12
 
13
  # Machine limits
14
  BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
15
 
16
- # Cache loaded models
17
- _model_cache = {}
18
 
19
 
20
- def get_model(model_name: str):
21
- """Load and cache model."""
22
- if model_name not in _model_cache:
23
- model_id = MODELS[model_name]
 
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  dtype = torch.float16 if device == "cuda" else torch.float32
26
 
27
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- if "gpt2" in model_id or "codegen" in model_id:
30
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
31
- else:
32
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- model.eval()
35
- _model_cache[model_name] = (model, tokenizer, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- return _model_cache[model_name]
38
 
39
 
40
  def validate_gcode(gcode: str) -> str:
@@ -73,14 +146,11 @@ def gcode_to_svg(gcode: str) -> str:
73
  x, y = 0.0, 0.0
74
  pen_down = False
75
 
76
- # Split on newlines first, then also split commands that may be on same line
77
- # Handle gcode that's all on one line by splitting on G0/G1/M commands
78
  lines = []
79
  for line in gcode.split("\n"):
80
  line = line.strip()
81
  if not line:
82
  continue
83
- # Split on gcode commands (G0, G1, G28, M280, etc.)
84
  parts = re.split(r'(?=[GM]\d)', line)
85
  for part in parts:
86
  part = part.strip()
@@ -88,19 +158,16 @@ def gcode_to_svg(gcode: str) -> str:
88
  lines.append(part)
89
 
90
  for line in lines:
91
-
92
- # Pen state from M280 servo commands
93
  if "M280" in line.upper():
94
  match = re.search(r"S(\d+)", line, re.IGNORECASE)
95
  if match:
96
  angle = int(match.group(1))
97
  was_down = pen_down
98
- pen_down = angle < 50 # 40 = down, 90 = up
99
  if was_down and not pen_down and len(current_path) > 1:
100
  paths.append(current_path[:])
101
  current_path = []
102
 
103
- # Position from G0/G1 commands
104
  x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
105
  y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
106
 
@@ -121,7 +188,6 @@ def gcode_to_svg(gcode: str) -> str:
121
  if len(current_path) > 1:
122
  paths.append(current_path)
123
 
124
- # Build SVG - light mode with dark lines
125
  w = BOUNDS["right"] - BOUNDS["left"]
126
  h = BOUNDS["top"] - BOUNDS["bottom"]
127
  padding = 20
@@ -129,91 +195,92 @@ def gcode_to_svg(gcode: str) -> str:
129
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
130
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
131
  style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
132
- <!-- Work area border -->
133
  <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
134
  fill="#fff" stroke="#ccc" stroke-width="2"/>
135
- <!-- Center crosshair -->
136
  <line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
137
  <line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
138
- <!-- Grid -->
139
- <defs>
140
- <pattern id="grid" width="100" height="100" patternUnits="userSpaceOnUse">
141
- <path d="M 100 0 L 0 0 0 100" fill="none" stroke="#eee" stroke-width="0.5"/>
142
- </pattern>
143
- </defs>
144
- <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" fill="url(#grid)"/>
145
  '''
146
 
147
- # Draw paths - dark lines
148
  for path in paths:
149
  if len(path) < 2:
150
  continue
151
- # SVG Y is inverted
152
  d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
153
  svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
154
 
155
- # Stats
156
  total_points = sum(len(p) for p in paths)
157
  svg += f'''
158
  <text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
159
  Paths: {len(paths)} | Points: {total_points}
160
  </text>
161
  '''
162
-
163
  svg += "</svg>"
164
  return svg
165
 
166
 
167
- def generate(prompt: str, model_name: str, temperature: float, max_tokens: int):
168
- """Generate gcode from prompt and return both code and visualization."""
169
  if not prompt or not prompt.strip():
170
- empty_svg = gcode_to_svg("")
171
- return "Enter a prompt to generate gcode", empty_svg
172
 
173
  try:
174
- model, tokenizer, device = get_model(model_name)
175
- model_id = MODELS[model_name]
 
 
 
 
 
 
176
 
177
- inputs = tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True)
178
- inputs = {k: v.to(device) for k, v in inputs.items()}
179
-
180
  with torch.no_grad():
181
- outputs = model.generate(
182
- **inputs,
183
- max_new_tokens=max_tokens,
184
- do_sample=True,
185
- temperature=temperature,
186
- top_p=0.9,
187
- pad_token_id=tokenizer.eos_token_id,
188
  )
189
-
190
- # For causal models, skip the input tokens
191
- if "gpt2" in model_id or "codegen" in model_id:
192
- gcode = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
193
- else:
194
- gcode = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  gcode = validate_gcode(gcode)
197
  line_count = len(gcode.split("\n"))
198
-
199
- # Generate SVG preview
200
  svg = gcode_to_svg(gcode)
201
 
202
- gcode_with_header = f"; dcode output - {line_count} lines\n; Model: {model_name}\n; Machine validated\n\n{gcode}"
203
  return gcode_with_header, svg
204
 
205
  except Exception as e:
206
- error_svg = gcode_to_svg("")
207
- return f"; Error: {e}", error_svg
 
208
 
209
 
210
  # Custom CSS
211
  custom_css = """
212
- #preview-container {
213
- background: #0a0a0a;
214
- border-radius: 8px;
215
- padding: 0;
216
- }
217
  .gradio-container {
218
  max-width: 1200px !important;
219
  }
@@ -222,9 +289,11 @@ custom_css = """
222
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
223
  gr.Markdown("""
224
  # dcode
225
- **Text → Polargraph Gcode** | Generate machine-compatible gcode from natural language.
226
 
227
- [GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-flan-t5-base) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
 
 
228
  """)
229
 
230
  with gr.Row():
@@ -234,24 +303,24 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
234
  placeholder="drawing of a cat, abstract spiral, portrait...",
235
  lines=2
236
  )
237
- model_dropdown = gr.Dropdown(
238
- choices=list(MODELS.keys()),
239
- value="flan-t5-base (best)",
240
- label="Model"
241
- )
242
  with gr.Row():
243
- temperature = gr.Slider(0.1, 1.5, value=0.8, label="Temperature", info="Higher = more creative")
244
- max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="Max Tokens")
245
 
246
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
247
 
248
  gr.Examples(
249
  examples=[
250
- ["drawing of a cat"],
251
  ["abstract spiral pattern"],
252
  ["simple house with chimney"],
253
- ["portrait of a woman"],
254
- ["geometric shapes"],
255
  ],
256
  inputs=prompt,
257
  )
@@ -260,7 +329,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
260
  preview = gr.HTML(
261
  value=gcode_to_svg(""),
262
  label="Preview",
263
- elem_id="preview-container"
264
  )
265
 
266
  with gr.Accordion("Gcode Output", open=False):
@@ -273,12 +341,12 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as d
273
 
274
  generate_btn.click(
275
  generate,
276
- [prompt, model_dropdown, temperature, max_tokens],
277
  [gcode_output, preview]
278
  )
279
  prompt.submit(
280
  generate,
281
- [prompt, model_dropdown, temperature, max_tokens],
282
  [gcode_output, preview]
283
  )
284
 
 
1
+ """dcode Gradio Space - Text to Gcode via Latent Diffusion."""
2
 
3
  import re
4
  import gradio as gr
5
  import torch
6
+ from pathlib import Path
 
 
 
 
 
7
 
8
  # Machine limits
9
  BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
10
 
11
+ # Model caches
12
+ _generator = None
13
 
14
 
15
+ def get_generator():
16
+ """Load and cache the latent-gcode generator."""
17
+ global _generator
18
+ if _generator is None:
19
+ from diffusers import StableDiffusionPipeline, AutoencoderKL
20
+ from transformers import AutoTokenizer
21
+ import torch.nn as nn
22
+
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  dtype = torch.float16 if device == "cuda" else torch.float32
25
 
26
+ print("Loading Stable Diffusion pipeline...")
27
+ pipe = StableDiffusionPipeline.from_pretrained(
28
+ "stabilityai/stable-diffusion-2-1-base",
29
+ torch_dtype=dtype,
30
+ safety_checker=None,
31
+ ).to(device)
32
+
33
+ print("Loading gcode decoder...")
34
+ from huggingface_hub import hf_hub_download
35
+
36
+ # Download model files
37
+ model_path = hf_hub_download("twarner/dcode-latent-gcode", "pytorch_model.bin")
38
+ config_path = hf_hub_download("twarner/dcode-latent-gcode", "config.json")
39
+
40
+ import json
41
+ with open(config_path) as f:
42
+ config = json.load(f)
43
+
44
+ # Load tokenizer
45
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
46
+
47
+ # Build decoder model
48
+ class LatentProjector(nn.Module):
49
+ def __init__(self, latent_dim, hidden_size):
50
+ super().__init__()
51
+ self.proj = nn.Sequential(
52
+ nn.Linear(latent_dim, hidden_size * 2),
53
+ nn.GELU(),
54
+ nn.Linear(hidden_size * 2, hidden_size),
55
+ nn.LayerNorm(hidden_size),
56
+ )
57
+ def forward(self, x):
58
+ return self.proj(x)
59
 
60
+ class GcodeDecoder(nn.Module):
61
+ def __init__(self, hidden_size, vocab_size, num_layers, num_heads, max_seq_len):
62
+ super().__init__()
63
+ self.embed = nn.Embedding(vocab_size, hidden_size)
64
+ self.pos_embed = nn.Embedding(max_seq_len, hidden_size)
65
+ layer = nn.TransformerDecoderLayer(hidden_size, num_heads, hidden_size * 4, batch_first=True)
66
+ self.decoder = nn.TransformerDecoder(layer, num_layers)
67
+ self.head = nn.Linear(hidden_size, vocab_size)
68
+ self.max_seq_len = max_seq_len
69
+
70
+ def forward(self, tgt, memory, tgt_mask=None):
71
+ pos = torch.arange(tgt.size(1), device=tgt.device)
72
+ x = self.embed(tgt) + self.pos_embed(pos)
73
+ x = self.decoder(x, memory, tgt_mask=tgt_mask)
74
+ return self.head(x)
75
 
76
+ # Initialize models
77
+ latent_dim = 4 * 64 * 64
78
+ hidden_size = config.get("hidden_size", 512)
79
+ vocab_size = tokenizer.vocab_size
80
+ num_layers = config.get("num_layers", 6)
81
+ num_heads = config.get("num_heads", 8)
82
+ max_seq_len = config.get("max_seq_len", 1024)
83
+
84
+ projector = LatentProjector(latent_dim, hidden_size).to(device, dtype)
85
+ decoder = GcodeDecoder(hidden_size, vocab_size, num_layers, num_heads, max_seq_len).to(device, dtype)
86
+
87
+ # Load weights
88
+ state_dict = torch.load(model_path, map_location=device)
89
+
90
+ proj_state = {k.replace("projector.", ""): v for k, v in state_dict.items() if k.startswith("projector.")}
91
+ dec_state = {k.replace("decoder.", ""): v for k, v in state_dict.items() if k.startswith("decoder.")}
92
+
93
+ projector.load_state_dict(proj_state)
94
+ decoder.load_state_dict(dec_state)
95
+
96
+ projector.eval()
97
+ decoder.eval()
98
+
99
+ _generator = {
100
+ "pipe": pipe,
101
+ "projector": projector,
102
+ "decoder": decoder,
103
+ "tokenizer": tokenizer,
104
+ "device": device,
105
+ "dtype": dtype,
106
+ "max_seq_len": max_seq_len,
107
+ }
108
+ print("Models loaded!")
109
 
110
+ return _generator
111
 
112
 
113
  def validate_gcode(gcode: str) -> str:
 
146
  x, y = 0.0, 0.0
147
  pen_down = False
148
 
 
 
149
  lines = []
150
  for line in gcode.split("\n"):
151
  line = line.strip()
152
  if not line:
153
  continue
 
154
  parts = re.split(r'(?=[GM]\d)', line)
155
  for part in parts:
156
  part = part.strip()
 
158
  lines.append(part)
159
 
160
  for line in lines:
 
 
161
  if "M280" in line.upper():
162
  match = re.search(r"S(\d+)", line, re.IGNORECASE)
163
  if match:
164
  angle = int(match.group(1))
165
  was_down = pen_down
166
+ pen_down = angle < 50
167
  if was_down and not pen_down and len(current_path) > 1:
168
  paths.append(current_path[:])
169
  current_path = []
170
 
 
171
  x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
172
  y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
173
 
 
188
  if len(current_path) > 1:
189
  paths.append(current_path)
190
 
 
191
  w = BOUNDS["right"] - BOUNDS["left"]
192
  h = BOUNDS["top"] - BOUNDS["bottom"]
193
  padding = 20
 
195
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
196
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
197
  style="background: #fafafa; width: 100%; height: 500px; border-radius: 8px; border: 1px solid #e5e5e5;">
 
198
  <rect x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}"
199
  fill="#fff" stroke="#ccc" stroke-width="2"/>
 
200
  <line x1="0" y1="{-BOUNDS["top"]}" x2="0" y2="{-BOUNDS["bottom"]}" stroke="#ddd" stroke-width="1"/>
201
  <line x1="{BOUNDS["left"]}" y1="0" x2="{BOUNDS["right"]}" y2="0" stroke="#ddd" stroke-width="1"/>
 
 
 
 
 
 
 
202
  '''
203
 
 
204
  for path in paths:
205
  if len(path) < 2:
206
  continue
 
207
  d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
208
  svg += f'<path d="{d}" fill="none" stroke="#1a1a1a" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>'
209
 
 
210
  total_points = sum(len(p) for p in paths)
211
  svg += f'''
212
  <text x="{BOUNDS["left"] + 10}" y="{-BOUNDS["top"] + 25}" fill="#666" font-family="monospace" font-size="14">
213
  Paths: {len(paths)} | Points: {total_points}
214
  </text>
215
  '''
 
216
  svg += "</svg>"
217
  return svg
218
 
219
 
220
+ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
221
+ """Generate gcode from text prompt via latent diffusion."""
222
  if not prompt or not prompt.strip():
223
+ return "Enter a prompt to generate gcode", gcode_to_svg("")
 
224
 
225
  try:
226
+ gen = get_generator()
227
+ pipe = gen["pipe"]
228
+ projector = gen["projector"]
229
+ decoder = gen["decoder"]
230
+ tokenizer = gen["tokenizer"]
231
+ device = gen["device"]
232
+ dtype = gen["dtype"]
233
+ max_seq_len = gen["max_seq_len"]
234
 
235
+ # 1. Text -> Latent via Stable Diffusion
 
 
236
  with torch.no_grad():
237
+ result = pipe(
238
+ prompt,
239
+ num_inference_steps=num_steps,
240
+ guidance_scale=guidance,
241
+ output_type="latent",
 
 
242
  )
243
+ latent = result.images # [1, 4, 64, 64]
244
+
245
+ # 2. Latent -> Gcode via decoder
246
+ with torch.no_grad():
247
+ # Flatten and project latent
248
+ latent_flat = latent.view(1, -1).to(dtype) # [1, 4*64*64]
249
+ memory = projector(latent_flat).unsqueeze(1) # [1, 1, hidden]
250
+
251
+ # Autoregressive decoding
252
+ bos_id = tokenizer.bos_token_id or tokenizer.pad_token_id
253
+ eos_id = tokenizer.eos_token_id
254
+
255
+ tokens = torch.tensor([[bos_id]], device=device)
256
+
257
+ for _ in range(min(max_tokens, max_seq_len - 1)):
258
+ logits = decoder(tokens, memory)
259
+ next_logits = logits[:, -1, :] / temperature
260
+ probs = torch.softmax(next_logits, dim=-1)
261
+ next_token = torch.multinomial(probs, 1)
262
+ tokens = torch.cat([tokens, next_token], dim=1)
263
+
264
+ if next_token.item() == eos_id:
265
+ break
266
+
267
+ gcode = tokenizer.decode(tokens[0], skip_special_tokens=True)
268
 
269
  gcode = validate_gcode(gcode)
270
  line_count = len(gcode.split("\n"))
 
 
271
  svg = gcode_to_svg(gcode)
272
 
273
+ gcode_with_header = f"; dcode output - {line_count} lines\n; Prompt: {prompt}\n; Machine validated\n\n{gcode}"
274
  return gcode_with_header, svg
275
 
276
  except Exception as e:
277
+ import traceback
278
+ traceback.print_exc()
279
+ return f"; Error: {e}", gcode_to_svg("")
280
 
281
 
282
  # Custom CSS
283
  custom_css = """
 
 
 
 
 
284
  .gradio-container {
285
  max-width: 1200px !important;
286
  }
 
289
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="emerald")) as demo:
290
  gr.Markdown("""
291
  # dcode
292
+ **Text → Polargraph Gcode via Latent Diffusion**
293
 
294
+ Uses Stable Diffusion to generate latents from text, then decodes to machine gcode.
295
+
296
+ [GitHub](https://github.com/Twarner491/dcode) | [Model](https://huggingface.co/twarner/dcode-latent-gcode) | [Dataset](https://huggingface.co/datasets/twarner/dcode-polargraph-gcode)
297
  """)
298
 
299
  with gr.Row():
 
303
  placeholder="drawing of a cat, abstract spiral, portrait...",
304
  lines=2
305
  )
306
+
307
+ with gr.Row():
308
+ temperature = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
309
+ max_tokens = gr.Slider(256, 1024, value=512, step=128, label="Max Tokens")
310
+
311
  with gr.Row():
312
+ num_steps = gr.Slider(10, 50, value=25, step=5, label="Diffusion Steps")
313
+ guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
314
 
315
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
316
 
317
  gr.Examples(
318
  examples=[
319
+ ["line drawing of a cat"],
320
  ["abstract spiral pattern"],
321
  ["simple house with chimney"],
322
+ ["portrait sketch"],
323
+ ["geometric shapes and lines"],
324
  ],
325
  inputs=prompt,
326
  )
 
329
  preview = gr.HTML(
330
  value=gcode_to_svg(""),
331
  label="Preview",
 
332
  )
333
 
334
  with gr.Accordion("Gcode Output", open=False):
 
341
 
342
  generate_btn.click(
343
  generate,
344
+ [prompt, temperature, max_tokens, num_steps, guidance],
345
  [gcode_output, preview]
346
  )
347
  prompt.submit(
348
  generate,
349
+ [prompt, temperature, max_tokens, num_steps, guidance],
350
  [gcode_output, preview]
351
  )
352
 
requirements.txt CHANGED
@@ -1,3 +1,6 @@
 
1
  torch>=2.0
2
  transformers>=4.36
 
3
  accelerate>=0.25
 
 
1
+ gradio>=4.44.0
2
  torch>=2.0
3
  transformers>=4.36
4
+ diffusers>=0.25
5
  accelerate>=0.25
6
+ huggingface_hub>=0.20