rahul7star commited on
Commit
c44ad11
·
verified ·
1 Parent(s): b651dea

Update app_exp.py

Browse files
Files changed (1) hide show
  1. app_exp.py +135 -134
app_exp.py CHANGED
@@ -1,195 +1,196 @@
1
- import spaces
2
  import os
3
  import sys
 
4
  import tempfile
5
- import datetime
6
  import numpy as np
7
  from PIL import Image
8
- import gradio as gr
9
  import torch
10
- import torch.distributed as dist
11
  from torchvision.io import write_video
12
 
13
- # ============================================================
14
- # 1️⃣ Repo & checkpoint paths
15
- # ============================================================
16
  REPO_PATH = "LongCat-Video"
17
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
 
18
  if not os.path.exists(REPO_PATH):
19
- subprocess.run(["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True)
 
 
 
20
 
21
  sys.path.insert(0, os.path.abspath(REPO_PATH))
22
-
23
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
24
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
25
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
26
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
27
- from longcat_video.context_parallel.context_parallel_util import init_context_parallel
28
  from longcat_video.context_parallel import context_parallel_util
29
  import cache_dit
 
30
  from transformers import AutoTokenizer, UMT5EncoderModel
 
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
- torch_dtype = torch.bfloat16 if device=="cuda" else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def torch_gc():
36
  if torch.cuda.is_available():
37
  torch.cuda.empty_cache()
38
  torch.cuda.ipc_collect()
39
 
40
- # ============================================================
41
- # 2️⃣ Model loader with cache & 4-bit/FP8 quantization
42
- # ============================================================
43
- def load_models(checkpoint_dir=CHECKPOINT_DIR, cp_size=1, quantize=True, cache=True):
44
- cp_split_hw = context_parallel_util.get_optimal_split(cp_size)
45
-
46
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir, subfolder="tokenizer", torch_dtype=torch_dtype)
47
- text_encoder = UMT5EncoderModel.from_pretrained(checkpoint_dir, subfolder="text_encoder", torch_dtype=torch_dtype)
48
- vae = AutoencoderKLWan.from_pretrained(checkpoint_dir, subfolder="vae", torch_dtype=torch_dtype)
49
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(checkpoint_dir, subfolder="scheduler", torch_dtype=torch_dtype)
50
-
51
- if quantize:
52
- from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
53
- quant_cfg = DiffusersBitsAndBytesConfig(
54
- load_in_4bit=True,
55
- bnb_4bit_quant_type="nf4",
56
- bnb_4bit_compute_dtype=torch_dtype
57
- )
58
- else:
59
- quant_cfg = None
60
-
61
- dit = LongCatVideoTransformer3DModel.from_pretrained(
62
- checkpoint_dir,
63
- subfolder="dit",
64
- cp_split_hw=cp_split_hw,
65
- torch_dtype=torch_dtype,
66
- quantization_config=quant_cfg
67
- )
68
-
69
- if cache:
70
- from cache_dit import enable_cache, BlockAdapter, ForwardPattern, DBCacheConfig
71
- enable_cache(
72
- BlockAdapter(transformer=dit, blocks=dit.blocks, forward_pattern=ForwardPattern.Pattern_3),
73
- cache_config=DBCacheConfig(Fn_compute_blocks=1)
74
- )
75
-
76
- pipe = LongCatVideoPipeline(
77
- tokenizer=tokenizer,
78
- text_encoder=text_encoder,
79
- vae=vae,
80
- scheduler=scheduler,
81
- dit=dit
82
- )
83
- pipe.to(device)
84
- return pipe
85
-
86
- pipe = load_models()
87
-
88
- # ============================================================
89
- # 3️⃣ LoRA refinement
90
- # ============================================================
91
- pipe.dit.load_lora(os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors'), 'refinement_lora')
92
- pipe.dit.enable_loras(['refinement_lora'])
93
- pipe.dit.enable_bsa()
94
-
95
- # ============================================================
96
- # 4️⃣ Video generation function
97
- # ============================================================
98
- @spaces.GPU(duration=60)
99
  def generate_video(
100
  mode,
101
  prompt,
102
  neg_prompt,
103
  image,
104
- height,
105
- width,
106
  num_frames,
107
  seed,
 
108
  use_refine,
109
  ):
 
 
 
110
  generator = torch.Generator(device=device).manual_seed(int(seed))
 
111
 
112
- if mode=="t2v":
 
 
 
 
 
 
 
 
 
 
113
  output = pipe.generate_t2v(
114
  prompt=prompt,
115
- negative_prompt=neg_prompt,
116
- height=height,
117
- width=width,
118
  num_frames=num_frames,
119
- num_inference_steps=50,
120
- guidance_scale=4.0,
121
- generator=generator
122
  )[0]
123
- else:
124
  pil_image = Image.fromarray(image)
125
  output = pipe.generate_i2v(
126
  image=pil_image,
127
  prompt=prompt,
128
- negative_prompt=neg_prompt,
129
- resolution=f"{height}x{width}",
130
  num_frames=num_frames,
131
- num_inference_steps=50,
132
- guidance_scale=4.0,
133
  generator=generator,
134
- use_kv_cache=True,
135
- offload_kv_cache=False
136
  )[0]
137
 
 
 
 
 
 
 
138
  if use_refine:
139
  pipe.dit.enable_loras(['refinement_lora'])
140
  pipe.dit.enable_bsa()
141
- stage1_video_pil = [(frame*255).astype(np.uint8) for frame in output]
142
- stage1_video_pil = [Image.fromarray(f) for f in stage1_video_pil]
143
 
 
 
 
 
144
  output = pipe.generate_refine(
145
- stage1_video=stage1_video_pil,
146
  prompt=prompt,
147
- num_cond_frames=1,
 
148
  num_inference_steps=50,
149
- generator=generator
150
  )[0]
151
 
152
- output_tensor = torch.from_numpy(np.array(output))
153
- output_tensor = (output_tensor*255).clamp(0,255).to(torch.uint8)
154
-
155
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
156
- write_video(f.name, output_tensor, fps=15, video_codec="libx264", options={"crf": "18"})
157
- return f.name
158
-
159
- # ============================================================
160
- # 5️⃣ Gradio interface
161
- # ============================================================
162
- with gr.Blocks() as demo:
163
- gr.Markdown("# 🎬 Optimized LongCat-Video Demo (FA3 removed)")
164
- with gr.Tab("Text-to-Video"):
165
- prompt_t2v = gr.Textbox(label="Prompt", lines=3)
166
- neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
167
- height_t2v = gr.Slider(256,1024,value=480,step=64,label="Height")
168
- width_t2v = gr.Slider(256,1024,value=832,step=64,label="Width")
169
- frames_t2v = gr.Slider(8,180,value=48,step=1,label="Frames")
170
- seed_t2v = gr.Number(value=42,label="Seed",precision=0)
171
- refine_t2v = gr.Checkbox(label="Use Refine",value=True)
172
- out_t2v = gr.Video(label="Generated Video")
173
- btn_t2v = gr.Button("Generate")
174
- btn_t2v.click(
175
- generate_video,
176
- inputs=["t2v", prompt_t2v, neg_prompt_t2v, None, height_t2v, width_t2v, frames_t2v, seed_t2v, refine_t2v],
177
- outputs=out_t2v
178
- )
179
- with gr.Tab("Image-to-Video"):
180
- image_i2v = gr.Image(type="numpy")
181
- prompt_i2v = gr.Textbox(label="Prompt", lines=3)
182
- neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
183
- frames_i2v = gr.Slider(8,180,value=48,step=1,label="Frames")
184
- seed_i2v = gr.Number(value=42,label="Seed",precision=0)
185
- refine_i2v = gr.Checkbox(label="Use Refine",value=True)
186
- out_i2v = gr.Video(label="Generated Video")
187
- btn_i2v = gr.Button("Generate")
188
- btn_i2v.click(
189
- generate_video,
190
- inputs=["i2v", prompt_i2v, neg_prompt_i2v, image_i2v, 480, 832, frames_i2v, seed_i2v, refine_i2v],
191
- outputs=out_i2v
192
- )
193
-
194
- if __name__=="__main__":
195
  demo.launch()
 
 
1
  import os
2
  import sys
3
+ import subprocess
4
  import tempfile
 
5
  import numpy as np
6
  from PIL import Image
7
+
8
  import torch
9
+ import gradio as gr
10
  from torchvision.io import write_video
11
 
12
+ # LongCat-Video imports
 
 
13
  REPO_PATH = "LongCat-Video"
14
  CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")
15
+
16
  if not os.path.exists(REPO_PATH):
17
+ subprocess.run(
18
+ ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
19
+ check=True
20
+ )
21
 
22
  sys.path.insert(0, os.path.abspath(REPO_PATH))
23
+ from huggingface_hub import snapshot_download
24
  from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
25
  from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
26
  from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
27
  from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
 
28
  from longcat_video.context_parallel import context_parallel_util
29
  import cache_dit
30
+
31
  from transformers import AutoTokenizer, UMT5EncoderModel
32
+ from diffusers.utils import export_to_video
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
36
+
37
+ # --- Download weights if missing ---
38
+ if not os.path.exists(CHECKPOINT_DIR):
39
+ snapshot_download(
40
+ repo_id="meituan-longcat/LongCat-Video",
41
+ local_dir=CHECKPOINT_DIR,
42
+ local_dir_use_symlinks=False,
43
+ ignore_patterns=["*.md", "*.gitattributes", "assets/*"]
44
+ )
45
+
46
+ # --- Initialize models ---
47
+ cp_split_hw = context_parallel_util.get_optimal_split(1)
48
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
49
+ text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
50
+ vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
51
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
52
+
53
+ dit = LongCatVideoTransformer3DModel.from_pretrained(
54
+ CHECKPOINT_DIR,
55
+ subfolder="dit",
56
+ cp_split_hw=cp_split_hw,
57
+ torch_dtype=torch_dtype,
58
+ enable_flashattn3=False,
59
+ enable_flashattn2=False,
60
+ enable_xformers=False # <- disables FA3/xFormers completely
61
+ )
62
+
63
+ pipe = LongCatVideoPipeline(
64
+ tokenizer=tokenizer,
65
+ text_encoder=text_encoder,
66
+ vae=vae,
67
+ scheduler=scheduler,
68
+ dit=dit,
69
+ )
70
+ pipe.to(device)
71
+
72
+ # --- Load LoRAs ---
73
+ cfg_lora = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors')
74
+ refine_lora = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors')
75
+ pipe.dit.load_lora(cfg_lora, 'cfg_step_lora')
76
+ pipe.dit.load_lora(refine_lora, 'refinement_lora')
77
+
78
+ # --- Enable Cache-DiT for DiT transformer ---
79
+ cache_dit.enable_cache(pipe.dit)
80
 
81
  def torch_gc():
82
  if torch.cuda.is_available():
83
  torch.cuda.empty_cache()
84
  torch.cuda.ipc_collect()
85
 
86
+ # --- Video generation function ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def generate_video(
88
  mode,
89
  prompt,
90
  neg_prompt,
91
  image,
 
 
92
  num_frames,
93
  seed,
94
+ use_distill,
95
  use_refine,
96
  ):
97
+ if pipe is None:
98
+ raise gr.Error("Models not loaded.")
99
+
100
  generator = torch.Generator(device=device).manual_seed(int(seed))
101
+ is_distill = use_distill or use_refine
102
 
103
+ if is_distill:
104
+ pipe.dit.enable_loras(['cfg_step_lora'])
105
+ num_inference_steps = 16
106
+ guidance_scale = 1.0
107
+ current_neg_prompt = ""
108
+ else:
109
+ num_inference_steps = 50
110
+ guidance_scale = 4.0
111
+ current_neg_prompt = neg_prompt
112
+
113
+ if mode == "t2v":
114
  output = pipe.generate_t2v(
115
  prompt=prompt,
116
+ negative_prompt=current_neg_prompt,
117
+ height=480,
118
+ width=832,
119
  num_frames=num_frames,
120
+ num_inference_steps=num_inference_steps,
121
+ guidance_scale=guidance_scale,
122
+ generator=generator,
123
  )[0]
124
+ else: # i2v
125
  pil_image = Image.fromarray(image)
126
  output = pipe.generate_i2v(
127
  image=pil_image,
128
  prompt=prompt,
129
+ negative_prompt=current_neg_prompt,
130
+ resolution="480p",
131
  num_frames=num_frames,
132
+ num_inference_steps=num_inference_steps,
133
+ guidance_scale=guidance_scale,
134
  generator=generator,
 
 
135
  )[0]
136
 
137
+ if is_distill:
138
+ pipe.dit.disable_all_loras()
139
+
140
+ torch_gc()
141
+
142
+ # Optional refinement
143
  if use_refine:
144
  pipe.dit.enable_loras(['refinement_lora'])
145
  pipe.dit.enable_bsa()
 
 
146
 
147
+ stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output]
148
+ stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]
149
+
150
+ refine_image = Image.fromarray(image) if mode == 'i2v' else None
151
  output = pipe.generate_refine(
152
+ image=refine_image,
153
  prompt=prompt,
154
+ stage1_video=stage1_video_pil,
155
+ num_cond_frames=1 if mode == 'i2v' else 0,
156
  num_inference_steps=50,
157
+ generator=generator,
158
  )[0]
