rahul7star commited on
Commit
d9b2184
·
verified ·
1 Parent(s): 561e1e1

Update app_exp.py

Browse files

one comit before is all loads just fix the fa3 if req

Files changed (1) hide show
  1. app_exp.py +110 -179
app_exp.py CHANGED
@@ -6,46 +6,16 @@ import sys
6
  import subprocess
7
  import tempfile
8
  import numpy as np
9
- import site
10
- import importlib
11
  from PIL import Image
12
- from huggingface_hub import snapshot_download, hf_hub_download
13
-
14
- # ============================================================
15
- # 0️⃣ Install required packages
16
- # ============================================================
17
- subprocess.run(["pip3", "install", "-U", "cache-dit"], check=True)
18
-
19
-
20
-
21
- import cache_dit
22
-
23
- enable_fa3 = False # default if FA3 cannot be loaded
24
-
25
- try:
26
- print("Installing FlashAttention 3...")
27
- flash_attention_wheel = hf_hub_download(
28
- repo_id="rahul7star/flash-attn-3",
29
- repo_type="model",
30
- filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
31
- )
32
- subprocess.run(["pip", "install", flash_attention_wheel], check=True)
33
- site.addsitedir(site.getsitepackages()[0])
34
- importlib.invalidate_caches()
35
- enable_fa3 = True
36
- print("✅ FlashAttention 3 installed and enabled")
37
- except Exception as e:
38
- print(f"⚠️ Could not install FlashAttention 3: {e}")
39
- # enable_fa3 remains False
40
 
41
-
42
- # ============================================================
43
- # 1️⃣ Repository & Weights
44
- # ============================================================
45
  REPO_PATH = "LongCat-Video"
46
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
47
 
 
48
  if not os.path.exists(REPO_PATH):
 
