multimodalart HF Staff commited on
Commit
37a733b
·
verified ·
1 Parent(s): bf82e7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -54,6 +54,20 @@ MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-mod
54
  pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True, revision="aot-compatible")
55
  pipe.load_components(["transformer", "vae"], trust_remote_code=True, revision="aot-compatible", torch_dtype=torch.bfloat16)
56
  pipe.load_components(["text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  SEED_FRAME_URLS = [
59
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
@@ -147,21 +161,7 @@ def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_
147
  Generator that keeps GPU allocated and processes commands.
148
  Yields (frame, frame_count) tuples.
149
  """
150
- pipe.to("cuda")
151
- pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
152
- aoti_load_(
153
- pipe.transformer,
154
- "diffusers-internal-dev/world-engine-aot",
155
- "transformer-fp8.pt2",
156
- "transformer-fp8-constants.pt"
157
- )
158
- aoti_load_(
159
- pipe.vae.decoder,
160
- "diffusers-internal-dev/world-engine-aot",
161
- "decoder.pt2",
162
- "decoder-constants.pt"
163
- )
164
-
165
  n_frames = pipe.transformer.config.n_frames
166
  print(f"Model loaded! (n_frames={n_frames})")
167
 
 
54
  pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True, revision="aot-compatible")
55
  pipe.load_components(["transformer", "vae"], trust_remote_code=True, revision="aot-compatible", torch_dtype=torch.bfloat16)
56
  pipe.load_components(["text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16)
57
+ pipe.to("cuda")
58
+ pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
59
+ aoti_load_(
60
+ pipe.transformer,
61
+ "diffusers-internal-dev/world-engine-aot",
62
+ "transformer-fp8.pt2",
63
+ "transformer-fp8-constants.pt"
64
+ )
65
+ aoti_load_(
66
+ pipe.vae.decoder,
67
+ "diffusers-internal-dev/world-engine-aot",
68
+ "decoder.pt2",
69
+ "decoder-constants.pt"
70
+ )
71
 
72
  SEED_FRAME_URLS = [
73
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
 
161
  Generator that keeps GPU allocated and processes commands.
162
  Yields (frame, frame_count) tuples.
163
  """
164
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  n_frames = pipe.transformer.config.n_frames
166
  print(f"Model loaded! (n_frames={n_frames})")
167