Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import os | |
| from tqdm import tqdm | |
| import json, requests | |
| import fal_client | |
| # import json | |
| def infer_text_guided_vg_bench( | |
| model_name, | |
| result_folder: str = "results", | |
| experiment_name: str = "Exp_Text-Guided_VG", | |
| overwrite_model_outputs: bool = False, | |
| overwrite_inputs: bool = False, | |
| limit_videos_amount: Optional[int] = None, | |
| ): | |
| """ | |
| Performs inference on the VideogenHub dataset using the provided text-guided video generation model. | |
| Args: | |
| model_name: name of the model we want to run inference on | |
| result_folder (str, optional): Path to the root directory where the results should be saved. | |
| Defaults to 'results'. | |
| experiment_name (str, optional): Name of the folder inside 'result_folder' where results | |
| for this particular experiment will be stored. Defaults to "Exp_Text-Guided_IG". | |
| overwrite_model_outputs (bool, optional): If set to True, will overwrite any pre-existing | |
| model outputs. Useful for resuming runs. Defaults to False. | |
| overwrite_inputs (bool, optional): If set to True, will overwrite any pre-existing input | |
| samples. Typically, should be set to False unless there's a need to update the inputs. | |
| Defaults to False. | |
| limit_videos_amount (int, optional): Limits the number of videos to be processed. If set to | |
| None, all videos in the dataset will be processed. | |
| Returns: | |
| None. Results are saved in the specified directory. | |
| Notes: | |
| The function processes each sample from the dataset, uses the model to infer an video based | |
| on text prompts, and then saves the resulting videos in the specified directories. | |
| """ | |
| benchmark_prompt_path = "t2v_vbench_1000.json" | |
| prompts = json.load(open(benchmark_prompt_path, "r")) | |
| save_path = os.path.join(result_folder, experiment_name, "dataset_lookup.json") | |
| if overwrite_inputs or not os.path.exists(save_path): | |
| if not os.path.exists(os.path.join(result_folder, experiment_name)): | |
| os.makedirs(os.path.join(result_folder, experiment_name)) | |
| with open(save_path, "w") as f: | |
| json.dump(prompts, f, indent=4) | |
| print( | |
| "========> Running Benchmark Dataset:", | |
| experiment_name, | |
| "| Model:", | |
| model_name, | |
| ) | |
| if model_name == 'AnimateDiff': | |
| fal_model_name = 'fast-animatediff/text-to-video' | |
| elif model_name == 'AnimateDiffTurbo': | |
| fal_model_name = 'fast-animatediff/turbo/text-to-video' | |
| elif model_name == 'FastSVD': | |
| fal_model_name = 'fast-svd/text-to-video' | |
| else: | |
| raise ValueError("Invalid model_name") | |
| for file_basename, prompt in tqdm(prompts.items()): | |
| idx = int(file_basename.split('_')[0]) | |
| dest_folder = os.path.join( | |
| result_folder, experiment_name, model_name | |
| ) | |
| # file_basename = f"{idx}_{prompt['prompt_en'].replace(' ', '_')}.mp4" | |
| if not os.path.exists(dest_folder): | |
| os.mkdir(dest_folder) | |
| dest_file = os.path.join(dest_folder, file_basename) | |
| if overwrite_model_outputs or not os.path.exists(dest_file): | |
| print("========> Inferencing", dest_file) | |
| handler = fal_client.submit( | |
| f"fal-ai/{fal_model_name}", | |
| arguments={ | |
| "prompt": prompt["prompt_en"] | |
| }, | |
| ) | |
| # for event in handler.iter_events(with_logs=True): | |
| # if isinstance(event, fal_client.InProgress): | |
| # print('Request in progress') | |
| # print(event.logs) | |
| result = handler.get() | |
| result_url = result['video']['url'] | |
| download_mp4(result_url, dest_file) | |
| else: | |
| print("========> Skipping", dest_file, ", it already exists") | |
| if limit_videos_amount is not None and (idx >= limit_videos_amount): | |
| break | |
| def download_mp4(url, filename): | |
| try: | |
| # Send a GET request to the URL | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() # Check if the request was successful | |
| # Open a local file with write-binary mode | |
| with open(filename, 'wb') as file: | |
| # Write the response content to the file in chunks | |
| for chunk in response.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| # print(f"Download complete: {filename}") | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error downloading file: {e}") | |
| if __name__ == "__main__": | |
| pass | |
| # infer_text_guided_vg_bench(model_name="AnimateDiff") | |
| infer_text_guided_vg_bench(result_folder="/mnt/tjena/maxku/max_projects/VideoGenHub/results", model_name="FastSVD") | |
| # infer_text_guided_vg_bench(result_folder="/mnt/tjena/maxku/max_projects/VideoGenHub/results", model_name="AnimateDiff") | |
| # infer_text_guided_vg_bench(result_folder="/mnt/tjena/maxku/max_projects/VideoGenHub/results", model_name="AnimateDiffTurbo") | |