xxxpo13 commited on
Commit
cc00960
·
verified ·
1 Parent(s): 354fcf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -95
app.py CHANGED
@@ -1,59 +1,256 @@
1
  import os
2
  import uuid
3
  import gradio as gr
4
- import subprocess
5
- import tempfile
6
- import shutil
7
-
8
- cpu_offloading=True
9
-
10
- def run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt):
11
- """
12
- Runs the external multi-GPU inference script and returns the path to the generated video.
13
- """
14
- # Create a temporary directory to store inputs and outputs
15
- with tempfile.TemporaryDirectory() as tmpdir:
16
- output_video = os.path.join(tmpdir, f"{uuid.uuid4()}_output.mp4")
17
-
18
- # Path to the external shell script
19
- script_path = "./scripts/app_multigpu_engine.sh" # Updated script path
20
-
21
- # Prepare the command
22
- cmd = [
23
- script_path,
24
- str(gpus),
25
- variant,
26
- model_path,
27
- 't2v', # Task is always 't2v' since 'i2v' is removed
28
- str(temp),
29
- str(guidance_scale),
30
- str(video_guidance_scale),
31
- resolution,
32
- output_video,
33
- prompt # Pass the prompt directly as an argument
34
- ]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- # Run the external script
38
- subprocess.run(cmd, check=True)
39
- except subprocess.CalledProcessError as e:
40
- raise RuntimeError(f"Error during video generation: {e}")
41
-
42
- # After generation, move the video to a permanent location
43
- final_output = os.path.join("generated_videos", f"{uuid.uuid4()}_output.mp4")
44
- os.makedirs("generated_videos", exist_ok=True)
45
- shutil.move(output_video, final_output)
46
-
47
- return final_output
48
-
49
- def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, gpus):
50
- model_path = "./pyramid_flow_model" # Use the model path as specified
51
- # Determine variant based on resolution
52
- if resolution == "768p":
53
- variant = "diffusion_transformer_768p"
54
  else:
55
- variant = "diffusion_transformer_384p"
56
- return run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Gradio interface
59
  with gr.Blocks() as demo:
@@ -69,11 +266,6 @@ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model b
69
 
70
  # Shared settings
71
  with gr.Row():
72
- gpus_dropdown = gr.Dropdown(
73
- choices=[2, 4],
74
- value=4,
75
- label="Number of GPUs"
76
- )
77
  resolution_dropdown = gr.Dropdown(
78
  choices=["768p", "384p"],
79
  value="768p",
@@ -83,11 +275,7 @@ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model b
83
  with gr.Tab("Text-to-Video"):
84
  with gr.Row():
85
  with gr.Column():
86
- text_prompt = gr.Textbox(
87
- label="Prompt (Less than 128 words)",
88
- placeholder="Enter a text prompt for the video",
89
- lines=2
90
- )
91
  temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
92
  guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
93
  video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
@@ -96,50 +284,48 @@ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model b
96
  txt_output = gr.Video(label="Generated Video")
97
  gr.Examples(
98
  examples=[
99
- [
100
- "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
101
- 16,
102
- 9.0,
103
- 5.0,
104
- "768p",
105
- 4
106
- ],
107
- [
108
- "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes",
109
- 16,
110
- 9.0,
111
- 5.0,
112
- "768p",
113
- 4
114
- ],
115
- [
116
- "Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours",
117
- 31,
118
- 9.0,
119
- 5.0,
120
- "768p",
121
- 4
122
- ],
123
  ],
124
- inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, gpus_dropdown],
125
  outputs=[txt_output],
126
  fn=generate_text_to_video,
127
  cache_examples='lazy',
128
  )
129
 
130
- # Update generate function for Text-to-Video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  txt_generate.click(
132
  generate_text_to_video,
133
- inputs=[
134
- text_prompt,
135
- temp_slider,
136
- guidance_scale_slider,
137
- video_guidance_scale_slider,
138
- resolution_dropdown,
139
- gpus_dropdown
140
- ],
141
  outputs=txt_output
142
  )
143
 
 
 
 
 
 
 
144
  # Launch Gradio app
145
- demo.launch(share=False)
 
1
  import os
2
  import uuid
3
  import gradio as gr
