Spaces:
Build error
Build error
Haoxin Chen commited on
Commit ·
15190a9
1
Parent(s): a8f3a29
fix ckpt path
Browse files- i2v_test.py +1 -1
- t2v_test.py +8 -8
i2v_test.py
CHANGED
|
@@ -68,7 +68,7 @@ class Image2Video():
|
|
| 68 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 69 |
|
| 70 |
def download_model(self):
|
| 71 |
-
REPO_ID = 'VideoCrafter/Image2Video-512
|
| 72 |
filename_list = ['model.ckpt']
|
| 73 |
if not os.path.exists('./checkpoints/i2v_512_v1/'):
|
| 74 |
os.makedirs('./checkpoints/i2v_512_v1/')
|
|
|
|
| 68 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 69 |
|
| 70 |
def download_model(self):
|
| 71 |
+
REPO_ID = 'VideoCrafter/Image2Video-512'
|
| 72 |
filename_list = ['model.ckpt']
|
| 73 |
if not os.path.exists('./checkpoints/i2v_512_v1/'):
|
| 74 |
os.makedirs('./checkpoints/i2v_512_v1/')
|
t2v_test.py
CHANGED
|
@@ -12,8 +12,8 @@ class Text2Video():
|
|
| 12 |
self.result_dir = result_dir
|
| 13 |
if not os.path.exists(self.result_dir):
|
| 14 |
os.mkdir(self.result_dir)
|
| 15 |
-
ckpt_path='checkpoints/
|
| 16 |
-
config_file='configs/
|
| 17 |
config = OmegaConf.load(config_file)
|
| 18 |
model_config = config.pop("model", OmegaConf.create())
|
| 19 |
model_config['params']['unet_config']['params']['use_checkpoint']=False
|
|
@@ -39,7 +39,7 @@ class Text2Video():
|
|
| 39 |
batch_size=1
|
| 40 |
channels = model.model.diffusion_model.in_channels
|
| 41 |
frames = model.temporal_length
|
| 42 |
-
h, w =
|
| 43 |
noise_shape = [batch_size, channels, frames, h, w]
|
| 44 |
|
| 45 |
#prompts = batch_size * [""]
|
|
@@ -59,15 +59,15 @@ class Text2Video():
|
|
| 59 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 60 |
|
| 61 |
def download_model(self):
|
| 62 |
-
REPO_ID = 'VideoCrafter/Text2Video-
|
| 63 |
filename_list = ['model.ckpt']
|
| 64 |
-
if not os.path.exists('./checkpoints/
|
| 65 |
-
os.makedirs('./checkpoints/
|
| 66 |
for filename in filename_list:
|
| 67 |
-
local_file = os.path.join('./checkpoints/
|
| 68 |
|
| 69 |
if not os.path.exists(local_file):
|
| 70 |
-
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/
|
| 71 |
|
| 72 |
|
| 73 |
if __name__ == '__main__':
|
|
|
|
| 12 |
self.result_dir = result_dir
|
| 13 |
if not os.path.exists(self.result_dir):
|
| 14 |
os.mkdir(self.result_dir)
|
| 15 |
+
ckpt_path='checkpoints/base_1024_v1/model.ckpt'
|
| 16 |
+
config_file='configs/inference_t2v_1024_v1.0.yaml'
|
| 17 |
config = OmegaConf.load(config_file)
|
| 18 |
model_config = config.pop("model", OmegaConf.create())
|
| 19 |
model_config['params']['unet_config']['params']['use_checkpoint']=False
|
|
|
|
| 39 |
batch_size=1
|
| 40 |
channels = model.model.diffusion_model.in_channels
|
| 41 |
frames = model.temporal_length
|
| 42 |
+
h, w = 576 // 8, 1024 // 8
|
| 43 |
noise_shape = [batch_size, channels, frames, h, w]
|
| 44 |
|
| 45 |
#prompts = batch_size * [""]
|
|
|
|
| 59 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 60 |
|
| 61 |
def download_model(self):
|
| 62 |
+
REPO_ID = 'VideoCrafter/Text2Video-1024'
|
| 63 |
filename_list = ['model.ckpt']
|
| 64 |
+
if not os.path.exists('./checkpoints/base_1024_v1/'):
|
| 65 |
+
os.makedirs('./checkpoints/base_1024_v1/')
|
| 66 |
for filename in filename_list:
|
| 67 |
+
local_file = os.path.join('./checkpoints/base_1024_v1/', filename)
|
| 68 |
|
| 69 |
if not os.path.exists(local_file):
|
| 70 |
+
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_1024_v1/', local_dir_use_symlinks=False)
|
| 71 |
|
| 72 |
|
| 73 |
if __name__ == '__main__':
|