IceClear commited on
Commit
1ac903d
·
1 Parent(s): 512f3c8
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -11,6 +11,9 @@
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
 
 
 
14
  import os
15
  import torch
16
  import mediapy
@@ -58,6 +61,27 @@ from pathlib import Path
58
  from urllib.parse import urlparse
59
  from torch.hub import download_url_to_file, get_dir
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
62
  """Load file form http url, will download models if necessary.
63
 
@@ -126,6 +150,7 @@ def configure_sequence_parallel(sp_size):
126
  if sp_size > 1:
127
  init_sequence_parallel(sp_size)
128
 
 
129
  def configure_runner(sp_size):
130
  config_path = os.path.join('./configs_3b', 'main.yaml')
131
  config = load_config(config_path)
@@ -141,6 +166,7 @@ def configure_runner(sp_size):
141
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
142
  return runner
143
 
 
144
  def generation_step(runner, text_embeds_dict, cond_latents):
145
  def _move_to_cuda(x):
146
  return [i.to(torch.device("cuda")) for i in x]
@@ -197,6 +223,7 @@ def generation_step(runner, text_embeds_dict, cond_latents):
197
 
198
  return samples
199
 
 
200
  def generation_loop(video_path='./test_videos', output_dir='./results', seed=666, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
201
  runner = configure_runner(1)
202
  output_dir = 'output/out.mp4'
 
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
+ from torchvision.transforms import functional as TVF
15
+ import spaces
16
+ from typing import Union
17
  import os
18
  import torch
19
  import mediapy
 
61
  from urllib.parse import urlparse
62
  from torch.hub import download_url_to_file, get_dir
63
 
64
+ class DivisibleCrop:
65
+ def __init__(self, factor):
66
+ if not isinstance(factor, tuple):
67
+ factor = (factor, factor)
68
+
69
+ self.height_factor, self.width_factor = factor[0], factor[1]
70
+
71
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
72
+ if isinstance(image, torch.Tensor):
73
+ height, width = image.shape[-2:]
74
+ elif isinstance(image, Image.Image):
75
+ width, height = image.size
76
+ else:
77
+ raise NotImplementedError
78
+
79
+ cropped_height = height - (height % self.height_factor)
80
+ cropped_width = width - (width % self.width_factor)
81
+
82
+ image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width))
83
+ return image
84
+
85
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
86
  """Load file form http url, will download models if necessary.
87
 
 
150
  if sp_size > 1:
151
  init_sequence_parallel(sp_size)
152
 
153
+ @spaces.GPU(duration=90)
154
  def configure_runner(sp_size):
155
  config_path = os.path.join('./configs_3b', 'main.yaml')
156
  config = load_config(config_path)
 
166
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
167
  return runner
168
 
169
+ @spaces.GPU(duration=90)
170
  def generation_step(runner, text_embeds_dict, cond_latents):
171
  def _move_to_cuda(x):
172
  return [i.to(torch.device("cuda")) for i in x]
 
223
 
224
  return samples
225
 
226
+ @spaces.GPU(duration=90)
227
  def generation_loop(video_path='./test_videos', output_dir='./results', seed=666, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
228
  runner = configure_runner(1)
229
  output_dir = 'output/out.mp4'