4
+ import torch
5
+ import PIL
6
+ from PIL import Image
7
+ from pyramid_dit import PyramidDiTForVideoGeneration
8
+ from diffusers.utils import export_to_video
9
+ from huggingface_hub import snapshot_download
10
+ import threading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Global model cache
13
+ model_cache = {}
14
+
15
+ # Lock to ensure thread-safe access to the model cache
16
+ model_cache_lock = threading.Lock()
17
+
18
+ # Configuration
19
+ model_repo = "rain1011/pyramid-flow-sd3" # Replace with the actual model repository on Hugging Face
20
+ model_dtype = "bf16" # Support bf16 and fp32
21
+ variants = {
22
+ 'high': 'diffusion_transformer_768p', # For high-resolution version
23
+ 'low': 'diffusion_transformer_384p' # For low-resolution version
24
+ }
25
+ required_file = 'config.json' # Ensure config.json is present
26
+ width_high = 1280
27
+ height_high = 768
28
+ width_low = 640
29
+ height_low = 384
30
+ cpu_offloading = True # enable cpu_offloading by default
31
+
32
+ # Get the current working directory and create a folder to store the model
33
+ current_directory = os.getcwd()
34
+ model_path = os.path.join(current_directory, "pyramid_flow_model") # Directory to store the model
35
+
36
+ # Download the model if not already present
37
+ def download_model_from_hf(model_repo, model_dir, variants, required_file):
38
+ need_download = False
39
+ if not os.path.exists(model_dir):
40
+ print(f"[INFO] Model directory '{model_dir}' does not exist. Initiating download...")
41
+ need_download = True
42
+ else:
43
+ # Check if all required files exist for each variant
44
+ for variant_key, variant_dir in variants.items():
45
+ variant_path = os.path.join(model_dir, variant_dir)
46
+ file_path = os.path.join(variant_path, required_file)
47
+ if not os.path.exists(file_path):
48
+ print(f"[WARNING] Required file '{required_file}' missing in '{variant_path}'.")
49
+ need_download = True
50
+ break
51
+
52
+ if need_download:
53
+ print(f"[INFO] Downloading model from '{model_repo}' to '{model_dir}'...")
54
  try:
55
+ snapshot_download(
56
+ repo_id=model_repo,
57
+ local_dir=model_dir,
58
+ local_dir_use_symlinks=False,
59
+ repo_type='model'
60
+ )
61
+ print("[INFO] Model download complete.")
62
+ except Exception as e:
63
+ print(f"[ERROR] Failed to download the model: {e}")
64
+ raise
 
 
 
 
 
 
 
65
  else:
66
+ print(f"[INFO] All required model files are present in '{model_dir}'. Skipping download.")
67
+
68
+ # Download model from Hugging Face if not present
69
+ download_model_from_hf(model_repo, model_path, variants, required_file)
70
+
71
+ # Function to initialize the model based on user options
72
+ def initialize_model(variant):
73
+ print(f"[INFO] Initializing model with variant='{variant}', using bf16 precision...")
74
+
75
+ # Determine the correct variant directory
76
+ variant_dir = variants['high'] if variant == '768p' else variants['low']
77
+ base_path = model_path # Pass the base model path
78
+
79
+ print(f"[DEBUG] Model base path: {base_path}")
80
+
81
+ # Verify that config.json exists in the variant directory
82
+ config_path = os.path.join(model_path, variant_dir, 'config.json')
83
+ if not os.path.exists(config_path):
84
+ print(f"[ERROR] config.json not found in '{os.path.join(model_path, variant_dir)}'.")
85
+ raise FileNotFoundError(f"config.json not found in '{os.path.join(model_path, variant_dir)}'.")
86
+
87
+ if model_dtype == "bf16":
88
+ torch_dtype_selected = torch.bfloat16
89
+ else:
90
+ torch_dtype_selected = torch.float32
91
+
92
+ # Initialize the model
93
+ try:
94
+ model = PyramidDiTForVideoGeneration(
95
+ base_path, # Pass the base model path
96
+ model_dtype=model_dtype, # Use bf16
97
+ model_variant=variant_dir, # Pass the variant directory name
98
+ cpu_offloading=cpu_offloading, # Pass the CPU offloading flag
99
+ )
100
+
101
+ # Always enable tiling for the VAE
102
+ model.vae.enable_tiling()
103
+
104
+ # Remove manual device placement when using CPU offloading
105
+ # The components will be moved to the appropriate devices automatically
106
+ if torch.cuda.is_available():
107
+ torch.cuda.set_device(0)
108
+ # Manual device replacement when not using CPU offloading
109
+ if not cpu_offloading:
110
+ model.vae.to("cuda")
111
+ model.dit.to("cuda")
112
+ model.text_encoder.to("cuda")
113
+ else:
114
+ print("[WARNING] CUDA is not available. Proceeding without GPU.")
115
+
116
+ print("[INFO] Model initialized successfully.")
117
+ return model, torch_dtype_selected
118
+ except Exception as e:
119
+ print(f"[ERROR] Error initializing model: {e}")
120
+ raise
121
+
122
+ # Function to get the model from cache or initialize it
123
+ def initialize_model_cached(variant):
124
+ key = variant
125
+
126
+ # Check if the model is already in the cache
127
+ if key not in model_cache:
128
+ with model_cache_lock:
129
+ # Double-checked locking to prevent race conditions
130
+ if key not in model_cache:
131
+ model, dtype = initialize_model(variant)
132
+ model_cache[key] = (model, dtype)
133
+
134
+ return model_cache[key]
135
+
136
+ def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):
137
+ ori_width, ori_height = img.width, img.height
138
+ scale = max(tgt_width / ori_width, tgt_height / ori_height)
139
+ resized_width = round(ori_width * scale)
140
+ resized_height = round(ori_height * scale)
141
+ img = img.resize((resized_width, resized_height), resample=PIL.Image.LANCZOS)
142
+
143
+ left = (resized_width - tgt_width) / 2
144
+ top = (resized_height - tgt_height) / 2
145
+ right = (resized_width + tgt_width) / 2
146
+ bottom = (resized_height + tgt_height) / 2
147
+
148
+ # Crop the center of the image
149
+ img = img.crop((left, top, right, bottom))
150
+
151
+ return img
152
+
153
+ # Function to generate text-to-video
154
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, progress=gr.Progress()):
155
+ progress(0, desc="Loading model")
156
+ print("[DEBUG] generate_text_to_video called.")
157
+ variant = '768p' if resolution == "768p" else '384p'
158
+ height = height_high if resolution == "768p" else height_low
159
+ width = width_high if resolution == "768p" else width_low
160
+
161
+ def progress_callback(i, m):
162
+ progress(i/m)
163
+
164
+ # Initialize model based on user options using cached function
165
+ try:
166
+ model, torch_dtype_selected = initialize_model_cached(variant)
167
+ except Exception as e:
168
+ print(f"[ERROR] Model initialization failed: {e}")
169
+ return f"Model initialization failed: {e}"
170
+
171
+ try:
172
+ print("[INFO] Starting text-to-video generation...")
173
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
174
+ frames = model.generate(
175
+ prompt=prompt,
176
+ num_inference_steps=[20, 20, 20],
177
+ video_num_inference_steps=[10, 10, 10],
178
+ height=height,
179
+ width=width,
180
+ temp=temp,
181
+ guidance_scale=guidance_scale,
182
+ video_guidance_scale=video_guidance_scale,
183
+ output_type="pil",
184
+ cpu_offloading=cpu_offloading,
185
+ save_memory=True,
186
+ callback=progress_callback,
187
+ )
188
+ print("[INFO] Text-to-video generation completed.")
189
+ except Exception as e:
190
+ print(f"[ERROR] Error during text-to-video generation: {e}")
191
+ return f"Error during video generation: {e}"
192
+
193
+ video_path = f"{str(uuid.uuid4())}_text_to_video_sample.mp4"
194
+ try:
195
+ export_to_video(frames, video_path, fps=24)
196
+ print(f"[INFO] Video exported to {video_path}.")
197
+ except Exception as e:
198
+ print(f"[ERROR] Error exporting video: {e}")
199
+ return f"Error exporting video: {e}"
200
+ return video_path
201
+
202
+ # Function to generate image-to-video
203
+ def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution, progress=gr.Progress()):
204
+ progress(0, desc="Loading model")
205
+ print("[DEBUG] generate_image_to_video called.")
206
+ variant = '768p' if resolution == "768p" else '384p'
207
+ height = height_high if resolution == "768p" else height_low
208
+ width = width_high if resolution == "768p" else width_low
209
+
210
+ try:
211
+ image = resize_crop_image(image, width, height)
212
+ print("[INFO] Image resized and cropped successfully.")
213
+ except Exception as e:
214
+ print(f"[ERROR] Error processing image: {e}")
215
+ return f"Error processing image: {e}"
216
+
217
+ def progress_callback(i, m):
218
+ progress(i/m)
219
+
220
+ # Initialize model based on user options using cached function
221
+ try:
222
+ model, torch_dtype_selected = initialize_model_cached(variant)
223
+ except Exception as e:
224
+ print(f"[ERROR] Model initialization failed: {e}")
225
+ return f"Model initialization failed: {e}"
226
+
227
+ try:
228
+ print("[INFO] Starting image-to-video generation...")
229
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
230
+ frames = model.generate_i2v(
231
+ prompt=prompt,
232
+ input_image=image,
233
+ num_inference_steps=[10, 10, 10],
234
+ temp=temp,
235
+ video_guidance_scale=video_guidance_scale,
236
+ output_type="pil",
237
+ cpu_offloading=True,
238
+ save_memory=True,
239
+ callback=progress_callback,
240
+ )
241
+ print("[INFO] Image-to-video generation completed.")
242
+ except Exception as e:
243
+ print(f"[ERROR] Error during image-to-video generation: {e}")
244
+ return f"Error during video generation: {e}"
245
+
246
+ video_path = f"{str(uuid.uuid4())}_image_to_video_sample.mp4"
247
+ try:
248
+ export_to_video(frames, video_path, fps=24)
249
+ print(f"[INFO] Video exported to {video_path}.")
250
+ except Exception as e:
251
+ print(f"[ERROR] Error exporting video: {e}")
252
+ return f"Error exporting video: {e}"
253
+ return video_path
254
 