159
 
160
+ pipe.dit.disable_all_loras()
161
+ pipe.dit.disable_bsa()
162
+ torch_gc()
163
+
164
+ # Export video
165
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
166
+ export_to_video(output, temp_video_file.name, fps=15)
167
+ return temp_video_file.name
168
+
169
+ # --- Gradio UI ---
170
+ css = ".fillable{max-width: 960px !important}"
171
+ with gr.Blocks(css=css) as demo:
172
+ gr.Markdown("# 🎬 LongCat-Video Optimized")
173
+ with gr.Row():
174
+ with gr.Column(scale=2):
175
+ prompt_input = gr.Textbox(label="Prompt", lines=4)
176
+ neg_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality")
177
+ seed_input = gr.Number(label="Seed", value=42, precision=0)
178
+ frames_slider = gr.Slider(16, 128, value=48, step=1, label="Number of Frames")
179
+ distill_checkbox = gr.Checkbox(label="Use Distill Mode", value=True)
180
+ refine_checkbox = gr.Checkbox(label="Use Refine Mode", value=False)
181
+ t2v_button = gr.Button("Generate Video")
182
+ with gr.Column(scale=3):
183
+ video_output = gr.Video(label="Generated Video", interactive=False)
184
+
185
+ t2v_button.click(
186
+ fn=generate_video,
187
+ inputs=[
188
+ gr.State("t2v"), prompt_input, neg_prompt_input,
189
+ gr.State(None), frames_slider, seed_input,
190
+ distill_checkbox, refine_checkbox
191
+ ],
192
+ outputs=video_output
193
+ )
194
+
195
+ if __name__ == "__main__":
 
 
 
 
 
 
 
196
  demo.launch()