Spaces:
Runtime error
Runtime error
fix: adapt model_manager to HF model API (no schedule_config/cfg_config dicts)
Browse files- model_manager.py +39 -32
model_manager.py
CHANGED
|
@@ -62,9 +62,14 @@ class ModelManager:
|
|
| 62 |
# Load models from HF Hub
|
| 63 |
self.vae, self.model = self._load_models(model_name)
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
self._base_schedule_config =
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
# Frame buffer
|
| 70 |
self.frame_buffer = FrameBuffer(target_buffer_size=4)
|
|
@@ -82,7 +87,8 @@ class ModelManager:
|
|
| 82 |
self.should_stop = False
|
| 83 |
|
| 84 |
# Model generation state
|
| 85 |
-
self.first_chunk = True
|
|
|
|
| 86 |
self.history_length = 30
|
| 87 |
|
| 88 |
print("ModelManager initialized successfully")
|
|
@@ -187,16 +193,14 @@ class ModelManager:
|
|
| 187 |
self.stream_recovery.reset()
|
| 188 |
self.vae.clear_cache()
|
| 189 |
self.first_chunk = True
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
self.model.
|
| 193 |
-
self.model.
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
schedule_config=self.model.schedule_config,
|
| 197 |
-
)
|
| 198 |
print(
|
| 199 |
-
f"Model initialized with history length: {self.history_length}
|
| 200 |
)
|
| 201 |
|
| 202 |
# Start generation thread
|
|
@@ -269,20 +273,16 @@ class ModelManager:
|
|
| 269 |
joints_num=22, smoothing_alpha=self.smoothing_alpha
|
| 270 |
)
|
| 271 |
|
| 272 |
-
# Restore
|
| 273 |
-
self.model.
|
| 274 |
-
self.model.
|
| 275 |
-
self.model.
|
| 276 |
-
self.
|
| 277 |
-
|
| 278 |
-
# Initialize model
|
| 279 |
-
self.model.init_generated(
|
| 280 |
-
self.history_length,
|
| 281 |
-
batch_size=1,
|
| 282 |
-
schedule_config=self.model.schedule_config,
|
| 283 |
-
)
|
| 284 |
print(
|
| 285 |
-
f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}
|
| 286 |
)
|
| 287 |
|
| 288 |
def _generation_loop(self):
|
|
@@ -299,12 +299,14 @@ class ModelManager:
|
|
| 299 |
try:
|
| 300 |
step_start = time.time()
|
| 301 |
|
| 302 |
-
# Generate one token (produces
|
| 303 |
-
|
| 304 |
-
x = {text_key: [self.current_text]}
|
| 305 |
|
| 306 |
# Generate from model (1 token)
|
| 307 |
-
output = self.model.stream_generate_step(
|
|
|
|
|
|
|
|
|
|
| 308 |
generated = output["generated"]
|
| 309 |
|
| 310 |
# Skip if no frames committed yet
|
|
@@ -364,8 +366,13 @@ class ModelManager:
|
|
| 364 |
"current_text": self.current_text,
|
| 365 |
"smoothing_alpha": self.smoothing_alpha,
|
| 366 |
"history_length": self.history_length,
|
| 367 |
-
"schedule_config":
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
}
|
| 370 |
|
| 371 |
|
|
|
|
| 62 |
# Load models from HF Hub
|
| 63 |
self.vae, self.model = self._load_models(model_name)
|
| 64 |
|
| 65 |
+
# Build config dicts from model's individual attributes (HF model API)
|
| 66 |
+
self._base_schedule_config = {
|
| 67 |
+
"chunk_size": self.model.chunk_size,
|
| 68 |
+
"steps": self.model.noise_steps,
|
| 69 |
+
}
|
| 70 |
+
self._base_cfg_config = {
|
| 71 |
+
"cfg_scale": self.model.cfg_scale,
|
| 72 |
+
}
|
| 73 |
|
| 74 |
# Frame buffer
|
| 75 |
self.frame_buffer = FrameBuffer(target_buffer_size=4)
|
|
|
|
| 87 |
self.should_stop = False
|
| 88 |
|
| 89 |
# Model generation state
|
| 90 |
+
self.first_chunk = True # For VAE stream_decode
|
| 91 |
+
self._model_first_chunk = True # For model stream_generate_step
|
| 92 |
self.history_length = 30
|
| 93 |
|
| 94 |
print("ModelManager initialized successfully")
|
|
|
|
| 193 |
self.stream_recovery.reset()
|
| 194 |
self.vae.clear_cache()
|
| 195 |
self.first_chunk = True
|
| 196 |
+
self._model_first_chunk = True
|
| 197 |
+
# Restore model params from base config
|
| 198 |
+
self.model.chunk_size = self._base_schedule_config["chunk_size"]
|
| 199 |
+
self.model.noise_steps = self._base_schedule_config["steps"]
|
| 200 |
+
self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
|
| 201 |
+
self.model.init_generated(self.history_length, batch_size=1)
|
|
|
|
|
|
|
| 202 |
print(
|
| 203 |
+
f"Model initialized with history length: {self.history_length}"
|
| 204 |
)
|
| 205 |
|
| 206 |
# Start generation thread
|
|
|
|
| 273 |
joints_num=22, smoothing_alpha=self.smoothing_alpha
|
| 274 |
)
|
| 275 |
|
| 276 |
+
# Restore model params from base config
|
| 277 |
+
self.model.chunk_size = self._base_schedule_config["chunk_size"]
|
| 278 |
+
self.model.noise_steps = self._base_schedule_config["steps"]
|
| 279 |
+
self.model.cfg_scale = self._base_cfg_config["cfg_scale"]
|
| 280 |
+
self._model_first_chunk = True
|
| 281 |
+
|
| 282 |
+
# Initialize model
|
| 283 |
+
self.model.init_generated(self.history_length, batch_size=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
print(
|
| 285 |
+
f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}"
|
| 286 |
)
|
| 287 |
|
| 288 |
def _generation_loop(self):
|
|
|
|
| 299 |
try:
|
| 300 |
step_start = time.time()
|
| 301 |
|
| 302 |
+
# Generate one token (produces frames from VAE)
|
| 303 |
+
x = {"text": [self.current_text]}
|
|
|
|
| 304 |
|
| 305 |
# Generate from model (1 token)
|
| 306 |
+
output = self.model.stream_generate_step(
|
| 307 |
+
x, first_chunk=self._model_first_chunk
|
| 308 |
+
)
|
| 309 |
+
self._model_first_chunk = False
|
| 310 |
generated = output["generated"]
|
| 311 |
|
| 312 |
# Skip if no frames committed yet
|
|
|
|
| 366 |
"current_text": self.current_text,
|
| 367 |
"smoothing_alpha": self.smoothing_alpha,
|
| 368 |
"history_length": self.history_length,
|
| 369 |
+
"schedule_config": {
|
| 370 |
+
"chunk_size": self.model.chunk_size,
|
| 371 |
+
"steps": self.model.noise_steps,
|
| 372 |
+
},
|
| 373 |
+
"cfg_config": {
|
| 374 |
+
"cfg_scale": self.model.cfg_scale,
|
| 375 |
+
},
|
| 376 |
}
|
| 377 |
|
| 378 |
|