dagloop5 commited on
Commit
864541a
·
verified ·
1 Parent(s): e2e429c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -60
app.py CHANGED
@@ -181,66 +181,100 @@ print("=" * 80)
181
  print("Preloading all models for ZeroGPU tensor packing...")
182
  print("This may take a few minutes...")
183
 
184
- # Inspect available attributes on key components
185
- print(" Inspecting DiffusionStage attributes...")
186
- stage_attrs = [a for a in dir(pipeline.stage_1) if not a.startswith('__')]
187
- print(f" DiffusionStage attributes: {stage_attrs}")
188
-
189
- # For DiffusionStage, the transformer is accessed via _transformer_ctx
190
- # but we need to trigger actual loading by accessing the context
191
- if hasattr(pipeline.stage_1, '_transformer_ctx'):
192
- print(" Loading stage 1 transformer via _transformer_ctx...")
193
- ctx = pipeline.stage_1._transformer_ctx
194
- if hasattr(ctx, '__enter__'):
195
- ctx.__enter__() # Force context entry to load transformer
196
-
197
- if hasattr(pipeline.stage_2, '_transformer_ctx'):
198
- print(" Loading stage 2 transformer via _transformer_ctx...")
199
- ctx = pipeline.stage_2._transformer_ctx
200
- if hasattr(ctx, '__enter__'):
201
- ctx.__enter__()
202
-
203
- # Inspect PromptEncoder attributes
204
- print(" Inspecting PromptEncoder attributes...")
205
- pe_attrs = [a for a in dir(pipeline.prompt_encoder) if not a.startswith('__')]
206
- print(f" PromptEncoder attributes: {pe_attrs}")
207
-
208
- # Try common names for video encoder in PromptEncoder
209
- for attr_name in ['video_encoder', '_video_encoder', 'enc', 'encoder', '_enc']:
210
- if hasattr(pipeline.prompt_encoder, attr_name):
211
- print(f" Loading video encoder via .{attr_name}...")
212
- _ = getattr(pipeline.prompt_encoder, attr_name)
213
- break
214
-
215
- # Inspect and load VideoDecoder
216
- print(" Inspecting VideoDecoder attributes...")
217
- vd_attrs = [a for a in dir(pipeline.video_decoder) if not a.startswith('__')]
218
- print(f" VideoDecoder attributes: {vd_attrs}")
219
- for attr_name in ['model', 'decoder', '_model']:
220
- if hasattr(pipeline.video_decoder, attr_name):
221
- print(f" Loading video decoder via .{attr_name}...")
222
- _ = getattr(pipeline.video_decoder, attr_name)
223
- break
224
-
225
- # Inspect and load AudioDecoder
226
- print(" Inspecting AudioDecoder attributes...")
227
- ad_attrs = [a for a in dir(pipeline.audio_decoder) if not a.startswith('__')]
228
- print(f" AudioDecoder attributes: {ad_attrs}")
229
- for attr_name in ['model', 'decoder', '_model']:
230
- if hasattr(pipeline.audio_decoder, attr_name):
231
- print(f" Loading audio decoder via .{attr_name}...")
232
- _ = getattr(pipeline.audio_decoder, attr_name)
233
- break
234
-
235
- # Inspect and load VideoUpsampler
236
- print(" Inspecting VideoUpsampler attributes...")
237
- up_attrs = [a for a in dir(pipeline.upsampler) if not a.startswith('__')]
238
- print(f" VideoUpsampler attributes: {up_attrs}")
239
- for attr_name in ['model', 'upsampler', '_model']:
240
- if hasattr(pipeline.upsampler, attr_name):
241
- print(f" Loading spatial upsampler via .{attr_name}...")
242
- _ = getattr(pipeline.upsampler, attr_name)
243
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  print("All models preloaded for ZeroGPU tensor packing!")
246
  print("=" * 80)
 
181
  print("Preloading all models for ZeroGPU tensor packing...")
182
  print("This may take a few minutes...")
183
 
