Spaces:
Runtime error
Runtime error
Update demo/t2v.py
Browse files- demo/t2v.py +8 -2
demo/t2v.py
CHANGED
|
@@ -36,14 +36,14 @@ class Text2Video():
|
|
| 36 |
self.download_model(model_folder, model_filename)
|
| 37 |
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
|
| 38 |
self.download_internvideo2(model_folder)
|
| 39 |
-
self.agent = torch.load(os.path.join(model_folder, model_filename))
|
| 40 |
model_name = 'internvideo2'
|
| 41 |
|
| 42 |
# Get ViCLIP
|
| 43 |
viclip_global_instance = ViCLIPGlobalInstance(model_name)
|
| 44 |
if not viclip_global_instance._instantiated:
|
| 45 |
print("Instantiating InternVideo2")
|
| 46 |
-
viclip_global_instance.instantiate()
|
| 47 |
self.clip = viclip_global_instance.viclip
|
| 48 |
self.tokenizer = viclip_global_instance.viclip_tokenizer
|
| 49 |
|
|
@@ -51,8 +51,11 @@ class Text2Video():
|
|
| 51 |
if not os.path.exists(self.result_dir):
|
| 52 |
os.mkdir(self.result_dir)
|
| 53 |
|
|
|
|
| 54 |
def get_prompt(self, prompt, duration):
|
| 55 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 56 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
| 57 |
start = time.time()
|
| 58 |
|
|
@@ -88,6 +91,9 @@ class Text2Video():
|
|
| 88 |
|
| 89 |
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
|
| 90 |
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
|
|
|
|
|
|
|
|
|
|
| 91 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 92 |
|
| 93 |
def download_model(self, model_folder, model_filename):
|
|
|
|
| 36 |
self.download_model(model_folder, model_filename)
|
| 37 |
if not os.path.isfile(os.path.join(model_folder, 'InternVideo2-stage2_1b-224p-f4.pt')):
|
| 38 |
self.download_internvideo2(model_folder)
|
| 39 |
+
self.agent = torch.load(os.path.join(model_folder, model_filename),map_location='cpu')
|
| 40 |
model_name = 'internvideo2'
|
| 41 |
|
| 42 |
# Get ViCLIP
|
| 43 |
viclip_global_instance = ViCLIPGlobalInstance(model_name)
|
| 44 |
if not viclip_global_instance._instantiated:
|
| 45 |
print("Instantiating InternVideo2")
|
| 46 |
+
viclip_global_instance.instantiate(device='cpu')
|
| 47 |
self.clip = viclip_global_instance.viclip
|
| 48 |
self.tokenizer = viclip_global_instance.viclip_tokenizer
|
| 49 |
|
|
|
|
| 51 |
if not os.path.exists(self.result_dir):
|
| 52 |
os.mkdir(self.result_dir)
|
| 53 |
|
| 54 |
+
@spaces.GPU
|
| 55 |
def get_prompt(self, prompt, duration):
|
| 56 |
torch.cuda.empty_cache()
|
| 57 |
+
self.agent.to('cuda')
|
| 58 |
+
self.clip.to('cuda')
|
| 59 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
| 60 |
start = time.time()
|
| 61 |
|
|
|
|
| 91 |
|
| 92 |
save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
|
| 93 |
print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
|
| 94 |
+
# Offload GPU
|
| 95 |
+
self.agent.to('cpu')
|
| 96 |
+
self.clip.to('cpu')
|
| 97 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
| 98 |
|
| 99 |
def download_model(self, model_folder, model_filename):
|