49
  subprocess.run(
50
  ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
51
  check=True
@@ -53,6 +23,7 @@ if not os.path.exists(REPO_PATH):
53
 
54
  sys.path.insert(0, os.path.abspath(REPO_PATH))
55
 
 
56
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
57
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
58
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
@@ -60,9 +31,9 @@ from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DMod
60
  from longcat_video.context_parallel import context_parallel_util
61
  from transformers import AutoTokenizer, UMT5EncoderModel
62
  from diffusers.utils import export_to_video
63
- from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
64
  from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
65
 
 
66
  if not os.path.exists(CHECKPOINT_DIR):
67
  snapshot_download(
68
  repo_id="meituan-longcat/LongCat-Video",
@@ -71,60 +42,33 @@ if not os.path.exists(CHECKPOINT_DIR):
71
  ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
72
  )
73
 
74
- # ============================================================
75
- # 2️⃣ Device & Models (with cache & quantization)
76
- # ============================================================
77
- device = "cuda" if torch.cuda.is_available() else "cpu"
78
- torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
79
  pipe = None
 
 
80
 
 
81
  try:
82
  cp_split_hw = context_parallel_util.get_optimal_split(1)
83
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
84
-
85
- # Text encoder with 4-bit quantization
86
- text_encoder = UMT5EncoderModel.from_pretrained(
87
- CHECKPOINT_DIR,
88
- subfolder="text_encoder",
89
- torch_dtype=torch_dtype,
90
- quantization_config=TransformersBitsAndBytesConfig(
91
- load_in_4bit=True,
92
- bnb_4bit_quant_type="nf4",
93
- bnb_4bit_compute_dtype=torch_dtype
94
- )
95
- )
96
-
97
  vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
98
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
99
 
100
- # DiT model with FP8/4-bit quantization + cache
 
 
 
 
 
101
  dit = LongCatVideoTransformer3DModel.from_pretrained(
102
  CHECKPOINT_DIR,
103
- enable_flashattn3=enable_fa3,
 
104
  enable_xformers=True,
105
  subfolder="dit",
106
  cp_split_hw=cp_split_hw,
107
- torch_dtype=torch_dtype
108
- )
109
-
110
- # Enable Cache-DiT
111
- cache_dit.enable_cache(
112
- cache_dit.BlockAdapter(
113
- transformer=dit,
114
- blocks=dit.blocks,
115
- forward_pattern=cache_dit.ForwardPattern.Pattern_3,
116
- check_forward_pattern=False,
117
- has_separate_cfg=False
118
- ),
119
- cache_config=cache_dit.DBCacheConfig(
120
- Fn_compute_blocks=1,
121
- Bn_compute_blocks=1,
122
- max_warmup_steps=5,
123
- max_cached_steps=50,
124
- max_continuous_cached_steps=50,
125
- residual_diff_threshold=0.01,
126
- num_inference_steps=50
127
- )
128
  )
129
 
130
  pipe = LongCatVideoPipeline(
@@ -135,172 +79,159 @@ try:
135
  dit=dit,
136
  )
137
  pipe.to(device)
138
- print("✅ Models loaded with Cache-DiT and quantization")
 
 
139
 
140
  except Exception as e:
141
- print(f"Failed to load models: {e}")
142
  pipe = None
143
 
144
- # ============================================================
145
- # 3️⃣ Generation Helper
146
- # ============================================================
147
  def torch_gc():
148
  if torch.cuda.is_available():
149
  torch.cuda.empty_cache()
150
  torch.cuda.ipc_collect()
151
 
152
- def check_duration(
 
 
 
 
 
153
  mode,
154
- prompt,
155
- neg_prompt,
156
  image,
157
  height, width, resolution,
158
  seed,
159
  use_distill,
160
  use_refine,
161
- progress
 
162
  ):
163
- if use_distill and resolution=="480p":
164
- return 180
165
- elif resolution=="720p":
166
- return 360
167
- else:
168
- return 900
169
-
170
- @spaces.GPU(duration=90)
171
- def generate_video(mode, prompt, neg_prompt, image, height, width, resolution,
172
- seed, use_distill, use_refine, duration_sec, progress=gr.Progress(track_tqdm=True)):
173
-
174
  if pipe is None:
175
- raise gr.Error("Models not loaded")
 
 
 
 
176
 
177
- fps = 15 if use_distill else 30
178
- num_frames = int(duration_sec * fps)
179
  generator = torch.Generator(device=device).manual_seed(int(seed))
180
  is_distill = use_distill or use_refine
181
- print(prompt)
182
 
183
- progress(0.2, desc="Stage 1: Base Video Generation")
184
- pipe.dit.enable_loras(['cfg_step_lora'] if is_distill else [])
185
- num_inference_steps = 12 if is_distill else 24
186
- guidance_scale = 2.0 if is_distill else 4.0
187
- curr_neg_prompt = "" if is_distill else neg_prompt
 
 
 
 
188
 
189
- if mode=="t2v":
190
  output = pipe.generate_t2v(
191
  prompt=prompt,
192
- negative_prompt=curr_neg_prompt,
193
  height=height,
194
  width=width,
195
  num_frames=num_frames,
196
  num_inference_steps=num_inference_steps,
197
  use_distill=is_distill,
198
  guidance_scale=guidance_scale,
199
- generator=generator
200
  )[0]
201
  else:
202
- pil_img = Image.fromarray(image)
203
  output = pipe.generate_i2v(
204
- image=pil_img,
205
  prompt=prompt,
206
- negative_prompt=curr_neg_prompt,
207
  resolution=resolution,
208
  num_frames=num_frames,
209
  num_inference_steps=num_inference_steps,
210
  use_distill=is_distill,
211
  guidance_scale=guidance_scale,
212
- generator=generator
213
  )[0]
214
 
215
- pipe.dit.disable_all_loras()
 
216
  torch_gc()
217
 
218
  if use_refine:
219
- progress(0.5, desc="Stage 2: Refinement")
220
  pipe.dit.enable_loras(['refinement_lora'])
221
  pipe.dit.enable_bsa()
222
- stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
223
  stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
224
- refine_image = Image.fromarray(image) if mode=='i2v' else None
225
  output = pipe.generate_refine(
226
  image=refine_image,
227
  prompt=prompt,
228
  stage1_video=stage1_video_pil,
229
- num_cond_frames=1 if mode=='i2v' else 0,
230
  num_inference_steps=50,
231
- generator=generator
232
  )[0]
233
  pipe.dit.disable_all_loras()
234
  pipe.dit.disable_bsa()
235
  torch_gc()
236
 
237
  progress(1.0, desc="Exporting video")
238
- print('video generated')
239
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
240
- export_to_video(output, f.name, fps=fps)
241
- return f.name
242
-
243
- # ============================================================
244
- # 4️⃣ Gradio UI
245
- # ============================================================
246
- css=".fillable{max-width:960px !important}"
247
 
 
 
248
  with gr.Blocks(css=css) as demo:
249
- gr.Markdown("# 🎬 LongCat-Video with Cache-DiT & Quantization")
250
- gr.Markdown("13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]")
251
 
252
- with gr.Tabs():
253
- # Text-to-Video
254
  with gr.TabItem("Text-to-Video"):
255
  mode_t2v = gr.State("t2v")
256
- with gr.Row():
257
- with gr.Column(scale=2):
258
- prompt_t2v = gr.Textbox(label="Prompt", lines=4)
259
- neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
260
- height_t2v = gr.Slider(256,1024,step=64,value=480,label="Height")
261
- width_t2v = gr.Slider(256,1024,step=64,value=832,label="Width")
262
- seed_t2v = gr.Number(value=42,label="Seed")
263
- distill_t2v = gr.Checkbox(value=True,label="Use Distill Mode")
264
- refine_t2v = gr.Checkbox(value=False,label="Use Refine Mode")
265
- duration_t2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
266
- t2v_button = gr.Button("Generate Video")
267
- with gr.Column(scale=3):
268
- video_output_t2v = gr.Video(label="Generated Video")
 
 
 
 
 
269
 
270
- # Image-to-Video
271
  with gr.TabItem("Image-to-Video"):
272
  mode_i2v = gr.State("i2v")
273
- with gr.Row():
274
- with gr.Column(scale=2):
275
- image_i2v = gr.Image(type="numpy", label="Input Image")
276
- prompt_i2v = gr.Textbox(label="Prompt", lines=4)
277
- neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="blurry, low quality")
278
- resolution_i2v = gr.Dropdown(["480p","720p"], value="480p", label="Resolution")
279
- seed_i2v = gr.Number(value=42,label="Seed")
280
- distill_i2v = gr.Checkbox(value=True,label="Use Distill Mode")
281
- refine_i2v = gr.Checkbox(value=False,label="Use Refine Mode")
282
- duration_i2v = gr.Slider(1,20,step=1,value=2,label="Video Duration (seconds)")
283
- i2v_button = gr.Button("Generate Video")
284
- with gr.Column(scale=3):
285
- video_output_i2v = gr.Video(label="Generated Video")
286
-
287
- # Bind events
288
- t2v_button.click(
289
- generate_video,
290
- inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
291
- height_t2v, width_t2v, gr.State("480p"),
292
- seed_t2v, distill_t2v, refine_t2v, duration_t2v],
293
- outputs=video_output_t2v
294
- )
295
-
296
- i2v_button.click(
297
- generate_video,
298
- inputs=[mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
299
- gr.State(None), gr.State(None), resolution_i2v,
300
- seed_i2v, distill_i2v, refine_i2v, duration_i2v],
301
- outputs=video_output_i2v
302
- )
303
-
304
- # Launch
305
- if __name__=="__main__":
306
- demo.launch()
 