184
+ # The TI2VidTwoStagesHQPipeline uses context managers for lazy loading.
185
+ # We need to enter the contexts, capture the loaded models, AND preserve them
186
+ # by replacing the pipeline's internal references with lambdas that hold them.
187
+ # This prevents garbage collection and allows ZeroGPU to pack them.
188
+
189
+ # 1. Load transformer via _transformer_ctx (enter context to load, store result)
190
+ print(" Loading stage 1 transformer...")
191
+ pipeline.stage_1._transformer_ctx.__enter__()
192
+ # Capture the actual model from the context
193
+ _stage_1_transformer = pipeline.stage_1._transformer_ctx.__dict__.get('transformer') or \
194
+ getattr(pipeline.stage_1, '_transformer', None)
195
+ # Replace _transformer_ctx with lambda that returns the captured model
196
+ pipeline.stage_1._transformer_ctx = type('ctx', (), {
197
+ '__enter__': lambda s: _stage_1_transformer,
198
+ '__exit__': lambda s, *a: None,
199
+ '__call__': lambda s, *a, **kw: _stage_1_transformer(*a, **kw)
200
+ })()
201
+ print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
202
+
203
+ print(" Loading stage 2 transformer...")
204
+ pipeline.stage_2._transformer_ctx.__enter__()
205
+ _stage_2_transformer = pipeline.stage_2._transformer_ctx.__dict__.get('transformer') or \
206
+ getattr(pipeline.stage_2, '_transformer', None)
207
+ pipeline.stage_2._transformer_ctx = type('ctx', (), {
208
+ '__enter__': lambda s: _stage_2_transformer,
209
+ '__exit__': lambda s, *a: None,
210
+ '__call__': lambda s, *a, **kw: _stage_2_transformer(*a, **kw)
211
+ })()
212
+ print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
213
+
214
+ # 2. Load text encoder via _text_encoder_ctx
215
+ print(" Loading Gemma text encoder...")
216
+ pipeline.prompt_encoder._text_encoder_ctx.__enter__()
217
+ _text_encoder = pipeline.prompt_encoder._text_encoder_ctx.__dict__.get('text_encoder') or \
218
+ getattr(pipeline.prompt_encoder, '_text_encoder', None)
219
+ # Store as instance attribute and create replacement context
220
+ pipeline.prompt_encoder._text_encoder = _text_encoder
221
+ pipeline.prompt_encoder._text_encoder_ctx = type('ctx', (), {
222
+ '__enter__': lambda s: _text_encoder,
223
+ '__exit__': lambda s, *a: None
224
+ })()
225
+ print(f" Captured text encoder: {type(_text_encoder)}")
226
+
227
+ # 3. Load video encoder (from prompt_encoder's video_encoder method)
228
+ print(" Loading video encoder...")
229
+ _video_encoder = pipeline.prompt_encoder.video_encoder()
230
+ pipeline.prompt_encoder.video_encoder = lambda: _video_encoder
231
+ print(f" Captured video encoder: {type(_video_encoder)}")
232
+
233
+ # 4. Load video decoder via _decoder_builder
234
+ print(" Loading video decoder...")
235
+ _video_decoder = pipeline.video_decoder._decoder_builder()
236
+ pipeline.video_decoder._decoder_builder = lambda: _video_decoder
237
+ # Also try direct model attribute if exists
238
+ if hasattr(pipeline.video_decoder, '_decoder'):
239
+ pipeline.video_decoder._decoder = _video_decoder
240
+ print(f" Captured video decoder: {type(_video_decoder)}")
241
+
242
+ # 5. Load audio decoder via _decoder_builder
243
+ print(" Loading audio decoder...")
244
+ _audio_decoder = pipeline.audio_decoder._decoder_builder()
245
+ pipeline.audio_decoder._decoder_builder = lambda: _audio_decoder
246
+ if hasattr(pipeline.audio_decoder, '_decoder'):
247
+ pipeline.audio_decoder._decoder = _audio_decoder
248
+ print(f" Captured audio decoder: {type(_audio_decoder)}")
249
+
250
+ # 6. Load vocoder (audio decoder has _vocoder_builder)
251
+ print(" Loading vocoder...")
252
+ if hasattr(pipeline.audio_decoder, '_vocoder_builder'):
253
+ _vocoder = pipeline.audio_decoder._vocoder_builder()
254
+ pipeline.audio_decoder._vocoder_builder = lambda: _vocoder
255
+ print(f" Captured vocoder: {type(_vocoder)}")
256
+
257
+ # 7. Load spatial upsampler via _upsampler_builder
258
+ print(" Loading spatial upsampler...")
259
+ _spatial_upsampler = pipeline.upsampler._upsampler_builder()
260
+ pipeline.upsampler._upsampler_builder = lambda: _spatial_upsampler
261
+ # Also try _encoder_builder
262
+ if hasattr(pipeline.upsampler, '_encoder'):
263
+ pipeline.upsampler._encoder = _spatial_upsampler
264
+ print(f" Captured spatial upsampler: {type(_spatial_upsampler)}")
265
+
266
+ # 8. Load image conditioner
267
+ print(" Loading image conditioner...")
268
+ if hasattr(pipeline, 'image_conditioner'):
269
+ if hasattr(pipeline.image_conditioner, 'video_encoder'):
270
+ _ic_encoder = pipeline.image_conditioner.video_encoder()
271
+ pipeline.image_conditioner.video_encoder = lambda: _ic_encoder
272
+
273
+ # Create global references to prevent garbage collection
274
+ # These ensure models stay loaded and ZeroGPU can pack them
275
+ print(" Creating global references to prevent garbage collection...")
276
+ global _stage_1_transformer, _stage_2_transformer, _text_encoder, _video_encoder
277
+ global _video_decoder, _audio_decoder, _vocoder, _spatial_upsampler
278
 
279
  print("All models preloaded for ZeroGPU tensor packing!")
280
  print("=" * 80)