Fixed args to forward
Browse files- video2world_hf.py +2 -2
video2world_hf.py
CHANGED
|
@@ -77,7 +77,7 @@ class DiffusionVideo2World(PreTrainedModel):
|
|
| 77 |
num_input_frames=cfg.num_input_frames,
|
| 78 |
)
|
| 79 |
|
| 80 |
-
def forward(self):
|
| 81 |
cfg = self.config
|
| 82 |
|
| 83 |
# Handle multiple prompts if prompt file is provided
|
|
@@ -86,7 +86,7 @@ class DiffusionVideo2World(PreTrainedModel):
|
|
| 86 |
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 87 |
else:
|
| 88 |
# Single prompt case
|
| 89 |
-
prompts = [{"prompt":
|
| 90 |
|
| 91 |
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 92 |
for i, input_dict in enumerate(prompts):
|
|
|
|
| 77 |
num_input_frames=cfg.num_input_frames,
|
| 78 |
)
|
| 79 |
|
| 80 |
+
def forward(self, prompt, input_image_or_video_path):
|
| 81 |
cfg = self.config
|
| 82 |
|
| 83 |
# Handle multiple prompts if prompt file is provided
|
|
|
|
| 86 |
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 87 |
else:
|
| 88 |
# Single prompt case
|
| 89 |
+
prompts = [{"prompt": prompt, "visual_input": input_image_or_video_path}]
|
| 90 |
|
| 91 |
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 92 |
for i, input_dict in enumerate(prompts):
|