twarner commited on
Commit
ac2abce
·
1 Parent(s): 3f5217f

Support v3 decoder architecture with CNN projection

Browse files
Files changed (1) hide show
  1. app.py +207 -106
app.py CHANGED
@@ -16,7 +16,120 @@ BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
16
  _model = None
17
 
18
 
19
- class GcodeDecoderConfig:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def __init__(
21
  self,
22
  latent_channels: int = 4,
@@ -39,8 +152,8 @@ class GcodeDecoderConfig:
39
  self.dropout = dropout
40
 
41
 
42
- class GcodeDecoder(nn.Module):
43
- def __init__(self, config: GcodeDecoderConfig):
44
  super().__init__()
45
  self.config = config
46
 
@@ -54,7 +167,6 @@ class GcodeDecoder(nn.Module):
54
  self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
55
  self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
56
 
57
- # Individual layers (matches v2 training architecture)
58
  self.layers = nn.ModuleList([
59
  nn.TransformerDecoderLayer(
60
  d_model=config.hidden_size,
@@ -84,7 +196,6 @@ class GcodeDecoder(nn.Module):
84
  positions = torch.arange(seq_len, device=device)
85
  x = self.token_embed(input_ids) + self.pos_embed(positions)
86
 
87
- # Causal mask must match dtype for attention
88
  causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
89
 
90
  for layer in self.layers:
@@ -92,43 +203,18 @@ class GcodeDecoder(nn.Module):
92
 
93
  x = self.ln_f(x)
94
  return self.lm_head(x)
95
-
96
- @torch.no_grad()
97
- def generate(self, latent, tokenizer, max_length=512, temperature=0.8, top_p=0.9):
98
- device = latent.device
99
- batch_size = latent.shape[0]
100
-
101
- input_ids = torch.full((batch_size, 1), tokenizer.pad_token_id, dtype=torch.long, device=device)
102
-
103
- for _ in range(max_length - 1):
104
- logits = self(latent, input_ids)
105
- next_logits = logits[:, -1, :] / temperature
106
-
107
- sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
108
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
109
- sorted_indices_to_remove = cumulative_probs > top_p
110
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
111
- sorted_indices_to_remove[:, 0] = False
112
-
113
- for b in range(batch_size):
114
- next_logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float('-inf')
115
-
116
- probs = torch.softmax(next_logits, dim=-1)
117
- next_token = torch.multinomial(probs, num_samples=1)
118
- input_ids = torch.cat([input_ids, next_token], dim=1)
119
-
120
- if next_token.item() == tokenizer.eos_token_id:
121
- break
122
-
123
- return tokenizer.decode(input_ids[0], skip_special_tokens=True)
124
 
125
 
 
 
 
 
126
  def get_model():
127
- """Load and cache the SD-Gcode model with full finetuned weights."""
128
  global _model
129
  if _model is None:
130
  from diffusers import StableDiffusionPipeline
131
- from transformers import AutoTokenizer
132
  from huggingface_hub import hf_hub_download
133
 
134
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -143,7 +229,13 @@ def get_model():
143
  with open(config_path) as f:
144
  config = json.load(f)
145
 
146
- # Load SD pipeline (we'll replace weights with finetuned ones)
 
 
 
 
 
 
147
  sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
148
  print(f"Loading SD from {sd_model_id}...")
149
  pipe = StableDiffusionPipeline.from_pretrained(
@@ -152,58 +244,52 @@ def get_model():
152
  safety_checker=None,
153
  ).to(device)
154
 
155
- # Build gcode decoder
156
- gcode_cfg = config.get("gcode_decoder", {})
157
- decoder_config = GcodeDecoderConfig(
158
- latent_channels=gcode_cfg.get("latent_channels", 4),
159
- latent_size=gcode_cfg.get("latent_size", 64),
160
- hidden_size=gcode_cfg.get("hidden_size", 768),
161
- num_layers=gcode_cfg.get("num_layers", 6),
162
- num_heads=gcode_cfg.get("num_heads", 12),
163
- vocab_size=gcode_cfg.get("vocab_size", 32128),
164
- max_seq_len=gcode_cfg.get("max_seq_len", 1024),
165
- )
166
- gcode_decoder = GcodeDecoder(decoder_config).to(device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # Load ALL finetuned weights
169
  print("Loading finetuned weights...")
170
  state_dict = torch.load(weights_path, map_location=device, weights_only=False)
171
 
172
- # Debug: print all key prefixes
173
- prefixes = set(k.split(".")[0] for k in state_dict.keys())
174
- print(f"State dict prefixes: {prefixes}")
175
- print(f"Sample keys: {list(state_dict.keys())[:5]}")
176
-
177
- # Load text encoder weights
178
  text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
179
  if k.startswith("text_encoder.")}
180
  if text_encoder_state:
181
  pipe.text_encoder.load_state_dict(text_encoder_state, strict=False)
182
  print(f"Loaded {len(text_encoder_state)} text encoder weights")
183
 
184
- # Load UNet weights
185
  unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items()
186
  if k.startswith("unet.")}
187
  if unet_state:
188
  pipe.unet.load_state_dict(unet_state, strict=False)
189
  print(f"Loaded {len(unet_state)} UNet weights")
190
 
191
- # Load gcode decoder weights
192
  decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
193
  if k.startswith("gcode_decoder.")}
194
  if decoder_state:
195
- # Check what keys the model expects vs what we have
196
- model_keys = set(gcode_decoder.state_dict().keys())
197
- ckpt_keys = set(decoder_state.keys())
198
- missing = model_keys - ckpt_keys
199
- extra = ckpt_keys - model_keys
200
- print(f"Decoder: model expects {len(model_keys)} keys, checkpoint has {len(ckpt_keys)}")
201
- if missing:
202
- print(f"Missing keys: {list(missing)[:5]}")
203
- if extra:
204
- print(f"Extra keys: {list(extra)[:5]}")
205
-
206
- # Try loading with strict=True to see errors
207
  try:
208
  gcode_decoder.load_state_dict(decoder_state, strict=True)
209
  print(f"Loaded {len(decoder_state)} decoder weights (strict)")
@@ -211,13 +297,19 @@ def get_model():
211
  print(f"Strict load failed: {e}")
212
  gcode_decoder.load_state_dict(decoder_state, strict=False)
213
  print(f"Loaded {len(decoder_state)} decoder weights (non-strict)")
214
- else:
215
- print("WARNING: No gcode_decoder weights found!")
216
 
217
  gcode_decoder.eval()
218
 
219
- # Gcode tokenizer
220
- gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
 
 
 
 
 
 
 
 
221
 
222
  _model = {
223
  "pipe": pipe,
@@ -226,12 +318,17 @@ def get_model():
226
  "device": device,
227
  "dtype": dtype,
228
  "num_inference_steps": config.get("num_inference_steps", 20),
 
229
  }
230
  print("Model loaded!")
231
 
232
  return _model
233
 
234
 
 
 
 
 
235
  def validate_gcode(gcode: str) -> str:
236
  """Clamp coordinates to machine bounds."""
237
  lines = []
@@ -268,13 +365,15 @@ def gcode_to_svg(gcode: str) -> str:
268
  x, y = 0.0, 0.0
269
  pen_down = False
270
 
271
- # Split on newlines OR command boundaries (for single-line gcode)
272
  lines = []
 
 
 
273
  for line in gcode.replace(";", "\n;").split("\n"):
274
  line = line.strip()
275
  if not line:
276
  continue
277
- # Split on G/M commands
278
  parts = re.split(r'(?=[GM]\d)', line)
279
  for part in parts:
280
  part = part.strip()
@@ -316,7 +415,6 @@ def gcode_to_svg(gcode: str) -> str:
316
  h = BOUNDS["top"] - BOUNDS["bottom"]
317
  padding = 20
318
 
319
- # Dark mode compatible SVG
320
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
321
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
322
  style="width: 100%; height: 480px; border: 1px solid var(--border, #e0e0e0); border-radius: 4px;">
@@ -354,6 +452,10 @@ def gcode_to_svg(gcode: str) -> str:
354
  return svg
355
 
356
 
 
 
 
 
357
  @spaces.GPU
358
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
359
  """Generate gcode from text prompt."""
@@ -367,6 +469,7 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
367
  gcode_tokenizer = m["gcode_tokenizer"]
368
  device = m["device"]
369
  dtype = m["dtype"]
 
370
 
371
  # Text -> Latent via SD diffusion
372
  with torch.no_grad():
@@ -377,25 +480,26 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
377
  output_type="latent",
378
  )
379
  latent = result.images.to(dtype)
380
- print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}, device: {latent.device}")
381
- print(f"Latent stats: min={latent.min():.3f}, max={latent.max():.3f}, mean={latent.mean():.3f}")
382
- print(f"Decoder dtype: {next(gcode_decoder.parameters()).dtype}, device: {next(gcode_decoder.parameters()).device}")
383
 