6
  import subprocess
7
  import tempfile
8
  import numpy as np
9
+ import spaces
 
10
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Define paths
 
 
 
13
  REPO_PATH = "LongCat-Video"
14
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
15
 
16
+ # Clone the repository if it doesn't exist
17
  if not os.path.exists(REPO_PATH):
18
+ print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
19
  subprocess.run(
20
  ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
21
  check=True
 
23
 
24
  sys.path.insert(0, os.path.abspath(REPO_PATH))
25
 
26
+ from huggingface_hub import snapshot_download
27
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
28
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
29
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
 
31
  from longcat_video.context_parallel import context_parallel_util
32
  from transformers import AutoTokenizer, UMT5EncoderModel
33
  from diffusers.utils import export_to_video
 
34
  from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
35
 
36
+ # Download weights if needed
37
  if not os.path.exists(CHECKPOINT_DIR):
38
  snapshot_download(
39
  repo_id="meituan-longcat/LongCat-Video",
 
42
  ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
43
  )
44
 
 
 
 
 
 
45
  pipe = None
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
48
 
49
+ print("--- Initializing Models ---")
50
  try:
51
  cp_split_hw = context_parallel_util.get_optimal_split(1)
52
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
53
+ text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
54
  vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
55
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
56
 
57
+ bnb_4bit_config = DiffusersBitsAndBytesConfig(
58
+ load_in_4bit=True,
59
+ bnb_4bit_quant_type="nf4",
60
+ bnb_4bit_compute_dtype=torch.bfloat16
61
+ )
62
+
63
  dit = LongCatVideoTransformer3DModel.from_pretrained(
64
  CHECKPOINT_DIR,
65
+ enable_flashattn3=False,
66
+ enable_flashattn2=False,
67
  enable_xformers=True,
68
  subfolder="dit",
69
  cp_split_hw=cp_split_hw,
70
+ torch_dtype=torch_dtype,
71
+ quantization_config=bnb_4bit_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
  pipe = LongCatVideoPipeline(
 
79
  dit=dit,
80
  )
81
  pipe.to(device)
82
+
83
+ pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors'), 'cfg_step_lora')
84
+ pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
85
 
86
  except Exception as e:
87
+ print("Failed to load models:", e)
88
  pipe = None
89
 
 
 
 
90
  def torch_gc():
91
  if torch.cuda.is_available():
92
  torch.cuda.empty_cache()
93
  torch.cuda.ipc_collect()
94
 
95
+ def check_duration(*args, duration_t2v=2, **kwargs):
96
+ fps = 30
97
+ return duration_t2v * fps # total frames
98
+
99
+ @spaces.GPU(duration=check_duration)
100
+ def generate_video(
101
  mode,
102
+ prompt,
103
+ neg_prompt,
104
  image,
105
  height, width, resolution,
106
  seed,
107
  use_distill,
108
  use_refine,
109
+ duration_t2v=2,
110
+ progress=gr.Progress(track_tqdm=True)
111
  ):
 
 
 
 
 
 
 
 
 
 
 
112
  if pipe is None:
113
+ raise gr.Error("Models failed to load.")
114
+
115
+ fps = 30
116
+ num_frames = duration_t2v * fps
117
+ print(prompt)
118
 
 
 
119
  generator = torch.Generator(device=device).manual_seed(int(seed))
120
  is_distill = use_distill or use_refine
 
121
 
122
+ if is_distill:
123
+ pipe.dit.enable_loras(['cfg_step_lora'])
124
+ num_inference_steps = 16
125
+ guidance_scale = 1.0
126
+ current_neg_prompt = ""
127
+ else:
128
+ num_inference_steps = 50
129
+ guidance_scale = 4.0
130
+ current_neg_prompt = neg_prompt
131
 
132
+ if mode == "t2v":
133
  output = pipe.generate_t2v(
134
  prompt=prompt,
135
+ negative_prompt=current_neg_prompt,
136
  height=height,
137
  width=width,
138
  num_frames=num_frames,
139
  num_inference_steps=num_inference_steps,
140
  use_distill=is_distill,
141
  guidance_scale=guidance_scale,
142
+ generator=generator,
143
  )[0]
144
  else:
145
+ pil_image = Image.fromarray(image)
146
  output = pipe.generate_i2v(
147
+ image=pil_image,
148
  prompt=prompt,
149
+ negative_prompt=current_neg_prompt,
150
  resolution=resolution,
151
  num_frames=num_frames,
152
  num_inference_steps=num_inference_steps,
153
  use_distill=is_distill,
154
  guidance_scale=guidance_scale,
155
+ generator=generator,
156
  )[0]
157
 
158
+ if is_distill:
159
+ pipe.dit.disable_all_loras()
160
  torch_gc()
161
 
162
  if use_refine:
163
+ progress(0.5, desc="Refinement")
164
  pipe.dit.enable_loras(['refinement_lora'])
165
  pipe.dit.enable_bsa()
166
+ stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output]
167
  stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
168
+ refine_image = Image.fromarray(image) if mode == 'i2v' else None
169
  output = pipe.generate_refine(
170
  image=refine_image,
171
  prompt=prompt,
172
  stage1_video=stage1_video_pil,
173
+ num_cond_frames=1 if mode == 'i2v' else 0,
174
  num_inference_steps=50,
175
+ generator=generator,
176
  )[0]
177
  pipe.dit.disable_all_loras()
178
  pipe.dit.disable_bsa()
179
  torch_gc()
180
 
181
  progress(1.0, desc="Exporting video")
182
+ print("video generated")
183
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
184
+ export_to_video(output, temp_video_file.name, fps=fps)
185
+ return temp_video_file.name
 
 
 
 
 
186
 
187
+ # --- Gradio UI ---
188
+ css = '.fillable{max-width: 960px !important}'
189
  with gr.Blocks(css=css) as demo:
190
+ gr.Markdown("# 🎬 LongCat-Video")
191
+ gr.Markdown("13.6B parameter dense video-generation model — [HuggingFace](https://huggingface.co/meituan-longcat/LongCat-Video)")
192
 
193
+ with gr.Tabs() as tabs:
 
194
  with gr.TabItem("Text-to-Video"):
195
  mode_t2v = gr.State("t2v")
196
+ prompt_t2v = gr.Textbox(label="Prompt", lines=4)
197
+ neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles")
198
+ height_t2v = gr.Slider(256, 1024, step=64, value=480, label="Height")
199
+ width_t2v = gr.Slider(256, 1024, step=64, value=832, label="Width")
200
+ seed_t2v = gr.Number(label="Seed", value=42, precision=0)
201
+ distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True)
202
+ refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False)
203
+ duration_t2v = gr.Slider(1, 20, step=1, value=2, label="Video Duration (seconds)")
204
+ t2v_button = gr.Button("Generate Video")
205
+ video_output_t2v = gr.Video(label="Generated Video", interactive=False)
206
+
207
+ t2v_button.click(
208
+ fn=generate_video,
209
+ inputs=[mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None),
210
+ height_t2v, width_t2v, gr.State(None),
211
+ seed_t2v, distill_t2v, refine_t2v, duration_t2v],
212
+ outputs=video_output_t2v
213
+ )
214
 
 
215
  with gr.TabItem("Image-to-Video"):
216
  mode_i2v = gr.State("i2v")
217
+ image_i2v = gr.Image(type="numpy", label="Input Image")
218
+ prompt_i2v = gr.Textbox(label="Prompt", lines=4)
219
+ neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark")
220
+ resolution_i2v = gr.Dropdown(["480p", "720p"], value="480p", label="Resolution")
221
+ seed_i2v = gr.Number(label="Seed", value=42, precision=0)
222
+ distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True)
223
+ refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False)
224
+ duration_i2v = gr.Slider(1, 20, step=1, value=2, label="Video Duration (seconds)")
225
+ i2v_button = gr.Button("Generate Video")
226
+ video_output_i2v = gr.Video(label="Generated Video", interactive=False)
227
+
228
+ i2v_button.click(
229
+ fn=generate_video,
230
+ inputs=[mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
231
+ gr.State(None), gr.State(None), resolution_i2v,
232
+ seed_i2v, distill_i2v, refine_i2v, duration_i2v],
233
+ outputs=video_output_i2v
234
+ )
235
+
236
+ if __name__ == "__main__":
237
+ demo.launch()