H-Liu1997 commited on
Commit
47551f4
·
1 Parent(s): e843211

fix: adapt model_manager to HF model API (no schedule_config/cfg_config dicts)

Browse files
Files changed (1) hide show
  1. model_manager.py +39 -32
model_manager.py CHANGED
@@ -62,9 +62,14 @@ class ModelManager:
62
  # Load models from HF Hub
63
  self.vae, self.model = self._load_models(model_name)
64
 
65
- # Save clean copies of user-facing configs (before any runtime injection)
66
- self._base_schedule_config = dict(self.model.schedule_config)
67
- self._base_cfg_config = dict(self.model.cfg_config)
 
 
 
 
 
68
 
69
  # Frame buffer
70
  self.frame_buffer = FrameBuffer(target_buffer_size=4)
@@ -82,7 +87,8 @@ class ModelManager:
82
  self.should_stop = False
83
 
84
  # Model generation state
85
- self.first_chunk = True
 
86
  self.history_length = 30
87
 
88
  print("ModelManager initialized successfully")
@@ -187,16 +193,14 @@ class ModelManager:
187
  self.stream_recovery.reset()
188
  self.vae.clear_cache()
189
  self.first_chunk = True
190
- # Restore clean config before init (clears runtime-injected keys)
191
- self.model.schedule_config.clear()
192
- self.model.schedule_config.update(self._base_schedule_config)
193
- self.model.init_generated(
194
- self.history_length,
195
- batch_size=1,
196
- schedule_config=self.model.schedule_config,
197
- )
198
  print(
199
- f"Model initialized with history length: {self.history_length}, schedule_config: {self.model.schedule_config}"
200
  )
201
 
202
  # Start generation thread
@@ -269,20 +273,16 @@ class ModelManager:
269
  joints_num=22, smoothing_alpha=self.smoothing_alpha
270
  )
271
 
272
- # Restore clean configs before init (clears runtime-injected keys)
273
- self.model.schedule_config.clear()
274
- self.model.schedule_config.update(self._base_schedule_config)
275
- self.model.cfg_config.clear()
276
- self.model.cfg_config.update(self._base_cfg_config)
277
-
278
- # Initialize model (reads steps/chunk_size from model.schedule_config directly)
279
- self.model.init_generated(
280
- self.history_length,
281
- batch_size=1,
282
- schedule_config=self.model.schedule_config,
283
- )
284
  print(
285
- f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}, schedule_config: {self.model.schedule_config}"
286
  )
287
 
288
  def _generation_loop(self):
@@ -299,12 +299,14 @@ class ModelManager:
299
  try:
300
  step_start = time.time()
301
 
302
- # Generate one token (produces 4 frames from VAE)
303
- text_key = self.model.input_keys["text"]
304
- x = {text_key: [self.current_text]}
305
 
306
  # Generate from model (1 token)
307
- output = self.model.stream_generate_step(x)
 
 
 
308
  generated = output["generated"]
309
 
310
  # Skip if no frames committed yet
@@ -364,8 +366,13 @@ class ModelManager:
364
  "current_text": self.current_text,
365
  "smoothing_alpha": self.smoothing_alpha,
366
  "history_length": self.history_length,
367
- "schedule_config": dict(self.model.schedule_config),
368
- "cfg_config": dict(self.model.cfg_config),
 
 
 
 
 
369
  }
370
 
371
 
 
62
  # Load models from HF Hub
63
  self.vae, self.model = self._load_models(model_name)
64
 
65
+ # Build config dicts from model's individual attributes (HF model API)
66
+ self._base_schedule_config = {
67
+ "chunk_size": self.model.chunk_size,
68
+ "steps": self.model.noise_steps,
69
+ }
70
+ self._base_cfg_config = {
71
+ "cfg_scale": self.model.cfg_scale,
72
+ }
73
 
74
  # Frame buffer
75
  self.frame_buffer = FrameBuffer(target_buffer_size=4)
 
87
  self.should_stop = False
88
 
89
  # Model generation state
90
+ self.first_chunk = True # For VAE stream_decode
91
+ self._model_first_chunk = True # For model stream_generate_step
92
  self.history_length = 30
93
 
94
  print("ModelManager initialized successfully")
 
193
  self.stream_recovery.reset()
194
  self.vae.clear_cache()
195
  self.first_chunk = True
196
+ self._model_first_chunk = True
197
+ # Restore model params from base config
198
+ self.model.chunk_size = self._base_schedule_config["chunk_size"]
199
+ self.model.noise_steps = self._base_schedule_config["steps"]
200
+ self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
201
+ self.model.init_generated(self.history_length, batch_size=1)
 
 
202
  print(
203
+ f"Model initialized with history length: {self.history_length}"
204
  )
205
 
206
  # Start generation thread
 
273
  joints_num=22, smoothing_alpha=self.smoothing_alpha
274
  )
275
 
276
+ # Restore model params from base config
277
+ self.model.chunk_size = self._base_schedule_config["chunk_size"]
278
+ self.model.noise_steps = self._base_schedule_config["steps"]
279
+ self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
280
+ self._model_first_chunk = True
281
+
282
+ # Initialize model
283
+ self.model.init_generated(self.history_length, batch_size=1)
 
 
 
 
284
  print(
285
+ f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}"
286
  )
287
 
288
  def _generation_loop(self):
 
299
  try:
300
  step_start = time.time()
301
 
302
+ # Generate one token (produces frames from VAE)
303
+ x = {"text": [self.current_text]}
 
304
 
305
  # Generate from model (1 token)
306
+ output = self.model.stream_generate_step(
307
+ x, first_chunk=self._model_first_chunk
308
+ )
309
+ self._model_first_chunk = False
310
  generated = output["generated"]
311
 
312
  # Skip if no frames committed yet
 
366
  "current_text": self.current_text,
367
  "smoothing_alpha": self.smoothing_alpha,
368
  "history_length": self.history_length,
369
+ "schedule_config": {
370
+ "chunk_size": self.model.chunk_size,
371
+ "steps": self.model.noise_steps,
372
+ },
373
+ "cfg_config": {
374
+ "cfg_scale": self.model.cfg_scale,
375
+ },
376
  }
377
 
378