384
- # Latent -> Gcode via trained decoder (with debug)
385
  with torch.no_grad():
386
  batch_size = latent.shape[0]
387
- # Start with semicolon (gcode comment start) instead of pad
388
- # Gcode files start with "; Source: ..."
389
- start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
390
- print(f"Start tokens for ';': {start_tokens}")
391
- if start_tokens:
392
- start_id = start_tokens[0]
393
  else:
394
- start_id = gcode_tokenizer.pad_token_id
 
 
 
395
  input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
396
 
397
- generated_tokens = []
398
- for step in range(min(max_tokens, 1024) - 1):
 
399
  logits = gcode_decoder(latent, input_ids)
400
  next_logits = logits[:, -1, :] / temperature
401
 
@@ -413,23 +517,17 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
413
  next_token = torch.multinomial(probs, num_samples=1)
414
  input_ids = torch.cat([input_ids, next_token], dim=1)
415
 
416
- token_id = next_token.item()
417
- generated_tokens.append(token_id)
418
-
419
- # Debug first few tokens
420
- if step < 5:
421
- token_str = gcode_tokenizer.decode([token_id])
422
- # Check logits distribution
423
- top5_vals, top5_ids = torch.topk(logits[0, -1, :], 5)
424
- top5_tokens = [gcode_tokenizer.decode([i.item()]) for i in top5_ids]
425
- print(f"Step {step}: token_id={token_id}, token='{token_str}', top5={list(zip(top5_tokens, top5_vals.tolist()))}")
426
-
427
- if token_id == gcode_tokenizer.eos_token_id:
428
- print(f"Hit EOS at step {step}")
429
  break
