Sushantkas commited on
Commit
76374a0
·
verified ·
1 Parent(s): 1d7f1a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -3
app.py CHANGED
@@ -45,21 +45,65 @@ except:
45
 
46
 
47
  try:
48
- print("Enabling vae Slicing")
49
  pipe.enable_vae_slicing()
50
  print("VAE Slicing Enabled")
51
- except:
52
- print("Model VAE Slicing Failed")
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
54
 
 
 
 
 
 
55
 
 
 
56
 
 
 
 
 
 
57
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
 
62
 
 
63
  def greet(name):
64
  return "Hello " + name + "!!"
65
 
 
45
 
46
 
47
  try:
48
+ print("Enabling VAE Slicing")
49
  pipe.enable_vae_slicing()
50
  print("VAE Slicing Enabled")
51
+ except Exception as e:
52
+ print("VAE Slicing Failed")
53
+ print("Error:", str(e))
54
+
55
+
56
+
57
+
58
+ # Loading function for Image
59
+ from diffusers.utils import load_image
60
+
61
+ def prepare_vertical_image(pipe, image_path, base_width=384, base_height=672):
62
+ """
63
+ Loads and resizes an image for Wan I2V vertical video generation.
64
 
65
+ Args:
66
+ pipe: WanImageToVideoPipeline (already loaded)
67
+ image_path (str): Path or URL to image
68
+ base_width (int): Desired width before adjustment
69
+ base_height (int): Desired height before adjustment
70
 
71
+ Returns:
72
+ resized_image (PIL.Image)
73
+ final_width (int)
74
+ final_height (int)
75
+ """
76
 
77
+ # Load image
78
+ image = load_image(image_path)
79
 
80
+ # Ensure compatibility with Wan spatial constraints
81
+ mod_value = (
82
+ pipe.vae_scale_factor_spatial *
83
+ pipe.transformer.config.patch_size[1]
84
+ )
85
 
86
+ final_width = (base_width // mod_value) * mod_value
87
+ final_height = (base_height // mod_value) * mod_value
88
 
89
+ resized_image = image.resize((final_width, final_height))
90
+
91
+ return resized_image, final_width, final_height
92
+
93
+
94
+
95
+ # how to use the Image loading
96
+ image, width, height = prepare_vertical_image(
97
+ pipe,
98
+ "input.jpg",
99
+ base_width=384,
100
+ base_height=672
101
+ )
102
 
103
 
104
 
105
 
106
+ @spaces.GPU(duration=60)
107
  def greet(name):
108
  return "Hello " + name + "!!"
109