badrerootunix commited on
Commit
0171411
·
1 Parent(s): a472df1

Optimize settings: 6 steps default, better resize logic, 24fps

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import spaces
3
  import torch
4
- from diffusers import WanImageToVideoPipeline
5
- from diffusers.utils import export_to_video
 
6
  import gradio as gr
7
  import tempfile
8
  import numpy as np
@@ -13,18 +14,18 @@ import random
13
  # MODEL CONFIGURATION
14
  # =========================================================
15
 
16
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
  MAX_DIM = 832
19
  MIN_DIM = 480
20
  SQUARE_DIM = 640
21
  MULTIPLE_OF = 16
22
  MAX_SEED = np.iinfo(np.int32).max
23
- FIXED_FPS = 16
24
  MIN_FRAMES_MODEL = 8
25
- MAX_FRAMES_MODEL = 49
26
  MIN_DURATION = 0.5
27
- MAX_DURATION = 2.0
28
 
29
  # =========================================================
30
  # LOAD PIPELINE
@@ -42,7 +43,7 @@ pipe = WanImageToVideoPipeline.from_pretrained(
42
  # =========================================================
43
 
44
  default_prompt_i2v = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
45
- default_negative_prompt = "low quality, worst quality, blurry, distorted, deformed, ugly, bad anatomy, static, frozen"
46
 
47
  # =========================================================
48
  # IMAGE RESIZING LOGIC
@@ -50,31 +51,46 @@ default_negative_prompt = "low quality, worst quality, blurry, distorted, deform
50
 
51
  def resize_image(image: Image.Image) -> Image.Image:
52
  width, height = image.size
53
-
54
- # Determine orientation and set target dimensions
55
- if width > height: # Landscape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  target_w = MAX_DIM
57
- target_h = MIN_DIM
58
- elif height > width: # Portrait
59
- target_w = MIN_DIM
60
  target_h = MAX_DIM
61
- else: # Square
62
- target_w = SQUARE_DIM
63
- target_h = SQUARE_DIM
64
 
65
- # Make divisible by 16
66
- target_w = (target_w // MULTIPLE_OF) * MULTIPLE_OF
67
- target_h = (target_h // MULTIPLE_OF) * MULTIPLE_OF
 
68
 
69
- return image.resize((target_w, target_h), Image.LANCZOS)
70
 
71
  # =========================================================
72
  # UTILITY FUNCTIONS
73
  # =========================================================
74
 
75
  def get_num_frames(duration_seconds: float):
76
- frames = int(round(duration_seconds * FIXED_FPS))
77
- return max(MIN_FRAMES_MODEL, min(MAX_FRAMES_MODEL, frames))
78
 
79
  # =========================================================
80
  # MAIN GENERATION FUNCTION
@@ -85,8 +101,8 @@ def generate_video(
85
  input_image,
86
  prompt,
87
  negative_prompt=default_negative_prompt,
88
- duration_seconds=1.5,
89
- steps=4,
90
  guidance_scale=1.0,
91
  seed=42,
92
  randomize_seed=False,
@@ -158,7 +174,7 @@ with gr.Blocks() as demo:
158
  minimum=MIN_DURATION,
159
  maximum=MAX_DURATION,
160
  step=0.5,
161
- value=1.0,
162
  label="Duration (seconds)"
163
  )
164
 
@@ -173,7 +189,7 @@ with gr.Blocks() as demo:
173
  minimum=4,
174
  maximum=12,
175
  step=1,
176
- value=8,
177
  label="Inference Steps"
178
  )
179
 
 
1
  import os
2
  import spaces
3
  import torch
4
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
5
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
6
+ from diffusers.utils.export_utils import export_to_video
7
  import gradio as gr
8
  import tempfile
9
  import numpy as np
 
14
  # MODEL CONFIGURATION
15
  # =========================================================
16
 
17
+ MODEL_ID = os.getenv("MODEL_ID", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers")
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  MAX_DIM = 832
20
  MIN_DIM = 480
21
  SQUARE_DIM = 640
22
  MULTIPLE_OF = 16
23
  MAX_SEED = np.iinfo(np.int32).max
24
+ FIXED_FPS = 24
25
  MIN_FRAMES_MODEL = 8
26
+ MAX_FRAMES_MODEL = 81
27
  MIN_DURATION = 0.5
28
+ MAX_DURATION = 3.0
29
 
30
  # =========================================================
31
  # LOAD PIPELINE
 
43
  # =========================================================
44
 
45
  default_prompt_i2v = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
46
+ default_negative_prompt = "low quality, worst quality, blurry, distorted, deformed, ugly, bad anatomy, static, frozen, overall gray"
47
 
48
  # =========================================================
49
  # IMAGE RESIZING LOGIC
 
51
 
52
  def resize_image(image: Image.Image) -> Image.Image:
53
  width, height = image.size
54
+ if width == height:
55
+ return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
56
+
57
+ aspect_ratio = width / height
58
+ MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
59
+ MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
60
+
61
+ image_to_resize = image
62
+ if aspect_ratio > MAX_ASPECT_RATIO:
63
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
64
+ left = (width - crop_width) // 2
65
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
66
+ elif aspect_ratio < MIN_ASPECT_RATIO:
67
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
68
+ top = (height - crop_height) // 2
69
+ image_to_resize = image.crop((0, top, width, top + crop_height))
70
+
71
+ current_width, current_height = image_to_resize.size
72
+ current_aspect = current_width / current_height
73
+
74
+ if current_width > current_height:
75
  target_w = MAX_DIM
76
+ target_h = int(round(target_w / current_aspect))
77
+ else:
 
78
  target_h = MAX_DIM
79
+ target_w = int(round(target_h * current_aspect))
 
 
80
 
81
+ final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
82
+ final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
83
+ final_w = max(MIN_DIM, min(MAX_DIM, final_w))
84
+ final_h = max(MIN_DIM, min(MAX_DIM, final_h))
85
 
86
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
87
 
88
  # =========================================================
89
  # UTILITY FUNCTIONS
90
  # =========================================================
91
 
92
  def get_num_frames(duration_seconds: float):
93
+ return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
 
94
 
95
  # =========================================================
96
  # MAIN GENERATION FUNCTION
 
101
  input_image,
102
  prompt,
103
  negative_prompt=default_negative_prompt,
104
+ duration_seconds=2.0,
105
+ steps=6,
106
  guidance_scale=1.0,
107
  seed=42,
108
  randomize_seed=False,
 
174
  minimum=MIN_DURATION,
175
  maximum=MAX_DURATION,
176
  step=0.5,
177
+ value=2.0,
178
  label="Duration (seconds)"
179
  )
180
 
 
189
  minimum=4,
190
  maximum=12,
191
  step=1,
192
+ value=6,
193
  label="Inference Steps"
194
  )
195