twarner commited on
Commit
cb2f3ac
·
1 Parent(s): 7505a9b

Debug state dict keys, fix button variant

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -162,6 +162,11 @@ def get_model():
162
  print("Loading finetuned weights...")
163
  state_dict = torch.load(weights_path, map_location=device, weights_only=False)
164
 
 
 
 
 
 
165
  # Load text encoder weights
166
  text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
167
  if k.startswith("text_encoder.")}
@@ -182,6 +187,9 @@ def get_model():
182
  if decoder_state:
183
  gcode_decoder.load_state_dict(decoder_state, strict=False)
184
  print(f"Loaded {len(decoder_state)} decoder weights")
 
 
 
185
 
186
  gcode_decoder.eval()
187
 
@@ -458,7 +466,7 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
458
  num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
459
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
460
 
461
- generate_btn = gr.Button("generate", variant="primary")
462
 
463
  gr.Examples(
464
  examples=[
 
162
  print("Loading finetuned weights...")
163
  state_dict = torch.load(weights_path, map_location=device, weights_only=False)
164
 
165
+ # Debug: print all key prefixes
166
+ prefixes = set(k.split(".")[0] for k in state_dict.keys())
167
+ print(f"State dict prefixes: {prefixes}")
168
+ print(f"Sample keys: {list(state_dict.keys())[:5]}")
169
+
170
  # Load text encoder weights
171
  text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
172
  if k.startswith("text_encoder.")}
 
187
  if decoder_state:
188
  gcode_decoder.load_state_dict(decoder_state, strict=False)
189
  print(f"Loaded {len(decoder_state)} decoder weights")
190
+ else:
191
+ print("WARNING: No gcode_decoder weights found!")
192
+ print(f"Looking for keys starting with 'gcode_decoder.', but found: {[k for k in state_dict.keys() if 'decoder' in k.lower()][:10]}")
193
 
194
  gcode_decoder.eval()
195
 
 
466
  num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps")
467
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance")
468
 
469
+ generate_btn = gr.Button("generate", variant="secondary")
470
 
471
  gr.Examples(
472
  examples=[