430
 
431
- print(f"Generated {len(generated_tokens)} tokens")
432
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
 
 
 
 
433
  print(f"Decoded gcode length: {len(gcode)} chars")
434
 
435
  gcode = validate_gcode(gcode)
@@ -445,7 +543,10 @@ def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, g
445
  return f"; Error: {e}", gcode_to_svg("")
446
 
447
 
448
- # Minimal monochrome CSS with dark mode
 
 
 
449
  css = """
450
  @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap');
451
 
@@ -550,7 +651,7 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
550
 
551
  with gr.Accordion("settings", open=False):
552
  temperature = gr.Slider(0.5, 1.5, value=0.8, label="temperature", step=0.1)
553
- max_tokens = gr.Slider(256, 1024, value=512, step=128, label="max tokens")
554
  num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
555
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
556
 
 
16
  _model = None
17
 
18
 
19
+ # ============================================================================
20
+ # V3 DECODER ARCHITECTURE
21
+ # ============================================================================
22
+
23
+ class GcodeDecoderConfigV3:
24
+ """Config for v3 decoder architecture."""
25
+
26
+ def __init__(
27
+ self,
28
+ latent_channels: int = 4,
29
+ latent_size: int = 64,
30
+ hidden_size: int = 1024,
31
+ num_layers: int = 12,
32
+ num_heads: int = 16,
33
+ vocab_size: int = 8192,
34
+ max_seq_len: int = 2048,
35
+ dropout: float = 0.1,
36
+ ffn_mult: int = 4,
37
+ ):
38
+ self.latent_channels = latent_channels
39
+ self.latent_size = latent_size
40
+ self.hidden_size = hidden_size
41
+ self.num_layers = num_layers
42
+ self.num_heads = num_heads
43
+ self.vocab_size = vocab_size
44
+ self.max_seq_len = max_seq_len
45
+ self.dropout = dropout
46
+ self.ffn_mult = ffn_mult
47
+
48
+
49
+ class CNNLatentProjector(nn.Module):
50
+ """CNN-based latent projector preserving spatial structure."""
51
+
52
+ def __init__(self, config: GcodeDecoderConfigV3):
53
+ super().__init__()
54
+
55
+ self.cnn = nn.Sequential(
56
+ nn.Conv2d(config.latent_channels, 64, 3, stride=2, padding=1),
57
+ nn.LayerNorm([64, 32, 32]),
58
+ nn.GELU(),
59
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
60
+ nn.LayerNorm([128, 16, 16]),
61
+ nn.GELU(),
62
+ nn.Conv2d(128, 256, 3, stride=2, padding=1),
63
+ nn.LayerNorm([256, 8, 8]),
64
+ nn.GELU(),
65
+ nn.Conv2d(256, config.hidden_size, 3, stride=2, padding=1),
66
+ nn.LayerNorm([config.hidden_size, 4, 4]),
67
+ nn.GELU(),
68
+ )
69
+
70
+ self.num_memory_tokens = 16
71
+ self.memory_pos = nn.Parameter(torch.randn(1, self.num_memory_tokens, config.hidden_size) * 0.02)
72
+
73
+ def forward(self, latent: torch.Tensor) -> torch.Tensor:
74
+ B = latent.shape[0]
75
+ x = self.cnn(latent)
76
+ x = x.view(B, x.shape[1], -1).transpose(1, 2)
77
+ x = x + self.memory_pos.expand(B, -1, -1)
78
+ return x
79
+
80
+
81
+ class GcodeDecoderV3(nn.Module):
82
+ """Large transformer decoder for gcode generation (v3)."""
83
+
84
+ def __init__(self, config: GcodeDecoderConfigV3):
85
+ super().__init__()
86
+ self.config = config
87
+
88
+ self.latent_proj = CNNLatentProjector(config)
89
+ self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
90
+ self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
91
+ self.embed_drop = nn.Dropout(config.dropout)
92
+
93
+ self.layers = nn.ModuleList([
94
+ nn.TransformerDecoderLayer(
95
+ d_model=config.hidden_size,
96
+ nhead=config.num_heads,
97
+ dim_feedforward=config.hidden_size * config.ffn_mult,
98
+ dropout=config.dropout,
99
+ activation='gelu',
100
+ batch_first=True,
101
+ norm_first=True,
102
+ )
103
+ for _ in range(config.num_layers)
104
+ ])
105
+
106
+ self.ln_f = nn.LayerNorm(config.hidden_size)
107
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
108
+
109
+ def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
110
+ B, seq_len = input_ids.shape
111
+ device = input_ids.device
112
+ dtype = latent.dtype
113
+
114
+ memory = self.latent_proj(latent)
115
+ positions = torch.arange(seq_len, device=device)
116
+ x = self.token_embed(input_ids) + self.pos_embed(positions)
117
+ x = self.embed_drop(x)
118
+
119
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
120
+
121
+ for layer in self.layers:
122
+ x = layer(x, memory, tgt_mask=causal_mask)
123
+
124
+ x = self.ln_f(x)
125
+ return self.lm_head(x)
126
+
127
+
128
+ # ============================================================================
129
+ # V2 DECODER ARCHITECTURE (for backwards compatibility)
130
+ # ============================================================================
131
+
132
+ class GcodeDecoderConfigV2:
133
  def __init__(
134
  self,
135
  latent_channels: int = 4,
 
152
  self.dropout = dropout
153
 
154
 
155
+ class GcodeDecoderV2(nn.Module):
156
+ def __init__(self, config: GcodeDecoderConfigV2):
157
  super().__init__()
158
  self.config = config
159
 
 
167
  self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
168
  self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
169
 
 
170
  self.layers = nn.ModuleList([
171
  nn.TransformerDecoderLayer(
172
  d_model=config.hidden_size,
 
196
  positions = torch.arange(seq_len, device=device)
197
  x = self.token_embed(input_ids) + self.pos_embed(positions)
198
 
 
199
  causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
200
 
201
  for layer in self.layers:
 
203
 
204
  x = self.ln_f(x)
205
  return self.lm_head(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
+ # ============================================================================
209
+ # MODEL LOADING
210
+ # ============================================================================
211
+
212
  def get_model():
213
+ """Load and cache the SD-Gcode model."""
214
  global _model
215
  if _model is None:
216
  from diffusers import StableDiffusionPipeline
217
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast
218
  from huggingface_hub import hf_hub_download
219
 
220
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
229
  with open(config_path) as f:
230
  config = json.load(f)
231
 
232
+ # Determine model version
233
+ gcode_cfg = config.get("gcode_decoder", {})
234
+ is_v3 = gcode_cfg.get("ffn_mult") is not None or gcode_cfg.get("hidden_size", 768) >= 1024
235
+
236
+ print(f"Model version: {'v3' if is_v3 else 'v2'}")
237
+
238
+ # Load SD pipeline
239
  sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
240
  print(f"Loading SD from {sd_model_id}...")
241
  pipe = StableDiffusionPipeline.from_pretrained(
 
244
  safety_checker=None,
245
  ).to(device)
246
 
247
+ # Build decoder based on version
248
+ if is_v3:
249
+ decoder_config = GcodeDecoderConfigV3(
250
+ latent_channels=gcode_cfg.get("latent_channels", 4),
251
+ latent_size=gcode_cfg.get("latent_size", 64),
252
+ hidden_size=gcode_cfg.get("hidden_size", 1024),
253
+ num_layers=gcode_cfg.get("num_layers", 12),
254
+ num_heads=gcode_cfg.get("num_heads", 16),
255
+ vocab_size=gcode_cfg.get("vocab_size", 8192),
256
+ max_seq_len=gcode_cfg.get("max_seq_len", 2048),
257
+ ffn_mult=gcode_cfg.get("ffn_mult", 4),
258
+ )
259
+ gcode_decoder = GcodeDecoderV3(decoder_config).to(device, dtype)
260
+ else:
261
+ decoder_config = GcodeDecoderConfigV2(
262
+ latent_channels=gcode_cfg.get("latent_channels", 4),
263
+ latent_size=gcode_cfg.get("latent_size", 64),
264
+ hidden_size=gcode_cfg.get("hidden_size", 768),
265
+ num_layers=gcode_cfg.get("num_layers", 6),
266
+ num_heads=gcode_cfg.get("num_heads", 12),
267
+ vocab_size=gcode_cfg.get("vocab_size", 32128),
268
+ max_seq_len=gcode_cfg.get("max_seq_len", 1024),
269
+ )
270
+ gcode_decoder = GcodeDecoderV2(decoder_config).to(device, dtype)
271
 
272
+ # Load weights
273
  print("Loading finetuned weights...")
274
  state_dict = torch.load(weights_path, map_location=device, weights_only=False)
275
 
276
+ # Load SD components if present
 
 
 
 
 
277
  text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
278
  if k.startswith("text_encoder.")}
279
  if text_encoder_state:
280
  pipe.text_encoder.load_state_dict(text_encoder_state, strict=False)
281
  print(f"Loaded {len(text_encoder_state)} text encoder weights")
282
 
 
283
  unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items()
284
  if k.startswith("unet.")}
285
  if unet_state:
286
  pipe.unet.load_state_dict(unet_state, strict=False)
287
  print(f"Loaded {len(unet_state)} UNet weights")
288
 
289
+ # Load decoder weights
290
  decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
291
  if k.startswith("gcode_decoder.")}
292
  if decoder_state:
 
 
 
 
 
 
 
 
 
 
 
 
293
  try:
294
  gcode_decoder.load_state_dict(decoder_state, strict=True)
295
  print(f"Loaded {len(decoder_state)} decoder weights (strict)")
 
297
  print(f"Strict load failed: {e}")
298
  gcode_decoder.load_state_dict(decoder_state, strict=False)
299
  print(f"Loaded {len(decoder_state)} decoder weights (non-strict)")
 
 
300
 
301
  gcode_decoder.eval()
302
 
303
+ # Load gcode tokenizer
304
+ try:
305
+ # Try loading custom tokenizer
306
+ tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode", "gcode_tokenizer/tokenizer.json")
307
+ gcode_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
308
+ print("Loaded custom gcode tokenizer")
309
+ except Exception:
310
+ # Fallback to T5 tokenizer
311
+ gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
312
+ print("Using fallback T5 tokenizer")
313
 
314
  _model = {
315
  "pipe": pipe,
 
318
  "device": device,
319
  "dtype": dtype,
320
  "num_inference_steps": config.get("num_inference_steps", 20),
321
+ "is_v3": is_v3,
322
  }
323
  print("Model loaded!")
324
 
325
  return _model
326
 
327
 
328
+ # ============================================================================
329
+ # GCODE PROCESSING
330
+ # ============================================================================
331
+
332
  def validate_gcode(gcode: str) -> str:
333
  """Clamp coordinates to machine bounds."""
334
  lines = []
 
365
  x, y = 0.0, 0.0
366
  pen_down = False
367
 
368
+ # Split on newlines, newline tokens, or command boundaries
369
  lines = []
370
+ # Replace newline tokens with actual newlines
371
+ gcode = gcode.replace("<newline>", "\n")
372
+
373
  for line in gcode.replace(";", "\n;").split("\n"):
374
  line = line.strip()
375
  if not line:
376
  continue
 
377
  parts = re.split(r'(?=[GM]\d)', line)
378
  for part in parts:
379
  part = part.strip()
 
415
  h = BOUNDS["top"] - BOUNDS["bottom"]
416
  padding = 20
417
 
 
418
  svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
419
  viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
420
  style="width: 100%; height: 480px; border: 1px solid var(--border, #e0e0e0); border-radius: 4px;">
 
452
  return svg
453
 
454
 
455
+ # ============================================================================
456
+ # GENERATION
457
+ # ============================================================================
458
+
459
  @spaces.GPU
460
  def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float):
461
  """Generate gcode from text prompt."""
 
469
  gcode_tokenizer = m["gcode_tokenizer"]
470
  device = m["device"]
471
  dtype = m["dtype"]
472
+ is_v3 = m.get("is_v3", False)
473
 
474
  # Text -> Latent via SD diffusion
475
  with torch.no_grad():
 
480
  output_type="latent",
481
  )
482
  latent = result.images.to(dtype)
483
+ print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}")
 
 
484
 
485
+ # Latent -> Gcode via trained decoder
486
  with torch.no_grad():
487
  batch_size = latent.shape[0]
488
+
489
+ # Start token
490
+ if is_v3:
491
+ # V3 uses custom tokenizer with BOS
492
+ start_id = gcode_tokenizer.bos_token_id or 0
 
493
  else:
494
+ # V2 uses semicolon as start
495
+ start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
496
+ start_id = start_tokens[0] if start_tokens else gcode_tokenizer.pad_token_id
497
+
498
  input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
499
 
500
+ max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1)
501
+
502
+ for step in range(max_gen):
503
  logits = gcode_decoder(latent, input_ids)
504
  next_logits = logits[:, -1, :] / temperature
505
 
 
517
  next_token = torch.multinomial(probs, num_samples=1)
518
  input_ids = torch.cat([input_ids, next_token], dim=1)
519
 
520
+ # Check EOS
521
+ if next_token.item() == gcode_tokenizer.eos_token_id:
 
 
 
 
 
 
 
 
 
 
 
522
  break
523
 
524
+ print(f"Generated {input_ids.shape[1]} tokens")
525
  gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True)
526
+
527
+ # Post-process for v3: restore newlines
528
+ if is_v3:
529
+ gcode = gcode.replace("<newline>", "\n")
530
+
531
  print(f"Decoded gcode length: {len(gcode)} chars")
532
 
533
  gcode = validate_gcode(gcode)
 
543
  return f"; Error: {e}", gcode_to_svg("")
544
 
545
 
546
+ # ============================================================================
547
+ # UI
548
+ # ============================================================================
549
+
550
  css = """
551
  @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap');
552
 
 
651
 
652
  with gr.Accordion("settings", open=False):
653
  temperature = gr.Slider(0.5, 1.5, value=0.8, label="temperature", step=0.1)
654
+ max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="max tokens")
655
  num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
656
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
657