255
  # Gradio interface
256
  with gr.Blocks() as demo:
 
266
 
267
  # Shared settings
268
  with gr.Row():
 
 
 
 
 
269
  resolution_dropdown = gr.Dropdown(
270
  choices=["768p", "384p"],
271
  value="768p",
 
275
  with gr.Tab("Text-to-Video"):
276
  with gr.Row():
277
  with gr.Column():
278
+ text_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
 
 
 
 
279
  temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
280
  guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
281
  video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
 
284
  txt_output = gr.Video(label="Generated Video")
285
  gr.Examples(
286
  examples=[
287
+ ["A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors", 16, 9.0, 5.0, "768p"],
288
+ ["Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes", 16, 9.0, 5.0, "768p"],
289
+ ["Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours", 31, 9.0, 5.0, "768p"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  ],
291
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown],
292
  outputs=[txt_output],
293
  fn=generate_text_to_video,
294
  cache_examples='lazy',
295
  )
296
 
297
+ with gr.Tab("Image-to-Video"):
298
+ with gr.Row():
299
+ with gr.Column():
300
+ image_input = gr.Image(type="pil", label="Input Image")
301
+ image_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
302
+ image_temp_slider = gr.Slider(2, 16, value=16, step=1, label="Duration")
303
+ image_video_guidance_scale_slider = gr.Slider(1.0, 7.0, value=4.0, step=0.1, label="Video Guidance Scale")
304
+ img_generate = gr.Button("Generate Video")
305
+ with gr.Column():
306
+ img_output = gr.Video(label="Generated Video")
307
+ gr.Examples(
308
+ examples=[
309
+ ['assets/the_great_wall.jpg', 'FPV flying over the Great Wall', 16, 4.0, "768p"]
310
+ ],
311
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown],
312
+ outputs=[img_output],
313
+ fn=generate_image_to_video,
314
+ cache_examples='lazy',
315
+ )
316
+
317
+ # Update generate functions to include resolution options
318
  txt_generate.click(
319
  generate_text_to_video,
320
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown],
 
 
 
 
 
 
 
321
  outputs=txt_output
322
  )
323
 
324
+ img_generate.click(
325
+ generate_image_to_video,
326
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown],
327
+ outputs=img_output
328
+ )
329
+
330
  # Launch Gradio app
331
+ demo.launch(share=True)