Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- .gitignore +211 -0
- Custom_training.md +33 -0
- Dockerfile +33 -0
- GPT_evaluation/evaluate_benchmark.sh +51 -0
- GPT_evaluation/evaluate_benchmark_1_correctness.py +186 -0
- GPT_evaluation/evaluate_benchmark_2_detailed_orientation.py +186 -0
- GPT_evaluation/evaluate_benchmark_3_context.py +186 -0
- GPT_evaluation/evaluate_benchmark_4_temporal.py +185 -0
- GPT_evaluation/evaluate_benchmark_5_consistency.py +193 -0
- GPT_evaluation/evaluate_zeroshot.py +207 -0
- GPT_evaluation/evaluate_zeroshot.sh +25 -0
- LICENSE.md +14 -0
- LICENSE_Lavis.md +14 -0
- README.md +411 -0
- clean_stage3_json.py +35 -0
- convert_cmd_to_json.py +45 -0
- convert_csv_to_json2.py +34 -0
- environment.yml +317 -0
- evaluation/Goldfish_eval/movies/eval_model_summary_llama_vid.sh +66 -0
- evaluation/Goldfish_eval/movies/eval_model_summary_movie_chat.sh +44 -0
- evaluation/Goldfish_eval/movies/eval_model_summary_movie_qa.sh +63 -0
- evaluation/Goldfish_eval/movies/eval_q_related_info_llama_vid.sh +57 -0
- evaluation/Goldfish_eval/movies/eval_q_related_info_movie_chat.sh +42 -0
- evaluation/Goldfish_eval/movies/eval_q_related_info_movie_qa.sh +57 -0
- evaluation/Goldfish_eval/movies/submit_batch_jobs_llama_vid.py +14 -0
- evaluation/Goldfish_eval/movies/submit_batch_jobs_movie_qa.py +16 -0
- evaluation/Goldfish_eval/movies/submit_batch_jobs_moviechat.py +14 -0
- evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job.sh +51 -0
- evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_sub_v.sh +50 -0
- evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_sub_v_sub.sh +51 -0
- evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_vision_vision.sh +51 -0
- evaluation/Goldfish_eval/tvqa_eval/eval_model_summary.sh +59 -0
- evaluation/Goldfish_eval/tvqa_eval/eval_q_related_info.sh +71 -0
- evaluation/Goldfish_eval/tvqa_eval/submit_batch_jobs.py +25 -0
- evaluation/eval_goldfish_llama_vid.py +616 -0
- evaluation/eval_goldfish_movie_chat.py +453 -0
- evaluation/eval_goldfish_movie_qa.py +591 -0
- evaluation/eval_goldfish_tvqa_long.py +535 -0
- evaluation/eval_minigpt4_video.py +201 -0
- evaluation/eval_retrieval_acc_tvqa.py +316 -0
- filter_json.py +63 -0
- goldfish_demo.py +198 -0
- goldfish_inference.py +62 -0
- goldfish_lv.py +654 -0
- index.py +103 -0
- minigpt4/__init__.py +31 -0
- minigpt4/common/__init__.py +0 -0
- minigpt4/common/config.py +474 -0
- minigpt4/common/dist_utils.py +146 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
repo_imgs/Goldfish_results_table.JPG filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
repo_imgs/MiniGPT4-video_fig.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
repo_imgs/demo_1.JPG filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
repo_imgs/goldfishai.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
repo_imgs/goldfishai_png.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
repo_imgs/minigpt4_demo_icon.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
repo_imgs/online_demo.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
repo_imgs/sample_1.gif filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
repo_imgs/sample_2.gif filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
repo_imgs/sample_3.gif filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
repo_imgs/teaser_fig_final_final.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
cover/
|
| 54 |
+
|
| 55 |
+
# Translations
|
| 56 |
+
*.mo
|
| 57 |
+
*.pot
|
| 58 |
+
|
| 59 |
+
# Django stuff:
|
| 60 |
+
*.log
|
| 61 |
+
local_settings.py
|
| 62 |
+
db.sqlite3
|
| 63 |
+
db.sqlite3-journal
|
| 64 |
+
|
| 65 |
+
# Flask stuff:
|
| 66 |
+
instance/
|
| 67 |
+
.webassets-cache
|
| 68 |
+
|
| 69 |
+
# Scrapy stuff:
|
| 70 |
+
.scrapy
|
| 71 |
+
|
| 72 |
+
# Sphinx documentation
|
| 73 |
+
docs/_build/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
.pybuilder/
|
| 77 |
+
target/
|
| 78 |
+
|
| 79 |
+
# Jupyter Notebook
|
| 80 |
+
.ipynb_checkpoints
|
| 81 |
+
|
| 82 |
+
# IPython
|
| 83 |
+
profile_default/
|
| 84 |
+
ipython_config.py
|
| 85 |
+
|
| 86 |
+
# pyenv
|
| 87 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 89 |
+
# .python-version
|
| 90 |
+
|
| 91 |
+
# pipenv
|
| 92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 95 |
+
# install all needed dependencies.
|
| 96 |
+
#Pipfile.lock
|
| 97 |
+
|
| 98 |
+
# poetry
|
| 99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 101 |
+
# commonly ignored for libraries.
|
| 102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 103 |
+
#poetry.lock
|
| 104 |
+
|
| 105 |
+
# pdm
|
| 106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 107 |
+
#pdm.lock
|
| 108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 109 |
+
# in version control.
|
| 110 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 111 |
+
.pdm.toml
|
| 112 |
+
|
| 113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 114 |
+
__pypackages__/
|
| 115 |
+
|
| 116 |
+
# Celery stuff
|
| 117 |
+
celerybeat-schedule
|
| 118 |
+
celerybeat.pid
|
| 119 |
+
|
| 120 |
+
# SageMath parsed files
|
| 121 |
+
*.sage.py
|
| 122 |
+
|
| 123 |
+
# Environments
|
| 124 |
+
.env
|
| 125 |
+
.venv
|
| 126 |
+
env/
|
| 127 |
+
venv/
|
| 128 |
+
ENV/
|
| 129 |
+
env.bak/
|
| 130 |
+
venv.bak/
|
| 131 |
+
|
| 132 |
+
# Spyder project settings
|
| 133 |
+
.spyderproject
|
| 134 |
+
.spyproject
|
| 135 |
+
|
| 136 |
+
# Rope project settings
|
| 137 |
+
.ropeproject
|
| 138 |
+
|
| 139 |
+
# mkdocs documentation
|
| 140 |
+
/site
|
| 141 |
+
|
| 142 |
+
# mypy
|
| 143 |
+
.mypy_cache/
|
| 144 |
+
.dmypy.json
|
| 145 |
+
dmypy.json
|
| 146 |
+
|
| 147 |
+
# Pyre type checker
|
| 148 |
+
.pyre/
|
| 149 |
+
|
| 150 |
+
# pytype static type analyzer
|
| 151 |
+
.pytype/
|
| 152 |
+
|
| 153 |
+
# Cython debug symbols
|
| 154 |
+
cython_debug/
|
| 155 |
+
|
| 156 |
+
# PyCharm
|
| 157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 161 |
+
.idea/
|
| 162 |
+
|
| 163 |
+
wandb/
|
| 164 |
+
jobs/logs/
|
| 165 |
+
*.out
|
| 166 |
+
*ipynb
|
| 167 |
+
.history/
|
| 168 |
+
*.json
|
| 169 |
+
# *.sh
|
| 170 |
+
.ipynb_common
|
| 171 |
+
logs/
|
| 172 |
+
results/
|
| 173 |
+
prompts/
|
| 174 |
+
output/
|
| 175 |
+
ckpt/
|
| 176 |
+
divide_vqa.py
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
slurm*
|
| 180 |
+
sbatch_generate*
|
| 181 |
+
# ignore all videos and subtitles
|
| 182 |
+
*.mp4
|
| 183 |
+
*.mp3
|
| 184 |
+
*.vtt
|
| 185 |
+
*.mkv
|
| 186 |
+
*.srt
|
| 187 |
+
# ignore text files
|
| 188 |
+
*.txt
|
| 189 |
+
# ignore *.err and *.out
|
| 190 |
+
*.err
|
| 191 |
+
*.out
|
| 192 |
+
*.pth
|
| 193 |
+
*.pt
|
| 194 |
+
*.json
|
| 195 |
+
# ignore workspace folder
|
| 196 |
+
workspace/*
|
| 197 |
+
flagged/*
|
| 198 |
+
jobs_video/eval/choose_best_ckpt/*
|
| 199 |
+
datasets/*
|
| 200 |
+
demo_job_new.sh
|
| 201 |
+
gemini_eval
|
| 202 |
+
llama3.py
|
| 203 |
+
evaluation_subtitles.zip
|
| 204 |
+
minigpt4/models/transformers
|
| 205 |
+
new_workspace
|
| 206 |
+
minigpt4_video
|
| 207 |
+
minigpt4_video_eval
|
| 208 |
+
Infinibench
|
| 209 |
+
goldfish_inference_latency.py
|
| 210 |
+
run.py
|
| 211 |
+
evaluation/eval_infinibench.py
|
Custom_training.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Customizing MiniGPT4-video for your own Video-text dataset
|
| 2 |
+
|
| 3 |
+
## Add your own video dataloader
|
| 4 |
+
Construct your own dataloader here `minigpt4/datasets/datasets/video_datasets.py` based on the existing dataloaders.<br>
|
| 5 |
+
Copy Video_loader_template class and edit it according to you data nature.
|
| 6 |
+
|
| 7 |
+
## Create config file for your dataloader
|
| 8 |
+
Here `minigpt4/configs/datasets/dataset_name/default.yaml` creates your yaml file that includes paths to your dataset.<br>
|
| 9 |
+
Copy the template file `minigpt4/configs/datasets/template/default.yaml` and edit the paths to your dataset.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Register your dataloader
|
| 13 |
+
In the `minigpt4/datasets/builders/image_text_pair_builder.py` file
|
| 14 |
+
Import your data loader class from the `minigpt4/datasets/datasets/video_datasets.py` file <br>
|
| 15 |
+
Copy and edit the VideoTemplateBuilder class.<br>
|
| 16 |
+
put the train_dataset_cls = YourVideoLoaderClass that you imported from `minigpt4/datasets/datasets/video_datasets.py` file.
|
| 17 |
+
|
| 18 |
+
## Edit training config file
|
| 19 |
+
Add your dataset to the datasets in the yml file as shown below:
|
| 20 |
+
```yaml
|
| 21 |
+
datasets:
|
| 22 |
+
dataset_name: # change this to your dataset name
|
| 23 |
+
batch_size: 4 # change this to your desired batch size
|
| 24 |
+
vis_processor:
|
| 25 |
+
train:
|
| 26 |
+
name: "blip2_image_train"
|
| 27 |
+
image_size: 224
|
| 28 |
+
text_processor:
|
| 29 |
+
train:
|
| 30 |
+
name: "blip_caption"
|
| 31 |
+
sample_ratio: 200 # if you including joint training with other datasets, you can set the sample ratio here
|
| 32 |
+
```
|
| 33 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime
|
| 2 |
+
# FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu20.04
|
| 3 |
+
# FROM nvcr.io/nvidia/pytorch:24.01-py3
|
| 4 |
+
# Install necessary tools
|
| 5 |
+
RUN apt-get update && apt-get install -y curl gnupg wget
|
| 6 |
+
|
| 7 |
+
# Add the NVIDIA GPG key and repository
|
| 8 |
+
RUN curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
|
| 9 |
+
&& curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
|
| 10 |
+
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
|
| 11 |
+
tee /etc/apt/sources.list.d/nvidia-container-toolkit.list \
|
| 12 |
+
&& apt-get update
|
| 13 |
+
|
| 14 |
+
# Install the NVIDIA container toolkit
|
| 15 |
+
RUN apt-get install -y nvidia-container-toolkit
|
| 16 |
+
# Set the default runtime to nvidia
|
| 17 |
+
ENV NVIDIA_VISIBLE_DEVICES=all
|
| 18 |
+
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
| 19 |
+
|
| 20 |
+
# RUN apt install python3-pip -y
|
| 21 |
+
COPY ./ /app
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
| 25 |
+
RUN apt-get install gcc -y
|
| 26 |
+
|
| 27 |
+
RUN pip install -r requirements.txt
|
| 28 |
+
|
| 29 |
+
ENV CUDA_VISIBLE_DEVICES=0
|
| 30 |
+
ENV HF_TKN="put your huggingface token here"
|
| 31 |
+
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
CMD ["python", "minigpt4_video_demo.py"]
|
GPT_evaluation/evaluate_benchmark.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Define common arguments for all scripts
|
| 4 |
+
|
| 5 |
+
PRED="pred_path"
|
| 6 |
+
OUTPUT_DIR="output_dir"
|
| 7 |
+
API_KEY="api_key"
|
| 8 |
+
NUM_TASKS=128
|
| 9 |
+
|
| 10 |
+
# Run the "correctness" evaluation script
|
| 11 |
+
python evaluate_benchmark_1_correctness.py \
|
| 12 |
+
--pred_path "${PRED_GENERIC}" \
|
| 13 |
+
--output_dir "${OUTPUT_DIR}/correctness_eval" \
|
| 14 |
+
--output_json "${OUTPUT_DIR}/correctness_results.json" \
|
| 15 |
+
--api_key $API_KEY \
|
| 16 |
+
--num_tasks $NUM_TASKS
|
| 17 |
+
|
| 18 |
+
# Run the "detailed orientation" evaluation script
|
| 19 |
+
python evaluate_benchmark_2_detailed_orientation.py \
|
| 20 |
+
--pred_path "${PRED_GENERIC}" \
|
| 21 |
+
--output_dir "${OUTPUT_DIR}/detailed_eval" \
|
| 22 |
+
--output_json "${OUTPUT_DIR}/detailed_orientation_results.json" \
|
| 23 |
+
--api_key $API_KEY \
|
| 24 |
+
--num_tasks $NUM_TASKS
|
| 25 |
+
|
| 26 |
+
# Run the "contextual understanding" evaluation script
|
| 27 |
+
python evaluate_benchmark_3_context.py \
|
| 28 |
+
--pred_path "${PRED_GENERIC}" \
|
| 29 |
+
--output_dir "${OUTPUT_DIR}/context_eval" \
|
| 30 |
+
--output_json "${OUTPUT_DIR}/contextual_understanding_results.json" \
|
| 31 |
+
--api_key $API_KEY \
|
| 32 |
+
--num_tasks $NUM_TASKS
|
| 33 |
+
|
| 34 |
+
# Run the "temporal understanding" evaluation script
|
| 35 |
+
python evaluate_benchmark_4_temporal.py \
|
| 36 |
+
--pred_path "${PRED_TEMPORAL}" \
|
| 37 |
+
--output_dir "${OUTPUT_DIR}/temporal_eval" \
|
| 38 |
+
--output_json "${OUTPUT_DIR}/temporal_understanding_results.json" \
|
| 39 |
+
--api_key $API_KEY \
|
| 40 |
+
--num_tasks $NUM_TASKS
|
| 41 |
+
|
| 42 |
+
# Run the "consistency" evaluation script
|
| 43 |
+
python evaluate_benchmark_5_consistency.py \
|
| 44 |
+
--pred_path "${PRED_CONSISTENCY}" \
|
| 45 |
+
--output_dir "${OUTPUT_DIR}/consistency_eval" \
|
| 46 |
+
--output_json "${OUTPUT_DIR}/consistency_results.json" \
|
| 47 |
+
--api_key $API_KEY \
|
| 48 |
+
--num_tasks $NUM_TASKS
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
echo "All evaluations completed!"
|
GPT_evaluation/evaluate_benchmark_1_correctness.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3
|
| 23 |
+
Returns a score for correctness.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question = qa_set['q']
|
| 29 |
+
answer = qa_set['a']
|
| 30 |
+
pred = qa_set['pred']
|
| 31 |
+
try:
|
| 32 |
+
# Compute the correctness score
|
| 33 |
+
completion = openai.ChatCompletion.create(
|
| 34 |
+
model="gpt-3.5-turbo",
|
| 35 |
+
messages=[
|
| 36 |
+
{
|
| 37 |
+
"role": "system",
|
| 38 |
+
"content":
|
| 39 |
+
"You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. "
|
| 40 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:"
|
| 41 |
+
"------"
|
| 42 |
+
"##INSTRUCTIONS: "
|
| 43 |
+
"- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n"
|
| 44 |
+
"- The predicted answer must be factually accurate and align with the video content.\n"
|
| 45 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
| 46 |
+
"- Evaluate the factual accuracy of the prediction compared to the answer."
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"role": "user",
|
| 50 |
+
"content":
|
| 51 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 52 |
+
f"Question: {question}\n"
|
| 53 |
+
f"Correct Answer: {answer}\n"
|
| 54 |
+
f"Predicted Answer: {pred}\n\n"
|
| 55 |
+
"Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. "
|
| 56 |
+
"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING."
|
| 57 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 58 |
+
"For example, your response should look like this: {''score': 4.8}."
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
# Convert response to a Python dictionary.
|
| 63 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 64 |
+
response_dict = ast.literal_eval(response_message)
|
| 65 |
+
result_qa_pair = [response_dict, qa_set]
|
| 66 |
+
|
| 67 |
+
# Save the question-answer pairs to a json file.
|
| 68 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 69 |
+
json.dump(result_qa_pair, f)
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error processing file '{key}': {e}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
"""
|
| 77 |
+
Main function to control the flow of the program.
|
| 78 |
+
"""
|
| 79 |
+
# Parse arguments.
|
| 80 |
+
args = parse_args()
|
| 81 |
+
|
| 82 |
+
file = open(args.pred_path)
|
| 83 |
+
pred_contents = json.load(file)
|
| 84 |
+
|
| 85 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 86 |
+
video_id_counts = {}
|
| 87 |
+
new_pred_contents = []
|
| 88 |
+
|
| 89 |
+
# Iterate through each sample in pred_contents
|
| 90 |
+
for sample in pred_contents:
|
| 91 |
+
video_id = sample['video_name']
|
| 92 |
+
if video_id in video_id_counts:
|
| 93 |
+
video_id_counts[video_id] += 1
|
| 94 |
+
else:
|
| 95 |
+
video_id_counts[video_id] = 0
|
| 96 |
+
|
| 97 |
+
# Create a new sample with the modified key
|
| 98 |
+
new_sample = sample
|
| 99 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 100 |
+
new_pred_contents.append(new_sample)
|
| 101 |
+
|
| 102 |
+
# Generating list of id's and corresponding files
|
| 103 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 104 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 105 |
+
|
| 106 |
+
output_dir = args.output_dir
|
| 107 |
+
# Generate output directory if not exists.
|
| 108 |
+
if not os.path.exists(output_dir):
|
| 109 |
+
os.makedirs(output_dir)
|
| 110 |
+
|
| 111 |
+
# Preparing dictionary of question-answer sets
|
| 112 |
+
prediction_set = {}
|
| 113 |
+
for sample in new_pred_contents:
|
| 114 |
+
id = sample['video_name']
|
| 115 |
+
question = sample['Q']
|
| 116 |
+
answer = sample['A']
|
| 117 |
+
pred = sample['pred']
|
| 118 |
+
qa_set = {"q": question, "a": answer, "pred": pred}
|
| 119 |
+
prediction_set[id] = qa_set
|
| 120 |
+
|
| 121 |
+
# Set the OpenAI API key.
|
| 122 |
+
openai.api_key = args.api_key
|
| 123 |
+
num_tasks = args.num_tasks
|
| 124 |
+
|
| 125 |
+
# While loop to ensure that all captions are processed.
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
# Files that have not been processed yet.
|
| 129 |
+
completed_files = os.listdir(output_dir)
|
| 130 |
+
print(f"completed_files: {len(completed_files)}")
|
| 131 |
+
|
| 132 |
+
# Files that have not been processed yet.
|
| 133 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 134 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 135 |
+
|
| 136 |
+
# Break the loop when there are no incomplete files
|
| 137 |
+
if len(incomplete_files) == 0:
|
| 138 |
+
break
|
| 139 |
+
if len(incomplete_files) <= num_tasks:
|
| 140 |
+
num_tasks = 1
|
| 141 |
+
|
| 142 |
+
# Split tasks into parts.
|
| 143 |
+
part_len = len(incomplete_files) // num_tasks
|
| 144 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 145 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 146 |
+
|
| 147 |
+
# Use a pool of workers to process the files in parallel.
|
| 148 |
+
with Pool() as pool:
|
| 149 |
+
pool.starmap(annotate, task_args)
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error: {e}")
|
| 153 |
+
|
| 154 |
+
# Combine all the processed files into one
|
| 155 |
+
combined_contents = {}
|
| 156 |
+
json_path = args.output_json
|
| 157 |
+
|
| 158 |
+
# Iterate through json files
|
| 159 |
+
for file_name in os.listdir(output_dir):
|
| 160 |
+
if file_name.endswith(".json"):
|
| 161 |
+
file_path = os.path.join(output_dir, file_name)
|
| 162 |
+
with open(file_path, "r") as json_file:
|
| 163 |
+
content = json.load(json_file)
|
| 164 |
+
combined_contents[file_name[:-5]] = content
|
| 165 |
+
|
| 166 |
+
# Write combined content to a json file
|
| 167 |
+
with open(json_path, "w") as json_file:
|
| 168 |
+
json.dump(combined_contents, json_file)
|
| 169 |
+
print("All evaluation completed!")
|
| 170 |
+
|
| 171 |
+
# Calculate average score
|
| 172 |
+
score_sum = 0
|
| 173 |
+
count = 0
|
| 174 |
+
for key, result in combined_contents.items():
|
| 175 |
+
count += 1
|
| 176 |
+
score_match = result[0]['score']
|
| 177 |
+
score = int(score_match)
|
| 178 |
+
score_sum += score
|
| 179 |
+
average_score = score_sum / count
|
| 180 |
+
|
| 181 |
+
print("Average score for correctness:", average_score)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
| 186 |
+
|
GPT_evaluation/evaluate_benchmark_2_detailed_orientation.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3 and
|
| 23 |
+
returns a score for detailed orientation.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question = qa_set['q']
|
| 29 |
+
answer = qa_set['a']
|
| 30 |
+
pred = qa_set['pred']
|
| 31 |
+
try:
|
| 32 |
+
# Compute the detailed-orientation score
|
| 33 |
+
completion = openai.ChatCompletion.create(
|
| 34 |
+
model="gpt-3.5-turbo",
|
| 35 |
+
messages=[
|
| 36 |
+
{
|
| 37 |
+
"role": "system",
|
| 38 |
+
"content":
|
| 39 |
+
"You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. "
|
| 40 |
+
"Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:"
|
| 41 |
+
"------"
|
| 42 |
+
"##INSTRUCTIONS: "
|
| 43 |
+
"- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n"
|
| 44 |
+
"- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n"
|
| 45 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
| 46 |
+
"- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity."
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"role": "user",
|
| 50 |
+
"content":
|
| 51 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 52 |
+
f"Question: {question}\n"
|
| 53 |
+
f"Correct Answer: {answer}\n"
|
| 54 |
+
f"Predicted Answer: {pred}\n\n"
|
| 55 |
+
"Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. "
|
| 56 |
+
"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING."
|
| 57 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 58 |
+
"For example, your response should look like this: {''score': 4.8}."
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
# Convert response to a Python dictionary.
|
| 63 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 64 |
+
response_dict = ast.literal_eval(response_message)
|
| 65 |
+
result_qa_pair = [response_dict, qa_set]
|
| 66 |
+
|
| 67 |
+
# Save the question-answer pairs to a json file.
|
| 68 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 69 |
+
json.dump(result_qa_pair, f)
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error processing file '{key}': {e}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
"""
|
| 77 |
+
Main function to control the flow of the program.
|
| 78 |
+
"""
|
| 79 |
+
# Parse arguments.
|
| 80 |
+
args = parse_args()
|
| 81 |
+
|
| 82 |
+
file = open(args.pred_path)
|
| 83 |
+
pred_contents = json.load(file)
|
| 84 |
+
|
| 85 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 86 |
+
video_id_counts = {}
|
| 87 |
+
new_pred_contents = []
|
| 88 |
+
|
| 89 |
+
# Iterate through each sample in pred_contents
|
| 90 |
+
for sample in pred_contents:
|
| 91 |
+
video_id = sample['video_name']
|
| 92 |
+
if video_id in video_id_counts:
|
| 93 |
+
video_id_counts[video_id] += 1
|
| 94 |
+
else:
|
| 95 |
+
video_id_counts[video_id] = 0
|
| 96 |
+
|
| 97 |
+
# Create a new sample with the modified key
|
| 98 |
+
new_sample = sample
|
| 99 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 100 |
+
new_pred_contents.append(new_sample)
|
| 101 |
+
|
| 102 |
+
# Generating list of id's and corresponding files
|
| 103 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 104 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 105 |
+
|
| 106 |
+
output_dir = args.output_dir
|
| 107 |
+
# Generate output directory if not exists.
|
| 108 |
+
if not os.path.exists(output_dir):
|
| 109 |
+
os.makedirs(output_dir)
|
| 110 |
+
|
| 111 |
+
# Preparing dictionary of question-answer sets
|
| 112 |
+
prediction_set = {}
|
| 113 |
+
for sample in new_pred_contents:
|
| 114 |
+
id = sample['video_name']
|
| 115 |
+
question = sample['Q']
|
| 116 |
+
answer = sample['A']
|
| 117 |
+
pred = sample['pred']
|
| 118 |
+
qa_set = {"q": question, "a": answer, "pred": pred}
|
| 119 |
+
prediction_set[id] = qa_set
|
| 120 |
+
|
| 121 |
+
# Set the OpenAI API key.
|
| 122 |
+
openai.api_key = args.api_key
|
| 123 |
+
num_tasks = args.num_tasks
|
| 124 |
+
|
| 125 |
+
# While loop to ensure that all captions are processed.
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
# Files that have not been processed yet.
|
| 129 |
+
completed_files = os.listdir(output_dir)
|
| 130 |
+
print(f"completed_files: {len(completed_files)}")
|
| 131 |
+
|
| 132 |
+
# Files that have not been processed yet.
|
| 133 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 134 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 135 |
+
|
| 136 |
+
# Break the loop when there are no incomplete files
|
| 137 |
+
if len(incomplete_files) == 0:
|
| 138 |
+
break
|
| 139 |
+
if len(incomplete_files) <= num_tasks:
|
| 140 |
+
num_tasks = 1
|
| 141 |
+
|
| 142 |
+
# Split tasks into parts.
|
| 143 |
+
part_len = len(incomplete_files) // num_tasks
|
| 144 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 145 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 146 |
+
|
| 147 |
+
# Use a pool of workers to process the files in parallel.
|
| 148 |
+
with Pool() as pool:
|
| 149 |
+
pool.starmap(annotate, task_args)
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error: {e}")
|
| 153 |
+
|
| 154 |
+
# Combine all the processed files into one
|
| 155 |
+
combined_contents = {}
|
| 156 |
+
json_path = args.output_json
|
| 157 |
+
|
| 158 |
+
# Iterate through json files
|
| 159 |
+
for file_name in os.listdir(output_dir):
|
| 160 |
+
if file_name.endswith(".json"):
|
| 161 |
+
file_path = os.path.join(output_dir, file_name)
|
| 162 |
+
with open(file_path, "r") as json_file:
|
| 163 |
+
content = json.load(json_file)
|
| 164 |
+
combined_contents[file_name[:-5]] = content
|
| 165 |
+
|
| 166 |
+
# Write combined content to a json file
|
| 167 |
+
with open(json_path, "w") as json_file:
|
| 168 |
+
json.dump(combined_contents, json_file)
|
| 169 |
+
print("All evaluation completed!")
|
| 170 |
+
|
| 171 |
+
# Calculate average score
|
| 172 |
+
score_sum = 0
|
| 173 |
+
count = 0
|
| 174 |
+
for key, result in combined_contents.items():
|
| 175 |
+
count += 1
|
| 176 |
+
score_match = result[0]['score']
|
| 177 |
+
score = int(score_match)
|
| 178 |
+
score_sum += score
|
| 179 |
+
average_score = score_sum / count
|
| 180 |
+
|
| 181 |
+
print("Average score for detailed orientation:", average_score)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
| 186 |
+
|
GPT_evaluation/evaluate_benchmark_3_context.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3 and
|
| 23 |
+
returns a score for contextual understanding.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question = qa_set['q']
|
| 29 |
+
answer = qa_set['a']
|
| 30 |
+
pred = qa_set['pred']
|
| 31 |
+
try:
|
| 32 |
+
# Compute the contextual understanding score
|
| 33 |
+
completion = openai.ChatCompletion.create(
|
| 34 |
+
model="gpt-3.5-turbo",
|
| 35 |
+
messages=[
|
| 36 |
+
{
|
| 37 |
+
"role": "system",
|
| 38 |
+
"content":
|
| 39 |
+
"You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. "
|
| 40 |
+
"Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:"
|
| 41 |
+
"------"
|
| 42 |
+
"##INSTRUCTIONS: "
|
| 43 |
+
"- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n"
|
| 44 |
+
"- The predicted answer must capture the main themes and sentiments of the video.\n"
|
| 45 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
| 46 |
+
"- Provide your evaluation of the contextual understanding of the prediction compared to the answer."
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"role": "user",
|
| 50 |
+
"content":
|
| 51 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 52 |
+
f"Question: {question}\n"
|
| 53 |
+
f"Correct Answer: {answer}\n"
|
| 54 |
+
f"Predicted Answer: {pred}\n\n"
|
| 55 |
+
"Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. "
|
| 56 |
+
"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING."
|
| 57 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 58 |
+
"For example, your response should look like this: {''score': 4.8}."
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
# Convert response to a Python dictionary.
|
| 63 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 64 |
+
response_dict = ast.literal_eval(response_message)
|
| 65 |
+
result_qa_pair = [response_dict, qa_set]
|
| 66 |
+
|
| 67 |
+
# Save the question-answer pairs to a json file.
|
| 68 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 69 |
+
json.dump(result_qa_pair, f)
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error processing file '{key}': {e}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
"""
|
| 77 |
+
Main function to control the flow of the program.
|
| 78 |
+
"""
|
| 79 |
+
# Parse arguments.
|
| 80 |
+
args = parse_args()
|
| 81 |
+
|
| 82 |
+
file = open(args.pred_path)
|
| 83 |
+
pred_contents = json.load(file)
|
| 84 |
+
|
| 85 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 86 |
+
video_id_counts = {}
|
| 87 |
+
new_pred_contents = []
|
| 88 |
+
|
| 89 |
+
# Iterate through each sample in pred_contents
|
| 90 |
+
for sample in pred_contents:
|
| 91 |
+
video_id = sample['video_name']
|
| 92 |
+
if video_id in video_id_counts:
|
| 93 |
+
video_id_counts[video_id] += 1
|
| 94 |
+
else:
|
| 95 |
+
video_id_counts[video_id] = 0
|
| 96 |
+
|
| 97 |
+
# Create a new sample with the modified key
|
| 98 |
+
new_sample = sample
|
| 99 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 100 |
+
new_pred_contents.append(new_sample)
|
| 101 |
+
|
| 102 |
+
# Generating list of id's and corresponding files
|
| 103 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 104 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 105 |
+
|
| 106 |
+
output_dir = args.output_dir
|
| 107 |
+
# Generate output directory if not exists.
|
| 108 |
+
if not os.path.exists(output_dir):
|
| 109 |
+
os.makedirs(output_dir)
|
| 110 |
+
|
| 111 |
+
# Preparing dictionary of question-answer sets
|
| 112 |
+
prediction_set = {}
|
| 113 |
+
for sample in new_pred_contents:
|
| 114 |
+
id = sample['video_name']
|
| 115 |
+
question = sample['Q']
|
| 116 |
+
answer = sample['A']
|
| 117 |
+
pred = sample['pred']
|
| 118 |
+
qa_set = {"q": question, "a": answer, "pred": pred}
|
| 119 |
+
prediction_set[id] = qa_set
|
| 120 |
+
|
| 121 |
+
# Set the OpenAI API key.
|
| 122 |
+
openai.api_key = args.api_key
|
| 123 |
+
num_tasks = args.num_tasks
|
| 124 |
+
|
| 125 |
+
# While loop to ensure that all captions are processed.
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
# Files that have not been processed yet.
|
| 129 |
+
completed_files = os.listdir(output_dir)
|
| 130 |
+
print(f"completed_files: {len(completed_files)}")
|
| 131 |
+
|
| 132 |
+
# Files that have not been processed yet.
|
| 133 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 134 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 135 |
+
|
| 136 |
+
# Break the loop when there are no incomplete files
|
| 137 |
+
if len(incomplete_files) == 0:
|
| 138 |
+
break
|
| 139 |
+
if len(incomplete_files) <= num_tasks:
|
| 140 |
+
num_tasks = 1
|
| 141 |
+
|
| 142 |
+
# Split tasks into parts.
|
| 143 |
+
part_len = len(incomplete_files) // num_tasks
|
| 144 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 145 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 146 |
+
|
| 147 |
+
# Use a pool of workers to process the files in parallel.
|
| 148 |
+
with Pool() as pool:
|
| 149 |
+
pool.starmap(annotate, task_args)
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error: {e}")
|
| 153 |
+
|
| 154 |
+
# Combine all the processed files into one
|
| 155 |
+
combined_contents = {}
|
| 156 |
+
json_path = args.output_json
|
| 157 |
+
|
| 158 |
+
# Iterate through json files
|
| 159 |
+
for file_name in os.listdir(output_dir):
|
| 160 |
+
if file_name.endswith(".json"):
|
| 161 |
+
file_path = os.path.join(output_dir, file_name)
|
| 162 |
+
with open(file_path, "r") as json_file:
|
| 163 |
+
content = json.load(json_file)
|
| 164 |
+
combined_contents[file_name[:-5]] = content
|
| 165 |
+
|
| 166 |
+
# Write combined content to a json file
|
| 167 |
+
with open(json_path, "w") as json_file:
|
| 168 |
+
json.dump(combined_contents, json_file)
|
| 169 |
+
print("All evaluation completed!")
|
| 170 |
+
|
| 171 |
+
# Calculate average score
|
| 172 |
+
score_sum = 0
|
| 173 |
+
count = 0
|
| 174 |
+
for key, result in combined_contents.items():
|
| 175 |
+
count += 1
|
| 176 |
+
score_match = result[0]['score']
|
| 177 |
+
score = int(score_match)
|
| 178 |
+
score_sum += score
|
| 179 |
+
average_score = score_sum / count
|
| 180 |
+
|
| 181 |
+
print("Average score for contextual understanding:", average_score)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
| 186 |
+
|
GPT_evaluation/evaluate_benchmark_4_temporal.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3 and
|
| 23 |
+
returns a score for temporal understanding.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question = qa_set['q']
|
| 29 |
+
answer = qa_set['a']
|
| 30 |
+
pred = qa_set['pred']
|
| 31 |
+
try:
|
| 32 |
+
# Compute the temporal understanding score
|
| 33 |
+
completion = openai.ChatCompletion.create(
|
| 34 |
+
model="gpt-3.5-turbo",
|
| 35 |
+
messages=[
|
| 36 |
+
{
|
| 37 |
+
"role": "system",
|
| 38 |
+
"content":
|
| 39 |
+
"You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. "
|
| 40 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:"
|
| 41 |
+
"------"
|
| 42 |
+
"##INSTRUCTIONS: "
|
| 43 |
+
"- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n"
|
| 44 |
+
"- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n"
|
| 45 |
+
"- Evaluate the temporal accuracy of the prediction compared to the answer."
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"role": "user",
|
| 49 |
+
"content":
|
| 50 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 51 |
+
f"Question: {question}\n"
|
| 52 |
+
f"Correct Answer: {answer}\n"
|
| 53 |
+
f"Predicted Answer: {pred}\n\n"
|
| 54 |
+
"Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. "
|
| 55 |
+
"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING."
|
| 56 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 57 |
+
"For example, your response should look like this: {''score': 4.8}."
|
| 58 |
+
}
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
# Convert response to a Python dictionary.
|
| 62 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 63 |
+
response_dict = ast.literal_eval(response_message)
|
| 64 |
+
result_qa_pair = [response_dict, qa_set]
|
| 65 |
+
|
| 66 |
+
# Save the question-answer pairs to a json file.
|
| 67 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 68 |
+
json.dump(result_qa_pair, f)
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error processing file '{key}': {e}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
"""
|
| 76 |
+
Main function to control the flow of the program.
|
| 77 |
+
"""
|
| 78 |
+
# Parse arguments.
|
| 79 |
+
args = parse_args()
|
| 80 |
+
|
| 81 |
+
file = open(args.pred_path)
|
| 82 |
+
pred_contents = json.load(file)
|
| 83 |
+
|
| 84 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 85 |
+
video_id_counts = {}
|
| 86 |
+
new_pred_contents = []
|
| 87 |
+
|
| 88 |
+
# Iterate through each sample in pred_contents
|
| 89 |
+
for sample in pred_contents:
|
| 90 |
+
video_id = sample['video_name']
|
| 91 |
+
if video_id in video_id_counts:
|
| 92 |
+
video_id_counts[video_id] += 1
|
| 93 |
+
else:
|
| 94 |
+
video_id_counts[video_id] = 0
|
| 95 |
+
|
| 96 |
+
# Create a new sample with the modified key
|
| 97 |
+
new_sample = sample
|
| 98 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 99 |
+
new_pred_contents.append(new_sample)
|
| 100 |
+
|
| 101 |
+
# Generating list of id's and corresponding files
|
| 102 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 103 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 104 |
+
|
| 105 |
+
output_dir = args.output_dir
|
| 106 |
+
# Generate output directory if not exists.
|
| 107 |
+
if not os.path.exists(output_dir):
|
| 108 |
+
os.makedirs(output_dir)
|
| 109 |
+
|
| 110 |
+
# Preparing dictionary of question-answer sets
|
| 111 |
+
prediction_set = {}
|
| 112 |
+
for sample in new_pred_contents:
|
| 113 |
+
id = sample['video_name']
|
| 114 |
+
question = sample['Q']
|
| 115 |
+
answer = sample['A']
|
| 116 |
+
pred = sample['pred']
|
| 117 |
+
qa_set = {"q": question, "a": answer, "pred": pred}
|
| 118 |
+
prediction_set[id] = qa_set
|
| 119 |
+
|
| 120 |
+
# Set the OpenAI API key.
|
| 121 |
+
openai.api_key = args.api_key
|
| 122 |
+
num_tasks = args.num_tasks
|
| 123 |
+
|
| 124 |
+
# While loop to ensure that all captions are processed.
|
| 125 |
+
while True:
|
| 126 |
+
try:
|
| 127 |
+
# Files that have not been processed yet.
|
| 128 |
+
completed_files = os.listdir(output_dir)
|
| 129 |
+
print(f"completed_files: {len(completed_files)}")
|
| 130 |
+
|
| 131 |
+
# Files that have not been processed yet.
|
| 132 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 133 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 134 |
+
|
| 135 |
+
# Break the loop when there are no incomplete files
|
| 136 |
+
if len(incomplete_files) == 0:
|
| 137 |
+
break
|
| 138 |
+
if len(incomplete_files) <= num_tasks:
|
| 139 |
+
num_tasks = 1
|
| 140 |
+
|
| 141 |
+
# Split tasks into parts.
|
| 142 |
+
part_len = len(incomplete_files) // num_tasks
|
| 143 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 144 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 145 |
+
|
| 146 |
+
# Use a pool of workers to process the files in parallel.
|
| 147 |
+
with Pool() as pool:
|
| 148 |
+
pool.starmap(annotate, task_args)
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error: {e}")
|
| 152 |
+
|
| 153 |
+
# Combine all the processed files into one
|
| 154 |
+
combined_contents = {}
|
| 155 |
+
json_path = args.output_json
|
| 156 |
+
|
| 157 |
+
# Iterate through json files
|
| 158 |
+
for file_name in os.listdir(output_dir):
|
| 159 |
+
if file_name.endswith(".json"):
|
| 160 |
+
file_path = os.path.join(output_dir, file_name)
|
| 161 |
+
with open(file_path, "r") as json_file:
|
| 162 |
+
content = json.load(json_file)
|
| 163 |
+
combined_contents[file_name[:-5]] = content
|
| 164 |
+
|
| 165 |
+
# Write combined content to a json file
|
| 166 |
+
with open(json_path, "w") as json_file:
|
| 167 |
+
json.dump(combined_contents, json_file)
|
| 168 |
+
print("All evaluation completed!")
|
| 169 |
+
|
| 170 |
+
# Calculate average score
|
| 171 |
+
score_sum = 0
|
| 172 |
+
count = 0
|
| 173 |
+
for key, result in combined_contents.items():
|
| 174 |
+
count += 1
|
| 175 |
+
score_match = result[0]['score']
|
| 176 |
+
score = int(score_match)
|
| 177 |
+
score_sum += score
|
| 178 |
+
average_score = score_sum / count
|
| 179 |
+
|
| 180 |
+
print("Average score temporal understanding:", average_score)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
| 185 |
+
|
GPT_evaluation/evaluate_benchmark_5_consistency.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3 and
|
| 23 |
+
returns a score for consistency.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question1 = qa_set['q1']
|
| 29 |
+
question2 = qa_set['q2']
|
| 30 |
+
answer = qa_set['a']
|
| 31 |
+
pred1 = qa_set['pred1']
|
| 32 |
+
pred2 = qa_set['pred2']
|
| 33 |
+
try:
|
| 34 |
+
# Compute the consistency score
|
| 35 |
+
completion = openai.ChatCompletion.create(
|
| 36 |
+
model="gpt-3.5-turbo",
|
| 37 |
+
messages=[
|
| 38 |
+
{
|
| 39 |
+
"role": "system",
|
| 40 |
+
"content":
|
| 41 |
+
"You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. "
|
| 42 |
+
"You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ."
|
| 43 |
+
"Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:"
|
| 44 |
+
"------"
|
| 45 |
+
"##INSTRUCTIONS: "
|
| 46 |
+
"- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n"
|
| 47 |
+
"- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n"
|
| 48 |
+
"- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n"
|
| 49 |
+
"- Evaluate the consistency of the two predicted answers compared to the correct answer."
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"role": "user",
|
| 53 |
+
"content":
|
| 54 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 55 |
+
f"Question 1: {question1}\n"
|
| 56 |
+
f"Question 2: {question2}\n"
|
| 57 |
+
f"Correct Answer: {answer}\n"
|
| 58 |
+
f"Predicted Answer to Question 1: {pred1}\n"
|
| 59 |
+
f"Predicted Answer to Question 2: {pred2}\n\n"
|
| 60 |
+
"Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. "
|
| 61 |
+
"Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING."
|
| 62 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 63 |
+
"For example, your response should look like this: {''score': 4.8}."
|
| 64 |
+
}
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
# Convert response to a Python dictionary.
|
| 68 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 69 |
+
response_dict = ast.literal_eval(response_message)
|
| 70 |
+
result_qa_pair = [response_dict, qa_set]
|
| 71 |
+
|
| 72 |
+
# Save the question-answer pairs to a json file.
|
| 73 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 74 |
+
json.dump(result_qa_pair, f)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Error processing file '{key}': {e}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
"""
|
| 82 |
+
Main function to control the flow of the program.
|
| 83 |
+
"""
|
| 84 |
+
# Parse arguments.
|
| 85 |
+
args = parse_args()
|
| 86 |
+
|
| 87 |
+
file = open(args.pred_path)
|
| 88 |
+
pred_contents = json.load(file)
|
| 89 |
+
|
| 90 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 91 |
+
video_id_counts = {}
|
| 92 |
+
new_pred_contents = []
|
| 93 |
+
|
| 94 |
+
# Iterate through each sample in pred_contents
|
| 95 |
+
for sample in pred_contents:
|
| 96 |
+
video_id = sample['video_name']
|
| 97 |
+
if video_id in video_id_counts:
|
| 98 |
+
video_id_counts[video_id] += 1
|
| 99 |
+
else:
|
| 100 |
+
video_id_counts[video_id] = 0
|
| 101 |
+
|
| 102 |
+
# Create a new sample with the modified key
|
| 103 |
+
new_sample = sample
|
| 104 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 105 |
+
new_pred_contents.append(new_sample)
|
| 106 |
+
|
| 107 |
+
# Generating list of id's and corresponding files
|
| 108 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 109 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 110 |
+
|
| 111 |
+
output_dir = args.output_dir
|
| 112 |
+
# Generate output directory if not exists.
|
| 113 |
+
if not os.path.exists(output_dir):
|
| 114 |
+
os.makedirs(output_dir)
|
| 115 |
+
|
| 116 |
+
# Preparing dictionary of question-answer sets
|
| 117 |
+
prediction_set = {}
|
| 118 |
+
for sample in new_pred_contents:
|
| 119 |
+
id = sample['video_name']
|
| 120 |
+
question1 = sample['Q1']
|
| 121 |
+
question2 = sample['Q1']
|
| 122 |
+
answer = sample['A']
|
| 123 |
+
pred1 = sample['pred1']
|
| 124 |
+
pred2 = sample['pred2']
|
| 125 |
+
qa_set = {"q1": question1, "q2": question2, "a": answer, "pred1": pred1, "pred2": pred2}
|
| 126 |
+
prediction_set[id] = qa_set
|
| 127 |
+
|
| 128 |
+
# Set the OpenAI API key.
|
| 129 |
+
openai.api_key = args.api_key
|
| 130 |
+
num_tasks = args.num_tasks
|
| 131 |
+
|
| 132 |
+
# While loop to ensure that all captions are processed.
|
| 133 |
+
while True:
|
| 134 |
+
try:
|
| 135 |
+
# Files that have not been processed yet.
|
| 136 |
+
completed_files = os.listdir(output_dir)
|
| 137 |
+
print(f"completed_files: {len(completed_files)}")
|
| 138 |
+
|
| 139 |
+
# Files that have not been processed yet.
|
| 140 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 141 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 142 |
+
|
| 143 |
+
# Break the loop when there are no incomplete files
|
| 144 |
+
if len(incomplete_files) == 0:
|
| 145 |
+
break
|
| 146 |
+
if len(incomplete_files) <= num_tasks:
|
| 147 |
+
num_tasks = 1
|
| 148 |
+
|
| 149 |
+
# Split tasks into parts.
|
| 150 |
+
part_len = len(incomplete_files) // num_tasks
|
| 151 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 152 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 153 |
+
|
| 154 |
+
# Use a pool of workers to process the files in parallel.
|
| 155 |
+
with Pool() as pool:
|
| 156 |
+
pool.starmap(annotate, task_args)
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
print(f"Error: {e}")
|
| 160 |
+
|
| 161 |
+
# Combine all the processed files into one
|
| 162 |
+
combined_contents = {}
|
| 163 |
+
json_path = args.output_json
|
| 164 |
+
|
| 165 |
+
# Iterate through json files
|
| 166 |
+
for file_name in os.listdir(output_dir):
|
| 167 |
+
if file_name.endswith(".json"):
|
| 168 |
+
file_path = os.path.join(output_dir, file_name)
|
| 169 |
+
with open(file_path, "r") as json_file:
|
| 170 |
+
content = json.load(json_file)
|
| 171 |
+
combined_contents[file_name[:-5]] = content
|
| 172 |
+
|
| 173 |
+
# Write combined content to a json file
|
| 174 |
+
with open(json_path, "w") as json_file:
|
| 175 |
+
json.dump(combined_contents, json_file)
|
| 176 |
+
print("All evaluation completed!")
|
| 177 |
+
|
| 178 |
+
# Calculate average score
|
| 179 |
+
score_sum = 0
|
| 180 |
+
count = 0
|
| 181 |
+
for key, result in combined_contents.items():
|
| 182 |
+
count += 1
|
| 183 |
+
score_match = result[0]['score']
|
| 184 |
+
score = int(score_match)
|
| 185 |
+
score_sum += score
|
| 186 |
+
average_score = score_sum / count
|
| 187 |
+
|
| 188 |
+
print("Average score for consistency:", average_score)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
main()
|
| 193 |
+
|
GPT_evaluation/evaluate_zeroshot.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import ast
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
|
| 11 |
+
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
|
| 12 |
+
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
|
| 13 |
+
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
|
| 14 |
+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
|
| 15 |
+
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
return args
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(prediction_set, caption_files, output_dir):
|
| 21 |
+
"""
|
| 22 |
+
Evaluates question and answer pairs using GPT-3
|
| 23 |
+
Returns a score for correctness.
|
| 24 |
+
"""
|
| 25 |
+
for file in caption_files:
|
| 26 |
+
key = file[:-5] # Strip file extension
|
| 27 |
+
qa_set = prediction_set[key]
|
| 28 |
+
question = qa_set['q']
|
| 29 |
+
answer = qa_set['a']
|
| 30 |
+
pred = qa_set['pred']
|
| 31 |
+
try:
|
| 32 |
+
# Compute the correctness score
|
| 33 |
+
completion = openai.ChatCompletion.create(
|
| 34 |
+
model="gpt-3.5-turbo",
|
| 35 |
+
messages=[
|
| 36 |
+
{
|
| 37 |
+
"role": "system",
|
| 38 |
+
"content":
|
| 39 |
+
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
|
| 40 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
|
| 41 |
+
"------"
|
| 42 |
+
"##INSTRUCTIONS: "
|
| 43 |
+
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
|
| 44 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
| 45 |
+
"- Evaluate the correctness of the prediction compared to the answer."
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"role": "user",
|
| 49 |
+
"content":
|
| 50 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 51 |
+
f"Question: {question}\n"
|
| 52 |
+
f"Correct Answer: {answer}\n"
|
| 53 |
+
f"Predicted Answer: {pred}\n\n"
|
| 54 |
+
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
|
| 55 |
+
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
|
| 56 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 57 |
+
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
|
| 58 |
+
}
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
# Convert response to a Python dictionary.
|
| 62 |
+
response_message = completion["choices"][0]["message"]["content"]
|
| 63 |
+
response_dict = ast.literal_eval(response_message)
|
| 64 |
+
result_qa_pair = [response_dict, qa_set]
|
| 65 |
+
|
| 66 |
+
# Save the question-answer pairs to a json file.
|
| 67 |
+
with open(f"{output_dir}/{key}.json", "w") as f:
|
| 68 |
+
json.dump(result_qa_pair, f)
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error processing file '{key}': {e}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
"""
|
| 76 |
+
Main function to control the flow of the program.
|
| 77 |
+
"""
|
| 78 |
+
# Parse arguments.
|
| 79 |
+
args = parse_args()
|
| 80 |
+
|
| 81 |
+
file = open(args.pred_path)
|
| 82 |
+
pred_contents = json.load(file)
|
| 83 |
+
|
| 84 |
+
# Dictionary to store the count of occurrences for each video_id
|
| 85 |
+
video_id_counts = {}
|
| 86 |
+
new_pred_contents = []
|
| 87 |
+
|
| 88 |
+
# Iterate through each sample in pred_contents
|
| 89 |
+
for sample in pred_contents:
|
| 90 |
+
video_id = sample['video_name']
|
| 91 |
+
if video_id in video_id_counts:
|
| 92 |
+
video_id_counts[video_id] += 1
|
| 93 |
+
else:
|
| 94 |
+
video_id_counts[video_id] = 0
|
| 95 |
+
|
| 96 |
+
# Create a new sample with the modified key
|
| 97 |
+
new_sample = sample
|
| 98 |
+
new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
|
| 99 |
+
new_pred_contents.append(new_sample)
|
| 100 |
+
|
| 101 |
+
# Generating list of id's and corresponding files
|
| 102 |
+
id_list = [x['video_name'] for x in new_pred_contents]
|
| 103 |
+
caption_files = [f"{id}.json" for id in id_list]
|
| 104 |
+
|
| 105 |
+
output_dir = args.output_dir
|
| 106 |
+
# Generate output directory if not exists.
|
| 107 |
+
if not os.path.exists(output_dir):
|
| 108 |
+
os.makedirs(output_dir)
|
| 109 |
+
|
| 110 |
+
# Preparing dictionary of question-answer sets
|
| 111 |
+
prediction_set = {}
|
| 112 |
+
for sample in new_pred_contents:
|
| 113 |
+
id = sample['video_name']
|
| 114 |
+
question = sample['Q']
|
| 115 |
+
answer = sample['A']
|
| 116 |
+
pred = sample['pred']
|
| 117 |
+
qa_set = {"q": question, "a": answer, "pred": pred}
|
| 118 |
+
prediction_set[id] = qa_set
|
| 119 |
+
|
| 120 |
+
# Set the OpenAI API key.
|
| 121 |
+
openai.api_key = args.api_key
|
| 122 |
+
num_tasks = args.num_tasks
|
| 123 |
+
|
| 124 |
+
# While loop to ensure that all captions are processed.
|
| 125 |
+
while True:
|
| 126 |
+
try:
|
| 127 |
+
# Files that have not been processed yet.
|
| 128 |
+
completed_files = os.listdir(output_dir)
|
| 129 |
+
print(f"completed_files: {len(completed_files)}")
|
| 130 |
+
|
| 131 |
+
# Files that have not been processed yet.
|
| 132 |
+
incomplete_files = [f for f in caption_files if f not in completed_files]
|
| 133 |
+
print(f"incomplete_files: {len(incomplete_files)}")
|
| 134 |
+
|
| 135 |
+
# Break the loop when there are no incomplete files
|
| 136 |
+
if len(incomplete_files) == 0:
|
| 137 |
+
break
|
| 138 |
+
if len(incomplete_files) <= num_tasks:
|
| 139 |
+
num_tasks = 1
|
| 140 |
+
|
| 141 |
+
# Split tasks into parts.
|
| 142 |
+
part_len = len(incomplete_files) // num_tasks
|
| 143 |
+
all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
|
| 144 |
+
task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
|
| 145 |
+
|
| 146 |
+
# Use a pool of workers to process the files in parallel.
|
| 147 |
+
with Pool() as pool:
|
| 148 |
+
pool.starmap(annotate, task_args)
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error: {e}")
|
| 152 |
+
|
| 153 |
+
# Combine all the processed files into one
|
| 154 |
+
combined_contents = {}
|
| 155 |
+
json_path = args.output_json
|
| 156 |
+
|
| 157 |
+
# Iterate through json files
|
| 158 |
+
for file_name in os.listdir(output_dir):
|
| 159 |
+
if file_name.endswith(".json"):
|
| 160 |
+
file_path = os.path.join(output_dir, file_name)
|
| 161 |
+
with open(file_path, "r") as json_file:
|
| 162 |
+
content = json.load(json_file)
|
| 163 |
+
combined_contents[file_name[:-5]] = content
|
| 164 |
+
|
| 165 |
+
# Write combined content to a json file
|
| 166 |
+
with open(json_path, "w") as json_file:
|
| 167 |
+
json.dump(combined_contents, json_file)
|
| 168 |
+
print("All evaluation completed!")
|
| 169 |
+
|
| 170 |
+
# Calculate average score and accuracy
|
| 171 |
+
score_sum = 0
|
| 172 |
+
count = 0
|
| 173 |
+
yes_count = 0
|
| 174 |
+
no_count = 0
|
| 175 |
+
for key, result in combined_contents.items():
|
| 176 |
+
# Computing score
|
| 177 |
+
count += 1
|
| 178 |
+
try :
|
| 179 |
+
score_match = result[0]['score']
|
| 180 |
+
score = int(score_match)
|
| 181 |
+
score_sum += score
|
| 182 |
+
except:
|
| 183 |
+
print("Score not found for", key)
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
# Computing accuracy
|
| 187 |
+
try:
|
| 188 |
+
pred = result[0]['pred']
|
| 189 |
+
if "yes" in pred.lower():
|
| 190 |
+
yes_count += 1
|
| 191 |
+
elif "no" in pred.lower():
|
| 192 |
+
no_count += 1
|
| 193 |
+
except:
|
| 194 |
+
print("Prediction not found for", key)
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
average_score = score_sum / count
|
| 198 |
+
accuracy = yes_count / (yes_count + no_count)
|
| 199 |
+
print("Yes count:", yes_count)
|
| 200 |
+
print("No count:", no_count)
|
| 201 |
+
print("Accuracy:", accuracy)
|
| 202 |
+
print("Average score:", average_score)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
main()
|
| 207 |
+
|
GPT_evaluation/evaluate_zeroshot.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=zeroshot_eval%j
|
| 4 |
+
#SBATCH --output=zeroshot_eval%j.out
|
| 5 |
+
#SBATCH --error=zeroshot_eval%j.err
|
| 6 |
+
#SBATCH --time=0-10:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --nodes=1
|
| 9 |
+
|
| 10 |
+
## run the application:
|
| 11 |
+
|
| 12 |
+
# PRED="pred_path"
|
| 13 |
+
# OUTPUT_DIR="output_dir"
|
| 14 |
+
# API_KEY="api_key"
|
| 15 |
+
# NUM_TASKS=128
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
python evaluate_zeroshot.py \
|
| 19 |
+
--pred_path ${PRED} \
|
| 20 |
+
--output_dir "${OUTPUT_DIR}/fewshot_accuracy" \
|
| 21 |
+
--output_json "${OUTPUT_DIR}/fewshot_accuracy_results.json"\
|
| 22 |
+
--api_key $API_KEY \
|
| 23 |
+
--num_tasks $NUM_TASKS
|
| 24 |
+
|
| 25 |
+
echo pred_path: $PRED
|
LICENSE.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright 2023 Deyao Zhu
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 7 |
+
|
| 8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 9 |
+
|
| 10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 11 |
+
|
| 12 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 13 |
+
|
| 14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
LICENSE_Lavis.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Salesforce, Inc.
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 7 |
+
|
| 8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 9 |
+
|
| 10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 11 |
+
|
| 12 |
+
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 13 |
+
|
| 14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [ECCV 2024 Accepted]Goldfish: Vision-Language Understanding of Arbitrarily Long Videos
|
| 2 |
+
# [CVPR2024W]MiniGPT4-Video: Advancing Multimodal LLMs for Video Understanding with Interleaved Visual-Textual Tokens
|
| 3 |
+
**This repo contains the codes for MiniGPT4-video for short video understanding and Goldfish for long video understanding.**
|
| 4 |
+
<h3 style="text-align: center;">Online Demos</h3>
|
| 5 |
+
<div style="display: flex; justify-content: center; gap: 40px;">
|
| 6 |
+
<div style="text-align: center;">
|
| 7 |
+
<a href='https://goldfishdemo.loophole.site'>
|
| 8 |
+
<img src='repo_imgs/goldfishai_png.png' width=200 height=200>
|
| 9 |
+
</a>
|
| 10 |
+
<div>
|
| 11 |
+
<font size=3>
|
| 12 |
+
<div>
|
| 13 |
+
<img src="repo_imgs/goldfishai_png.png" width=18>
|
| 14 |
+
<a href="https://vision-cair.github.io/Goldfish_website/">Project Page</a>
|
| 15 |
+
<a href="https://arxiv.org/abs/2407.12679">📝 arXiv Paper</a>
|
| 16 |
+
<a href="https://huggingface.co/datasets/Vision-CAIR/TVQA-Long/tree/main">🤗 TVQA-Long Dataset</a>
|
| 17 |
+
</div>
|
| 18 |
+
</font>
|
| 19 |
+
</div>
|
| 20 |
+
</div>
|
| 21 |
+
<div style="text-align: center;">
|
| 22 |
+
<a href='https://huggingface.co/spaces/Vision-CAIR/MiniGPT4-video'>
|
| 23 |
+
<img src='repo_imgs/minigpt4_demo_icon.png' width=200 height=200>
|
| 24 |
+
</a>
|
| 25 |
+
<div>
|
| 26 |
+
<font size=3>
|
| 27 |
+
<div>
|
| 28 |
+
<a href="https://vision-cair.github.io/MiniGPT4-video/">🎞️ Project Page</a>
|
| 29 |
+
<a href="https://arxiv.org/abs/2404.03413">📝 arXiv Paper</a>
|
| 30 |
+
</div>
|
| 31 |
+
</font>
|
| 32 |
+
</div>
|
| 33 |
+
</div>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
## Overview
|
| 39 |
+
Most current LLM-based models for video understanding can
|
| 40 |
+
process videos within minutes but struggle with processing lengthy videos
|
| 41 |
+
due to the “noise and redundancy challenge” and “memory and compu-
|
| 42 |
+
tation” challenges. In this paper, we present Goldfish, a methodology
|
| 43 |
+
tailored for comprehending videos of arbitrary lengths. We also introduce
|
| 44 |
+
the TVQA-long benchmark, specifically designed to evaluate models’
|
| 45 |
+
capabilities in understanding long videos with questions in both vision
|
| 46 |
+
and text content. Goldfish approaches these challenges with an efficient
|
| 47 |
+
retrieval mechanism that initially gathers the top-k video clips relevant to
|
| 48 |
+
the instruction before proceeding to provide the desired response. This de-
|
| 49 |
+
sign of the retrieval mechanism enables the Goldfish to efficiently process
|
| 50 |
+
arbitrarily long video sequences, facilitating its application in contexts
|
| 51 |
+
such as movies or television series. To facilitate the retrieval process, we
|
| 52 |
+
developed MiniGPT4-Video that generates detailed descriptions for the
|
| 53 |
+
video clips. In addressing the scarcity of benchmarks for long video evalu-
|
| 54 |
+
ation, we adapted the TVQA short video benchmark for extended content
|
| 55 |
+
analysis by aggregating questions from entire episodes, thereby shifting
|
| 56 |
+
the evaluation from partial to full episode comprehension. We attained a
|
| 57 |
+
41.78% accuracy rate on the TVQA-long benchmark, surpassing previous
|
| 58 |
+
methods by 14.94%. Our MiniGPT4-Video also shows exceptional perfor-
|
| 59 |
+
mance in short video comprehension, exceeding existing state-of-the-art
|
| 60 |
+
methods by 3.23%, 2.03%, 16.5% and 23.59% on the MSVD, MSRVTT,
|
| 61 |
+
TGIF,and TVQA short video benchmarks, respectively. These results
|
| 62 |
+
indicate that our models have significant improvements in both long and
|
| 63 |
+
short-video understanding.
|
| 64 |
+
### Goldfish framework (Long videos)
|
| 65 |
+
<br>
|
| 66 |
+

|
| 67 |
+
### MiniGPT4-Video (Short videos)
|
| 68 |
+

|
| 69 |
+
|
| 70 |
+
[](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-tgif-qa?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 71 |
+
|
| 72 |
+
[](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-tvqa?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 73 |
+
|
| 74 |
+
[](https://paperswithcode.com/sota/video-based-generative-performance-1?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 75 |
+
|
| 76 |
+
[](https://paperswithcode.com/sota/video-based-generative-performance-3?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 77 |
+
|
| 78 |
+
[](https://paperswithcode.com/sota/video-based-generative-performance-4?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 79 |
+
|
| 80 |
+
[](https://paperswithcode.com/sota/video-based-generative-performance-5?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 81 |
+
|
| 82 |
+
[](https://paperswithcode.com/sota/video-based-generative-performance-2?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 83 |
+
|
| 84 |
+
[](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 85 |
+
|
| 86 |
+
[](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 87 |
+
|
| 88 |
+
[](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=minigpt4-video-advancing-multimodal-llms-for)
|
| 89 |
+
|
| 90 |
+

|
| 91 |
+

|
| 92 |
+

|
| 93 |
+
## :rocket: Demo
|
| 94 |
+
**1. Clone the repository** <br>
|
| 95 |
+
```bash
|
| 96 |
+
git clone https://github.com/Vision-CAIR/MiniGPT4-video.git
|
| 97 |
+
cd MiniGPT4-video
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
**2. Set up the environment** <br>
|
| 101 |
+
```bash
|
| 102 |
+
conda env create -f environment.yml
|
| 103 |
+
```
|
| 104 |
+
**3. Download the checkpoints**
|
| 105 |
+
|
| 106 |
+
| MiniGPT4-Video (Llama2 Chat 7B) | MiniGPT4-Video (Mistral 7B) |
|
| 107 |
+
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
| 108 |
+
| [Download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_llama_checkpoint_last.pth) | [Download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_mistral_checkpoint_last.pth) |
|
| 109 |
+
|
| 110 |
+
**4. Run the demo** <br>
|
| 111 |
+
Goldfish demo
|
| 112 |
+
```bash
|
| 113 |
+
# For recommended performance, add the parameter --use_openai_embedding True to the command below and set the API key in the environment variable OPENAI_API_KEY otherwise the model will use the default embeddings.
|
| 114 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 115 |
+
# Llama2
|
| 116 |
+
python goldfish_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/llama2_test_config.yaml
|
| 117 |
+
# Mistral
|
| 118 |
+
python goldfish_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/mistral_test_config.yaml
|
| 119 |
+
```
|
| 120 |
+
MiniGPT4-Video demo
|
| 121 |
+
```bash
|
| 122 |
+
# Llama2
|
| 123 |
+
python minigpt4_video_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/llama2_test_config.yaml
|
| 124 |
+
# Mistral
|
| 125 |
+
python minigpt4_video_demo.py --ckpt path_to_video_checkpoint --cfg-path test_configs/mistral_test_config.yaml
|
| 126 |
+
```
|
| 127 |
+
### Inference
|
| 128 |
+
Do the previous steps and replace step 4 with this step <br>
|
| 129 |
+
Goldfish inference
|
| 130 |
+
```bash
|
| 131 |
+
# For recommended performance, add the parameter --use_openai_embedding True to the command below and set the API key in the environment variable OPENAI_API_KEY otherwise the model will use the default embeddings.
|
| 132 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 133 |
+
# Llama2
|
| 134 |
+
python goldfish_inference.py --ckpt path_to_llama2_checkpoint --cfg-path test_configs/llama2_test_config.yaml --video_path path_to_video --question "Your question here"
|
| 135 |
+
# Mistral
|
| 136 |
+
python goldfish_inference.py --ckpt path_to_mistral_checkpoint --cfg-path test_configs/mistral_test_config.yaml --video_path path_to_video --question "Your question here"
|
| 137 |
+
```
|
| 138 |
+
MiniGPT4-Video inference
|
| 139 |
+
```bash
|
| 140 |
+
# Llama2
|
| 141 |
+
python minigpt4_video_inference.py --ckpt path_to_llama2_checkpoint --cfg-path test_configs/llama2_test_config.yaml --video_path path_to_video --question "Your question here"
|
| 142 |
+
# Mistral
|
| 143 |
+
python minigpt4_video_inference.py --ckpt path_to_mistral_checkpoint --cfg-path test_configs/mistral_test_config.yaml --video_path path_to_video --question "Your question here"
|
| 144 |
+
```
|
| 145 |
+
## :fire: Training
|
| 146 |
+
For both Goldfish and MiniGPT4-Video, the only training part is the MiniGPT4-Video model. <br>
|
| 147 |
+
### To customize MiniGPT4-Video for your own Video-text dataset
|
| 148 |
+
<!-- point to file here Custom_training.md -->
|
| 149 |
+
You can find the steps to customize MiniGPT4-Video for your own video-text dataset in [Custom_training.md](Custom_training.md)
|
| 150 |
+
### Training datasets
|
| 151 |
+
After downloading the datasets below, **you should go to the datasets configuration folder here minigpt4/configs/datasets set the paths for each dataset there.**<br>
|
| 152 |
+
Image text training<br>
|
| 153 |
+
You can find the steps to download the datasets in [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4/tree/main/dataset)<br>
|
| 154 |
+
+ LAION <br>
|
| 155 |
+
+ Conceptual Captions <br>
|
| 156 |
+
+ SBU <br>
|
| 157 |
+
|
| 158 |
+
Video text training:<br>
|
| 159 |
+
|
| 160 |
+
+ [CMD](https://www.robots.ox.ac.uk/~vgg/data/condensed-movies/) <br>
|
| 161 |
+
+ [Webvid](https://github.com/m-bain/webvid/) <br> <!-- + [Webvid](https://huggingface.co/datasets/TempoFunk/webvid-10M?row=2) <br> -->
|
| 162 |
+
+ [Video Instructional Dataset 100K](https://huggingface.co/datasets/MBZUAI/VideoInstruct-100K) <br>
|
| 163 |
+
|
| 164 |
+
You can find the datasets annotation files for video_text datasets here [download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/training_datasets) <br>
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
### Model training:
|
| 168 |
+
You can edit the number of gpus in the each script.sh below<br>
|
| 169 |
+
#### Stage 1 (image text pretraining)
|
| 170 |
+
|
| 171 |
+
You can directly download the pretrained MiniGPT4 [checkpoint](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) aligned with Llama2. <br>
|
| 172 |
+
|
| 173 |
+
Or train by yourself:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
# pretrain
|
| 177 |
+
# Llama2
|
| 178 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_llama2_image.yaml
|
| 179 |
+
# Mistral
|
| 180 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_mistral_image.yaml
|
| 181 |
+
|
| 182 |
+
# align
|
| 183 |
+
# To launch the second stage alignment, first specify the path to the checkpoint file trained in pretrain stage.
|
| 184 |
+
# Llama2
|
| 185 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_llama2_image_align.yaml
|
| 186 |
+
# Mistral
|
| 187 |
+
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/224_minigpt4_mistral_image_align.yaml
|
| 188 |
+
```
|
| 189 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/image_llama2_checkpoint.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/image_mistral_checkpoint.pth)<br>
|
| 190 |
+
#### Stage 2 (video captioning pretraining)
|
| 191 |
+
|
| 192 |
+
For **Llama2** <br>
|
| 193 |
+
set the cfg-path in the script to `train_configs/224_v2_llama2_video_stage_2.yaml` <br>
|
| 194 |
+
set the model name here `minigpt4/configs/datasets/cmd_video/default.yaml` and `minigpt4/configs/datasets/webvid/default.yaml` to llama2<br>
|
| 195 |
+
For **Mistral**<br>
|
| 196 |
+
set the cfg-path in the script to `train_configs/224_v2_mistral_video_stage_2.yaml` <br>
|
| 197 |
+
set the model name here `minigpt4/configs/datasets/cmd_video/default.yaml` and `minigpt4/configs/datasets/webvid/default.yaml` to mistral<br>
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
bash training_scripts/stage_2.sh
|
| 201 |
+
```
|
| 202 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_llama_checkpoint_last.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_captioning_mistral_checkpoint_last.pth)<br>
|
| 203 |
+
|
| 204 |
+
#### Stage 3 (video Instruction finetuning)
|
| 205 |
+
|
| 206 |
+
For **Llama2** <br>
|
| 207 |
+
set the cfg-path in the script to `train_configs/224_v2_llama2_video_stage_3.yaml` <br>
|
| 208 |
+
set the model name here `minigpt4/configs/datasets/video_chatgpt/default.yaml` to llama2<br>
|
| 209 |
+
|
| 210 |
+
For **Mistral**<br>
|
| 211 |
+
set the cfg-path in the script to `train_configs/224_v2_mistral_video_stage_3.yaml` <br>
|
| 212 |
+
set the model name here `minigpt4/configs/datasets/video_chatgpt/default.yaml` to mistral<br>
|
| 213 |
+
|
| 214 |
+
```bash
|
| 215 |
+
bash training_scripts/stage_3.sh
|
| 216 |
+
```
|
| 217 |
+
You can download our trained weights for this stage from here [Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_llama_checkpoint_last.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_mistral_checkpoint_last.pth)<br>
|
| 218 |
+
|
| 219 |
+
## :zap: MiniGPT4-Video Evaluation
|
| 220 |
+
To reproduce the results use the best checkpoints for each model <br>
|
| 221 |
+
[Llama2](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_llama_checkpoint_best.pth) [Mistral](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/checkpoints/video_mistral_checkpoint_best.pth)<br>
|
| 222 |
+
We used the same evaluation as [Video-ChatGPT](https://mbzuai-oryx.github.io/Video-ChatGPT/)<br>
|
| 223 |
+
|
| 224 |
+
|Method| Using Subtitles | Information Correctness | Detailed Orientation | Contextual Understanding | Temporal Understanding | Consistency |
|
| 225 |
+
|:--------------------:|:----:|:------------------------:|:---------------------:|:-------------------------:|:-----------------------:|:------------:|
|
| 226 |
+
| LLaMA Adapter | :x:| 2.03 | 2.32| 2.30| 1.98| 2.15 |
|
| 227 |
+
| Video LLaMA| :x:| 1.96 | 2.18| 2.16| 1.82| 1.79 |
|
| 228 |
+
| Video Chat| :x:| 2.23 | 2.50| 2.53| 1.94| 2.24 |
|
| 229 |
+
| Video-ChatGPT | :x:| 2.40 | 2.52| 2.62| 1.98| 2.37 |
|
| 230 |
+
| BT-Adapter-7B | :x:| 2.68 | 2.69| 3.27| 2.34| 2.46 |
|
| 231 |
+
| LLaMA-VID-7B| :x:| 2.96 | 3.00| 3.53| 2.46| 2.51 |
|
| 232 |
+
| **Ours-7B Llama2**| :x:| 2.93 | 2.97| 3.45| **2.47**| **2.60**|
|
| 233 |
+
| **Ours-7B Llama2**| :white_check_mark:| **3.08** | **3.02**| **3.57**| **2.65**| **2.67**|
|
| 234 |
+
| **Ours-7B Mistral** | :x:| 2.83|2.52 |3.01 |2.32 |2.40 |
|
| 235 |
+
| **Ours-7B Mistral**| :white_check_mark:| 2.91 | 2.57| 3.11|2.33 | 2.39|
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|Method| Using Subtitles | MSVD Acc.↑ | MSVD Score↑ | MSRVTT Acc.↑ | MSRVTT Score↑ | TGIF Acc.↑ | TGIF Score↑ | ActivityNet Acc.↑ | ActivityNet Score↑ | TVQA Acc.↑ |
|
| 240 |
+
|:---------------------------------------:|:----------------:|:-----------:|:------------:|:--------------:|:---------------:|:-----------:|:------------:|:-------------------:|:--------------------:|:------------:|
|
| 241 |
+
| FrozenBiLM|:x:|32.2| --|16.8 |--| 41 |-- |24.7|--|29.7 |
|
| 242 |
+
| LLaMA Adapter|:x:|54.9| 3.1 |43.8 |2.7| -- |-- |34.2| 2.7| --|
|
| 243 |
+
| Video LLaMA|:x:|51.6| 2.5 |29|1.8| -- |-- |12.4| 1.1| --|
|
| 244 |
+
| Video Chat|:x:|56.3| 2.8 |45|2.5|34.4| 2.3 |26.5| 2.2|--|
|
| 245 |
+
| Video-ChatGPT|:x:|64.9| 3.3 |49.3 |2.8|51.4| 3.0 |35.2| 2.7|23.35|
|
| 246 |
+
| BT-Adapter-7B|:x:|67.7| 3.7 |57|3.2| -- |-- |45.7| 3.2| --|
|
| 247 |
+
| LLaMA-VID-7B |:x:|69.7| 3.7 |57.7 |3.2| -- |-- |**47.4**| **3.3**| --|
|
| 248 |
+
| **Ours-7B LLama2**|:x:|72.93|3.84|58.83|3.29|67.9|3.71| 45.85 |3.23|36.45|
|
| 249 |
+
| **Ours-7B Llama2**|:white_check_mark:|72.93|3.84|**59.73**|**3.3** |67.9|3.71| 46.3|3.4 |46.94|
|
| 250 |
+
| **Ours-7B Mistral**|:x:|**73.92**|**4.06**|58.26|3.52|**72.22**|**4.08**|44.25 |3.35|33.90|
|
| 251 |
+
| **Ours-7B Mistral**|:white_check_mark:|**73.92**|**4.06**|58.68|3.53 |**72.22**|**4.08**| 44.38|3.36 |**54.21** |
|
| 252 |
+
|
| 253 |
+
### Download datasets for evaluation
|
| 254 |
+
+ [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) <br>
|
| 255 |
+
+ [MSRVTT](https://cove.thecvf.com/datasets/839) <br>
|
| 256 |
+
+ [TGIF](https://github.com/YunseokJANG/tgif-qa/blob/master/dataset/README.md) <br>
|
| 257 |
+
+ [ActivityNet](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/hanoona_bangalath_mbzuai_ac_ae/ESa302OCJMNHsMk7wuBbQc8BZH5CqlcdCWiSpXynQZDfAQ?e=CrOPbm) <br>
|
| 258 |
+
+ [TVQA](https://nlp.cs.unc.edu/data/jielei/tvqa/tvqa_public_html/download_tvqa.html) <br>
|
| 259 |
+
+ [Video-ChatGPT benchmark](https://mbzuai-oryx.github.io/Video-ChatGPT/) <br>
|
| 260 |
+
|
| 261 |
+
You can find the evaluation datasets annotation files [download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/evaluation_datasets) <br>
|
| 262 |
+
|
| 263 |
+
Subtitles for MSR-VTT,and ActivityNet are availabe here [download](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/resolve/main/datasets/evaluation_subtitles.zip)
|
| 264 |
+
note these subtitles are generated using <a href="https://github.com/openai/whisper">whisper model<br>
|
| 265 |
+
TVQA subtitles can be downloaded from [here](https://nlp.cs.unc.edu/data/jielei/tvqa/tvqa_public_html/download_tvqa.html)
|
| 266 |
+
### Run evaluation script
|
| 267 |
+
Set the each evaluation script parameters in the script <br>
|
| 268 |
+
```
|
| 269 |
+
NAME="" # Name of the experiment
|
| 270 |
+
BATCH_SIZE=8 # batch size
|
| 271 |
+
CKPT_PATH="" # path to the checkpoint
|
| 272 |
+
DATASET="msvd" # dataset name, available datasets: tvqa, msrvtt, msvd, activitynet,tgif ,video_chatgpt_generic,video_chatgpt_temporal,video_chatgpt_consistency
|
| 273 |
+
# set the paths to the dataset files
|
| 274 |
+
videos_path="" # path to the videos file
|
| 275 |
+
subtitles_path="" # path to the subtitles file if the dataset is msrvtt, activitynet or tvqa else set it to ""
|
| 276 |
+
ann_path="" # path to the annotations file
|
| 277 |
+
cfg_path="" # path to the config file
|
| 278 |
+
```
|
| 279 |
+
<br>
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
bash evaluation/minigpt4_video_eval/minigpt4_video_evalualtion.sh
|
| 283 |
+
```
|
| 284 |
+
Then Use GPT3.5 turbo to compare the predictions with the ground truth and generate the accuracy and scores <br>
|
| 285 |
+
Set these variables in both evaluate_benchmark.sh and evaluate_zeroshot.sh <br>
|
| 286 |
+
```bash
|
| 287 |
+
PRED="path_to_predictions"
|
| 288 |
+
OUTPUT_DIR="path_to_output_dir"
|
| 289 |
+
API_KEY="openAI_key"
|
| 290 |
+
NUM_TASKS=128
|
| 291 |
+
```
|
| 292 |
+
Then to evaluate [Video-ChatGPT benchmark] run the following script <br>
|
| 293 |
+
```bash
|
| 294 |
+
bash GPT_evaluation/evaluate_benchmark.sh
|
| 295 |
+
```
|
| 296 |
+
To evaluate open ended questions run the following script <br>
|
| 297 |
+
```bash
|
| 298 |
+
bash GPT_evaluation/evaluate_zeroshot.py
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
## :zap: Goldfish Evaluation
|
| 302 |
+
**Long video benchmarking results on four benchmarks: LLama-Vid, MovieChat, Movie QA, and our proposed TVQA-Long. The "V" modality indicates the use of video frames only, while "V+T" indicates the use of both video frames and subtitles**
|
| 303 |
+
|
| 304 |
+
<!--  -->
|
| 305 |
+
| Method | Modalities | LLama-Vid Acc.↑ | LLama-Vid Score↑ | MovieChat Acc.↑ | MovieChat Score↑ | Movie QA Acc.↑ | Movie QA Score↑ | TVQA-Long Acc.↑ | TVQA-Long Score↑ |
|
| 306 |
+
|-------------|------------|-----------------|------------------|-----------------|------------------|----------------|-----------------|------------|-------------|
|
| 307 |
+
| LLAMA-VID | V | 20.68 | 2.41 | 53.2 | 3.81 | 24.42 | 2.19 | 24.63 | 2.16 |
|
| 308 |
+
| MovieChat | V | 11.71 | 1.45 | NA | NA | 16.18 | 1.68 | 5.0 | 0.86 |
|
| 309 |
+
| Ours | V | **23.09** | 2.19 | **67.6** | **4.23** | **28.49** | **2.8** | **28.61** | **2.78** |
|
| 310 |
+
| LLAMA-VID | V+T | 41.4† | 3.07† | NA | NA | 37.65† | 3.03† | 26.86 | 2.21 |
|
| 311 |
+
| Ours | V+T | 31.49 | 2.48 | NA | NA | 35.24 | **3.1** | **41.78** | **3.21** |
|
| 312 |
+
|
| 313 |
+
**Note: The dagger † symbol indicates the method used the benchmark during training, which implies unfair comparison.**
|
| 314 |
+
|
| 315 |
+
To reproduce the results use the `checkpoints/video_llama_checkpoint_last.pth` and use openAI embedding `--use_openai_embedding=True`<br>
|
| 316 |
+
### Download datasets for evaluation
|
| 317 |
+
For **Llama-vid** and **MovieQA** <br>
|
| 318 |
+
Dowlnoad the original MovieNet data with movies and annotations from [here](https://opendatalab.com/OpenDataLab/MovieNet/tree/main/raw)<br>
|
| 319 |
+
This will be the souce videos for LLama-vid and MovieQA <br>
|
| 320 |
+
#### Filtered Annotations same as illestrated in the paper and used for evaluation
|
| 321 |
+
[Llama-vid](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/goldfish_eval_datasets/llama_vid)<br>
|
| 322 |
+
[MovieQA](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/goldfish_eval_datasets/movie_qa)<br>
|
| 323 |
+
For **Moviechat** the only available videos while implementing this work is 10 % of the training data and this what we used for evalaution and can be found [here](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/blob/main/datasets/goldfish_eval_datasets/movie_chat/available_movies_list.txt) <br>
|
| 324 |
+
Full dataset can be found [here](https://huggingface.co/datasets/Enxin/MovieChat-1K_train/tree/main) <br>
|
| 325 |
+
For **TVQA-Long** <br>
|
| 326 |
+
if you want to use TVQA-Long for another model (llama-vid),both videos and annotations can be found here [TVQA-Long](https://huggingface.co/datasets/Vision-CAIR/TVQA-Long/tree/main).
|
| 327 |
+
For Goldfish evalaution we will use the separated clips from the original TVQA dataset <br>
|
| 328 |
+
### Run the evaluation scripts
|
| 329 |
+
``` bash
|
| 330 |
+
# Llama-vid evalauation
|
| 331 |
+
# set these parameters in the script
|
| 332 |
+
videos_path="path to the videos"
|
| 333 |
+
subtitle_path="path to the subtitles"
|
| 334 |
+
video_clips_saving_path="path to save the video clips"
|
| 335 |
+
annotation_file="path to the annotation file"
|
| 336 |
+
movienet_annotations_dir="path to the movienet annotations directory"
|
| 337 |
+
NEIGHBOURS=3
|
| 338 |
+
use_openai_embedding="whether to use openai embeddings or not"
|
| 339 |
+
# then run the script
|
| 340 |
+
bash evaluation/Goldfish_eval/movies/eval_model_summary_llama_vid.sh
|
| 341 |
+
|
| 342 |
+
# MovieQA evaluation
|
| 343 |
+
# same as above but set the parameters in the script to the MovieQA paths
|
| 344 |
+
bash evaluation/Goldfish_eval/movies/eval_model_summary_movie_qa.sh
|
| 345 |
+
|
| 346 |
+
# MovieChat evaluation
|
| 347 |
+
# set these parameters in the script
|
| 348 |
+
dataset_path="path to the movies folder"
|
| 349 |
+
annotation_json_folder="path to the jsons folder"
|
| 350 |
+
# then run the script
|
| 351 |
+
bash evaluation/Goldfish_eval/movies/eval_model_summary_movie_chat.sh
|
| 352 |
+
```
|
| 353 |
+
### TVQA-Long
|
| 354 |
+
For Goldfish evaluation we can use the original separated clips from the original TVQA dataset <br>
|
| 355 |
+
Download the original TVQA videos and clips subtitles for short videos from [here](https://nlp.cs.unc.edu/data/jielei/tvqa/tvqa_public_html/download_tvqa.html)<br>
|
| 356 |
+
tvqa_long_annotation [here](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/goldfish_eval_datasets/tvqa/tvqa_val_edited.json) <br>
|
| 357 |
+
tvqa_json_subtitles [here](https://huggingface.co/Vision-CAIR/MiniGPT4-Video/tree/main/datasets/goldfish_eval_datasets/tvqa/tvqa_preprocessed_subtitles.json)<br>
|
| 358 |
+
|
| 359 |
+
```bash
|
| 360 |
+
# set these parameters in the script
|
| 361 |
+
tvqa_json_subtitles="path to the tvqa json subtitles file"
|
| 362 |
+
tvqa_clips_subtitles="path to the tvqa clips subtitles"
|
| 363 |
+
videos_frames="path to the video frames"
|
| 364 |
+
tvqa_long_annotation="path to the TVQA-Long annotation file"
|
| 365 |
+
NEIGHBOURS= 3
|
| 366 |
+
use_openai_embedding="whether to use openai embeddings or not"
|
| 367 |
+
# then run the script
|
| 368 |
+
bash evaluation/Goldfish_eval/tvqa_eval/eval_model_summary.sh
|
| 369 |
+
````
|
| 370 |
+
|
| 371 |
+
Then Use GPT3.5 turbo to compare the predictions with the ground truth and generate the accuracy and scores <br>
|
| 372 |
+
Set these variables in evaluate_zeroshot.sh <br>
|
| 373 |
+
```bash
|
| 374 |
+
PRED="path_to_predictions"
|
| 375 |
+
OUTPUT_DIR="path_to_output_dir"
|
| 376 |
+
API_KEY="openAI_key"
|
| 377 |
+
NUM_TASKS=128
|
| 378 |
+
```
|
| 379 |
+
To evaluate open ended questions run the following script <br>
|
| 380 |
+
```bash
|
| 381 |
+
bash GPT_evaluation/evaluate_zeroshot.sh
|
| 382 |
+
```
|
| 383 |
+
|
| 384 |
+
## Citation
|
| 385 |
+
If you're using MiniGPT4-Video or Goldfish in your research or applications, please cite using this BibTeX:
|
| 386 |
+
```
|
| 387 |
+
@misc{ataallah2024goldfishvisionlanguageunderstandingarbitrarily,
|
| 388 |
+
title={Goldfish: Vision-Language Understanding of Arbitrarily Long Videos},
|
| 389 |
+
author={Kirolos Ataallah and Xiaoqian Shen and Eslam Abdelrahman and Essam Sleiman and Mingchen Zhuge and Jian Ding and Deyao Zhu and Jürgen Schmidhuber and Mohamed Elhoseiny},
|
| 390 |
+
year={2024},
|
| 391 |
+
eprint={2407.12679},
|
| 392 |
+
archivePrefix={arXiv},
|
| 393 |
+
primaryClass={cs.CV},
|
| 394 |
+
url={https://arxiv.org/abs/2407.12679},
|
| 395 |
+
}
|
| 396 |
+
@article{ataallah2024minigpt4,
|
| 397 |
+
title={MiniGPT4-Video: Advancing Multimodal LLMs for Video Understanding with Interleaved Visual-Textual Tokens},
|
| 398 |
+
author={Ataallah, Kirolos and Shen, Xiaoqian and Abdelrahman, Eslam and Sleiman, Essam and Zhu, Deyao and Ding, Jian and Elhoseiny, Mohamed},
|
| 399 |
+
journal={arXiv preprint arXiv:2404.03413},
|
| 400 |
+
year={2024}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
## Acknowledgements
|
| 406 |
+
[MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4) <br>
|
| 407 |
+
[Video-ChatGPT](https://mbzuai-oryx.github.io/Video-ChatGPT)
|
| 408 |
+
|
| 409 |
+
## License
|
| 410 |
+
This repository is under [BSD 3-Clause License](LICENSE.md).
|
| 411 |
+
Many codes are based on [MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4).
|
clean_stage3_json.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
VIDEO_DIR = "datasets/stage3/videos"
|
| 5 |
+
JSON_PATH = "datasets/stage3/video_instruct_data.json"
|
| 6 |
+
OUTPUT_JSON = "datasets/stage3/video_instruct_data_clean.json"
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
print("🚀 开始清洗 Stage 3 JSON...")
|
| 10 |
+
# 1. 扫描本地视频 ID
|
| 11 |
+
existing_ids = set()
|
| 12 |
+
for f in os.listdir(VIDEO_DIR):
|
| 13 |
+
if f.endswith(('.mp4', '.mkv', '.webm')):
|
| 14 |
+
existing_ids.add(os.path.splitext(f)[0])
|
| 15 |
+
print(f"✅ 本地视频数: {len(existing_ids)}")
|
| 16 |
+
|
| 17 |
+
# 2. 读取全量 JSON
|
| 18 |
+
with open(JSON_PATH, 'r') as f:
|
| 19 |
+
data = json.load(f)
|
| 20 |
+
|
| 21 |
+
# 3. 过滤:只保留本地有的
|
| 22 |
+
clean_data = []
|
| 23 |
+
for item in data:
|
| 24 |
+
# 兼容不同的键名情况
|
| 25 |
+
vid = item.get("video_id") or item.get("video_name") or item.get("image_id")
|
| 26 |
+
if vid in existing_ids:
|
| 27 |
+
clean_data.append(item)
|
| 28 |
+
|
| 29 |
+
# 4. 保存
|
| 30 |
+
with open(OUTPUT_JSON, 'w') as f:
|
| 31 |
+
json.dump(clean_data, f)
|
| 32 |
+
print(f"🎉 清洗完毕!有效数据: {len(clean_data)} 条。已保存至 {OUTPUT_JSON}")
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
convert_cmd_to_json.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# ================= 配置 =================
|
| 6 |
+
BASE_DIR = "datasets"
|
| 7 |
+
METADATA_DIR = os.path.join(BASE_DIR, "CondensedMovies_Metadata")
|
| 8 |
+
VIDEO_DIR = os.path.join(BASE_DIR, "CondensedMovies_Videos")
|
| 9 |
+
OUTPUT_JSON = os.path.join(BASE_DIR, "cmd_annotations.json")
|
| 10 |
+
# ========================================
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
print("🚀 生成标准 CMD JSON...")
|
| 14 |
+
|
| 15 |
+
# 1. 读取 CSV
|
| 16 |
+
df_clips = pd.read_csv(os.path.join(METADATA_DIR, "clips.csv"))
|
| 17 |
+
df_desc = pd.read_csv(os.path.join(METADATA_DIR, "descriptions.csv"))
|
| 18 |
+
df_merged = pd.merge(df_clips, df_desc, on="videoid", how="inner")
|
| 19 |
+
|
| 20 |
+
# 2. 扫描本地视频 (现在它们都在根目录了,且都是 mp4)
|
| 21 |
+
existing_ids = set()
|
| 22 |
+
for f in os.listdir(VIDEO_DIR):
|
| 23 |
+
if f.endswith(".mp4"):
|
| 24 |
+
existing_ids.add(os.path.splitext(f)[0])
|
| 25 |
+
|
| 26 |
+
print(f"✅ 本地找到 {len(existing_ids)} 个视频")
|
| 27 |
+
|
| 28 |
+
# 3. 生成列表
|
| 29 |
+
annotations = []
|
| 30 |
+
for _, row in df_merged.iterrows():
|
| 31 |
+
vid = row['videoid']
|
| 32 |
+
if vid in existing_ids:
|
| 33 |
+
# 只要 image_id 和 caption,完全符合原始代码要求
|
| 34 |
+
annotations.append({
|
| 35 |
+
"image_id": vid,
|
| 36 |
+
"caption": row['description']
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
# 4. 保存
|
| 40 |
+
with open(OUTPUT_JSON, 'w') as f:
|
| 41 |
+
json.dump(annotations, f)
|
| 42 |
+
print(f"🎉 JSON 生成完毕: {len(annotations)} 条数据")
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
convert_csv_to_json2.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# 读取 CSV
|
| 6 |
+
csv_path = 'datasets/stage3/video_instruct_data.csv'
|
| 7 |
+
df = pd.read_csv(csv_path)
|
| 8 |
+
|
| 9 |
+
json_data = []
|
| 10 |
+
|
| 11 |
+
# 遍历每一行
|
| 12 |
+
for index, row in df.iterrows():
|
| 13 |
+
# 获取视频ID
|
| 14 |
+
vid = str(row['video_id']).strip()
|
| 15 |
+
|
| 16 |
+
# 获取问题和答案
|
| 17 |
+
question = str(row['q']).strip()
|
| 18 |
+
answer = str(row['a']).strip()
|
| 19 |
+
|
| 20 |
+
# 【关键修改】这里改回代码喜欢的 "q" 和 "a"
|
| 21 |
+
entry = {
|
| 22 |
+
"video_id": vid,
|
| 23 |
+
"q": question, # 之前写的是 "instruction",现在改回 "q"
|
| 24 |
+
"a": answer, # 之前写的是 "answer",现在改回 "a"
|
| 25 |
+
"length": 100
|
| 26 |
+
}
|
| 27 |
+
json_data.append(entry)
|
| 28 |
+
|
| 29 |
+
# 覆盖保存为 JSON
|
| 30 |
+
output_path = 'datasets/stage3/video_instruct_data.json'
|
| 31 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 32 |
+
json.dump(json_data, f, indent=4)
|
| 33 |
+
|
| 34 |
+
print(f"转换完成!已重新生成符合代码要求的 JSON。")
|
environment.yml
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: goldfish
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 6 |
+
- _openmp_mutex=4.5=2_gnu
|
| 7 |
+
- archspec=0.2.2=pyhd8ed1ab_0
|
| 8 |
+
- boltons=23.1.1=pyhd8ed1ab_0
|
| 9 |
+
- brotli-python=1.1.0=py39h3d6467e_1
|
| 10 |
+
- bzip2=1.0.8=hd590300_5
|
| 11 |
+
- c-ares=1.25.0=hd590300_0
|
| 12 |
+
- ca-certificates=2024.2.2=hbcca054_0
|
| 13 |
+
- certifi=2024.2.2=pyhd8ed1ab_0
|
| 14 |
+
- cffi=1.16.0=py39h7a31438_0
|
| 15 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
| 16 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
| 17 |
+
- conda=23.11.0=py39hf3d152e_1
|
| 18 |
+
- conda-libmamba-solver=23.12.0=pyhd8ed1ab_0
|
| 19 |
+
- conda-package-handling=2.2.0=pyh38be061_0
|
| 20 |
+
- conda-package-streaming=0.9.0=pyhd8ed1ab_0
|
| 21 |
+
- cudatoolkit=11.8.0=h4ba93d1_12
|
| 22 |
+
- cudatoolkit-dev=11.7.0=h1de0b5d_6
|
| 23 |
+
- distro=1.9.0=pyhd8ed1ab_0
|
| 24 |
+
- faiss=1.7.4=py39cuda112h460e57a_0_cuda
|
| 25 |
+
- fmt=10.1.1=h00ab1b0_1
|
| 26 |
+
- freetype=2.12.1=h267a509_2
|
| 27 |
+
- gmp=6.1.2=hf484d3e_1000
|
| 28 |
+
- gnutls=3.5.19=h2a4e5f8_1
|
| 29 |
+
- icu=73.2=h59595ed_0
|
| 30 |
+
- idna=3.6=pyhd8ed1ab_0
|
| 31 |
+
- jsonpatch=1.33=pyhd8ed1ab_0
|
| 32 |
+
- jsonpointer=2.4=py39hf3d152e_3
|
| 33 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 34 |
+
- krb5=1.21.2=h659d440_0
|
| 35 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
| 36 |
+
- libarchive=3.7.2=h2aa1ff5_1
|
| 37 |
+
- libblas=3.9.0=20_linux64_openblas
|
| 38 |
+
- libcblas=3.9.0=20_linux64_openblas
|
| 39 |
+
- libcurl=8.5.0=hca28451_0
|
| 40 |
+
- libedit=3.1.20191231=he28a2e2_2
|
| 41 |
+
- libev=4.33=hd590300_2
|
| 42 |
+
- libfaiss=1.7.4=cuda112hb18a002_0_cuda
|
| 43 |
+
- libfaiss-avx2=1.7.4=cuda112h1234567_0_cuda
|
| 44 |
+
- libffi=3.4.2=h7f98852_5
|
| 45 |
+
- libgcc-ng=13.2.0=h807b86a_3
|
| 46 |
+
- libgfortran-ng=13.2.0=h69a702a_3
|
| 47 |
+
- libgfortran5=13.2.0=ha4646dd_3
|
| 48 |
+
- libgomp=13.2.0=h807b86a_3
|
| 49 |
+
- libiconv=1.17=hd590300_2
|
| 50 |
+
- liblapack=3.9.0=20_linux64_openblas
|
| 51 |
+
- libmamba=1.5.6=had39da4_0
|
| 52 |
+
- libmambapy=1.5.6=py39h10defb6_0
|
| 53 |
+
- libnghttp2=1.58.0=h47da74e_1
|
| 54 |
+
- libnsl=2.0.1=hd590300_0
|
| 55 |
+
- libopenblas=0.3.25=pthreads_h413a1c8_0
|
| 56 |
+
- libpng=1.6.39=h753d276_0
|
| 57 |
+
- libsolv=0.7.27=hfc55251_0
|
| 58 |
+
- libsqlite=3.44.2=h2797004_0
|
| 59 |
+
- libssh2=1.11.0=h0841786_0
|
| 60 |
+
- libstdcxx-ng=13.2.0=h7e041cc_3
|
| 61 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 62 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 63 |
+
- libxml2=2.12.3=h232c23b_0
|
| 64 |
+
- libzlib=1.2.13=hd590300_5
|
| 65 |
+
- lz4-c=1.9.4=hcb278e6_0
|
| 66 |
+
- lzo=2.10=h516909a_1000
|
| 67 |
+
- menuinst=2.0.1=py39hf3d152e_0
|
| 68 |
+
- ncurses=6.4=h59595ed_2
|
| 69 |
+
- nettle=3.3=0
|
| 70 |
+
- numpy=1.26.3=py39h474f0d3_0
|
| 71 |
+
- openh264=1.8.0=hdbcaa40_1000
|
| 72 |
+
- openssl=3.2.1=hd590300_0
|
| 73 |
+
- packaging=23.2=pyhd8ed1ab_0
|
| 74 |
+
- pip=23.3.2=pyhd8ed1ab_0
|
| 75 |
+
- platformdirs=4.1.0=pyhd8ed1ab_0
|
| 76 |
+
- pluggy=1.3.0=pyhd8ed1ab_0
|
| 77 |
+
- pybind11-abi=4=hd8ed1ab_3
|
| 78 |
+
- pycosat=0.6.6=py39hd1e30aa_0
|
| 79 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
| 80 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
| 81 |
+
- python=3.9.18=h0755675_1_cpython
|
| 82 |
+
- python_abi=3.9=4_cp39
|
| 83 |
+
- readline=8.2=h8228510_1
|
| 84 |
+
- reproc=14.2.4.post0=hd590300_1
|
| 85 |
+
- reproc-cpp=14.2.4.post0=h59595ed_1
|
| 86 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
| 87 |
+
- ruamel.yaml=0.18.5=py39hd1e30aa_0
|
| 88 |
+
- ruamel.yaml.clib=0.2.7=py39hd1e30aa_2
|
| 89 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 90 |
+
- tqdm=4.66.1=pyhd8ed1ab_0
|
| 91 |
+
- urllib3=2.1.0=pyhd8ed1ab_0
|
| 92 |
+
- wheel=0.42.0=pyhd8ed1ab_0
|
| 93 |
+
- x264=1!152.20180717=h14c3975_1001
|
| 94 |
+
- xz=5.2.6=h166bdaf_0
|
| 95 |
+
- yaml-cpp=0.8.0=h59595ed_0
|
| 96 |
+
- zlib=1.2.13=hd590300_5
|
| 97 |
+
- zstandard=0.22.0=py39h6e5214e_0
|
| 98 |
+
- zstd=1.5.5=hfc55251_0
|
| 99 |
+
- pip:
|
| 100 |
+
- accelerate==0.25.0
|
| 101 |
+
- aiofiles==23.2.1
|
| 102 |
+
- aiohttp==3.9.1
|
| 103 |
+
- aiosignal==1.3.1
|
| 104 |
+
- altair==5.2.0
|
| 105 |
+
- annotated-types==0.6.0
|
| 106 |
+
- antlr4-python3-runtime==4.9.3
|
| 107 |
+
- anyio==4.2.0
|
| 108 |
+
- appdirs==1.4.4
|
| 109 |
+
- asgiref==3.7.2
|
| 110 |
+
- async-timeout==4.0.3
|
| 111 |
+
- attrs==23.2.0
|
| 112 |
+
- backoff==2.2.1
|
| 113 |
+
- bcrypt==4.1.2
|
| 114 |
+
- beautifulsoup4==4.12.2
|
| 115 |
+
- bitarray==2.9.2
|
| 116 |
+
- bitsandbytes==0.42.0
|
| 117 |
+
- bleach==6.1.0
|
| 118 |
+
- blinker==1.7.0
|
| 119 |
+
- braceexpand==0.1.7
|
| 120 |
+
- build==1.0.3
|
| 121 |
+
- cachetools==5.3.2
|
| 122 |
+
- chardet==5.2.0
|
| 123 |
+
- chroma-hnswlib==0.7.3
|
| 124 |
+
- chromadb==0.4.22
|
| 125 |
+
- click==8.1.7
|
| 126 |
+
- cmake==3.25.0
|
| 127 |
+
- colbert-ai==0.2.18
|
| 128 |
+
- coloredlogs==15.0.1
|
| 129 |
+
- contourpy==1.2.0
|
| 130 |
+
- cycler==0.12.1
|
| 131 |
+
- datasets==2.17.0
|
| 132 |
+
- decorator==4.4.2
|
| 133 |
+
- decord==0.6.0
|
| 134 |
+
- deprecated==1.2.14
|
| 135 |
+
- dill==0.3.8
|
| 136 |
+
- docker-pycreds==0.4.0
|
| 137 |
+
- docopt==0.6.2
|
| 138 |
+
- einops==0.7.0
|
| 139 |
+
- exceptiongroup==1.2.0
|
| 140 |
+
- faiss-gpu==1.7.2
|
| 141 |
+
- fastapi==0.108.0
|
| 142 |
+
- ffmpeg==1.4
|
| 143 |
+
- ffmpeg-python==0.2.0
|
| 144 |
+
- ffmpy==0.3.1
|
| 145 |
+
- filelock==3.13.1
|
| 146 |
+
- flask==3.0.2
|
| 147 |
+
- flatbuffers==23.5.26
|
| 148 |
+
- fonttools==4.47.0
|
| 149 |
+
- frozenlist==1.4.1
|
| 150 |
+
- fsspec==2023.10.0
|
| 151 |
+
- ftfy==6.1.3
|
| 152 |
+
- future==0.18.3
|
| 153 |
+
- gdown==4.7.1
|
| 154 |
+
- git-python==1.0.3
|
| 155 |
+
- gitdb==4.0.11
|
| 156 |
+
- gitpython==3.1.40
|
| 157 |
+
- google-auth==2.26.1
|
| 158 |
+
- googleapis-common-protos==1.62.0
|
| 159 |
+
- gradio
|
| 160 |
+
- gradio-client
|
| 161 |
+
- h11==0.14.0
|
| 162 |
+
- h5py==3.10.0
|
| 163 |
+
- httpcore==1.0.2
|
| 164 |
+
- httptools==0.6.1
|
| 165 |
+
- httpx==0.26.0
|
| 166 |
+
- huggingface-hub
|
| 167 |
+
- humanfriendly==10.0
|
| 168 |
+
- imageio==2.33.1
|
| 169 |
+
- imageio-ffmpeg==0.4.9
|
| 170 |
+
- importlib-metadata==6.11.0
|
| 171 |
+
- importlib-resources==6.1.1
|
| 172 |
+
- inquirerpy==0.3.4
|
| 173 |
+
- iopath==0.1.10
|
| 174 |
+
- itsdangerous==2.1.2
|
| 175 |
+
- jinja2==3.1.2
|
| 176 |
+
- joblib==1.3.2
|
| 177 |
+
- jsonschema==4.20.0
|
| 178 |
+
- jsonschema-specifications==2023.12.1
|
| 179 |
+
- kaggle==1.6.0
|
| 180 |
+
- kiwisolver==1.4.5
|
| 181 |
+
- kubernetes==29.0.0
|
| 182 |
+
- lazy-loader==0.3
|
| 183 |
+
- lit==15.0.7
|
| 184 |
+
- llvmlite==0.41.1
|
| 185 |
+
- markdown-it-py==3.0.0
|
| 186 |
+
- matplotlib==3.8.2
|
| 187 |
+
- mdurl==0.1.2
|
| 188 |
+
- mmh3==4.1.0
|
| 189 |
+
- monotonic==1.6
|
| 190 |
+
- more-itertools==10.1.0
|
| 191 |
+
- moviepy==1.0.3
|
| 192 |
+
- mpmath==1.3.0
|
| 193 |
+
- multidict==6.0.4
|
| 194 |
+
- multiprocess==0.70.16
|
| 195 |
+
- mutagen==1.47.0
|
| 196 |
+
- networkx==3.2.1
|
| 197 |
+
- ninja==1.11.1.1
|
| 198 |
+
- nltk==3.8.1
|
| 199 |
+
- numba==0.58.1
|
| 200 |
+
- omegaconf==2.3.0
|
| 201 |
+
- onnxruntime==1.16.3
|
| 202 |
+
- openai
|
| 203 |
+
- openai-whisper==20231117
|
| 204 |
+
- opencv-python==4.7.0.72
|
| 205 |
+
- opentelemetry-api==1.22.0
|
| 206 |
+
- opentelemetry-exporter-otlp-proto-common==1.22.0
|
| 207 |
+
- opentelemetry-exporter-otlp-proto-grpc==1.22.0
|
| 208 |
+
- opentelemetry-instrumentation==0.43b0
|
| 209 |
+
- opentelemetry-instrumentation-asgi==0.43b0
|
| 210 |
+
- opentelemetry-instrumentation-fastapi==0.43b0
|
| 211 |
+
- opentelemetry-proto==1.22.0
|
| 212 |
+
- opentelemetry-sdk==1.22.0
|
| 213 |
+
- opentelemetry-semantic-conventions==0.43b0
|
| 214 |
+
- opentelemetry-util-http==0.43b0
|
| 215 |
+
- orjson==3.9.10
|
| 216 |
+
- overrides==7.4.0
|
| 217 |
+
- pandas==2.0.0
|
| 218 |
+
- pathtools==0.1.2
|
| 219 |
+
- peft==0.2.0
|
| 220 |
+
- pfzy==0.3.4
|
| 221 |
+
- pillow==10.2.0
|
| 222 |
+
- plotly==5.18.0
|
| 223 |
+
- portalocker==2.8.2
|
| 224 |
+
- posthog==3.3.0
|
| 225 |
+
- proglog==0.1.10
|
| 226 |
+
- progressbar2==4.3.2
|
| 227 |
+
- prompt-toolkit==3.0.43
|
| 228 |
+
- protobuf==4.25.1
|
| 229 |
+
- psutil==5.9.7
|
| 230 |
+
- pulsar-client==3.4.0
|
| 231 |
+
- pyarrow==15.0.0
|
| 232 |
+
- pyarrow-hotfix==0.6
|
| 233 |
+
- pyasn1==0.5.1
|
| 234 |
+
- pyasn1-modules==0.3.0
|
| 235 |
+
- pycocoevalcap==1.2
|
| 236 |
+
- pycocotools==2.0.6
|
| 237 |
+
- pycryptodomex==3.19.1
|
| 238 |
+
- pydantic==2.5.3
|
| 239 |
+
- pydantic-core==2.14.6
|
| 240 |
+
- pydub==0.25.1
|
| 241 |
+
- pygments==2.17.2
|
| 242 |
+
- pyparsing==3.1.1
|
| 243 |
+
- pypika==0.48.9
|
| 244 |
+
- pyproject-hooks==1.0.0
|
| 245 |
+
- pysrt==1.1.2
|
| 246 |
+
- python-dateutil==2.8.2
|
| 247 |
+
- python-dotenv==1.0.0
|
| 248 |
+
- python-multipart==0.0.6
|
| 249 |
+
- python-slugify==8.0.1
|
| 250 |
+
- python-utils==3.8.1
|
| 251 |
+
- pytubefix==6.5.1
|
| 252 |
+
- pytz==2023.3.post1
|
| 253 |
+
- pyyaml==6.0.1
|
| 254 |
+
- referencing==0.32.0
|
| 255 |
+
- regex==2023.12.25
|
| 256 |
+
- rich==13.7.0
|
| 257 |
+
- rouge==1.0.1
|
| 258 |
+
- rpds-py==0.16.2
|
| 259 |
+
- rsa==4.9
|
| 260 |
+
- safetensors==0.4.1
|
| 261 |
+
- scikit-image==0.22.0
|
| 262 |
+
- scikit-learn==1.3.2
|
| 263 |
+
- scipy==1.11.4
|
| 264 |
+
- seaborn==0.13.1
|
| 265 |
+
- semantic-version==2.10.0
|
| 266 |
+
- sentence-transformers==2.2.2
|
| 267 |
+
- sentencepiece==0.1.97
|
| 268 |
+
- sentry-sdk==1.39.1
|
| 269 |
+
- setproctitle==1.3.3
|
| 270 |
+
- setuptools==69.0.3
|
| 271 |
+
- shellingham==1.5.4
|
| 272 |
+
- six==1.16.0
|
| 273 |
+
- smmap==5.0.1
|
| 274 |
+
- sniffio==1.3.0
|
| 275 |
+
- soundfile==0.12.1
|
| 276 |
+
- soupsieve==2.5
|
| 277 |
+
- starlette==0.32.0.post1
|
| 278 |
+
- sympy==1.12
|
| 279 |
+
- tenacity==8.2.3
|
| 280 |
+
- text-unidecode==1.3
|
| 281 |
+
- threadpoolctl==3.2.0
|
| 282 |
+
- tifffile==2023.12.9
|
| 283 |
+
- tiktoken==0.5.2
|
| 284 |
+
- timm
|
| 285 |
+
- tokenizers==0.15.2
|
| 286 |
+
- tomli==2.0.1
|
| 287 |
+
- tomlkit==0.12.0
|
| 288 |
+
- toolz==0.12.0
|
| 289 |
+
- torch==2.2.2
|
| 290 |
+
- torchaudio==2.2.2
|
| 291 |
+
- torchvision==0.17.2
|
| 292 |
+
- transformers
|
| 293 |
+
#- triton==2.0.0
|
| 294 |
+
- typer==0.9.0
|
| 295 |
+
- typing-extensions==4.9.0
|
| 296 |
+
- tzdata==2023.4
|
| 297 |
+
- ujson==5.9.0
|
| 298 |
+
- uvicorn==0.25.0
|
| 299 |
+
- uvloop==0.19.0
|
| 300 |
+
- visual-genome==1.1.1
|
| 301 |
+
- wandb==0.14.2
|
| 302 |
+
- watchfiles==0.21.0
|
| 303 |
+
- wcwidth==0.2.13
|
| 304 |
+
- webdataset==0.2.48
|
| 305 |
+
- webencodings==0.5.1
|
| 306 |
+
- websocket-client==1.7.0
|
| 307 |
+
- websockets
|
| 308 |
+
- webvtt-py==0.4.6
|
| 309 |
+
- wrapt==1.16.0
|
| 310 |
+
- xxhash==3.4.1
|
| 311 |
+
- yarl==1.9.4
|
| 312 |
+
- youtube-dl==2021.12.17
|
| 313 |
+
- yt-dlp
|
| 314 |
+
- zipp
|
| 315 |
+
#- vllm
|
| 316 |
+
#- openai-whisper
|
| 317 |
+
#- triton==2.0.0
|
evaluation/Goldfish_eval/movies/eval_model_summary_llama_vid.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=L_RAG_general_summary_3_subtitles_together_%j
|
| 4 |
+
#SBATCH --output=L_RAG_general_summary_3_subtitles_together_%j.out
|
| 5 |
+
#SBATCH --error=L_RAG_general_summary_3_subtitles_together_%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
|
| 14 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 15 |
+
START=$1
|
| 16 |
+
END=$2
|
| 17 |
+
BATCH_SIZE=4
|
| 18 |
+
|
| 19 |
+
NEIGHBOURS=3
|
| 20 |
+
## Dataset paths
|
| 21 |
+
videos_path="path to the videos"
|
| 22 |
+
subtitle_path="path to the subtitles"
|
| 23 |
+
video_clips_saving_path="path to save the video clips"
|
| 24 |
+
annotation_file="path to the annotation file"
|
| 25 |
+
movienet_annotations_dir="path to the movienet annotations directory"
|
| 26 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 27 |
+
use_openai_embedding=True
|
| 28 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# if start and end are not provided, then use the whole dataset
|
| 33 |
+
if [ -z "$START" ]
|
| 34 |
+
then
|
| 35 |
+
START=0
|
| 36 |
+
fi
|
| 37 |
+
if [ -z "$END" ]
|
| 38 |
+
then
|
| 39 |
+
END=100000
|
| 40 |
+
fi
|
| 41 |
+
echo "Start: $START"
|
| 42 |
+
echo "End: $END"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# # Vision + subtitles
|
| 47 |
+
exp_name="Vsion_subtitles_model_summary_subtitle"
|
| 48 |
+
echo $exp_name
|
| 49 |
+
python evaluation/eval_goldfish_llama_vid.py --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 50 |
+
--videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# vision only
|
| 54 |
+
# exp_name="vision_only"
|
| 55 |
+
# echo $exp_name
|
| 56 |
+
# python eval_goldfish_llama_vid.py --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 57 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# subtiltes only (eliminate the vision)
|
| 61 |
+
# exp_name="subtitles_only"
|
| 62 |
+
# echo $exp_name
|
| 63 |
+
# python eval_goldfish_llama_vid.py --index_subtitles_together --subtitles_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 64 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 65 |
+
|
| 66 |
+
|
evaluation/Goldfish_eval/movies/eval_model_summary_movie_chat.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=MC_RAG_general_summary_all_%j
|
| 4 |
+
#SBATCH --output=MC_RAG_general_summary_all_%j.out
|
| 5 |
+
#SBATCH --error=MC_RAG_general_summary_all_%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 14 |
+
START=$1
|
| 15 |
+
END=$2
|
| 16 |
+
BATCH_SIZE=4
|
| 17 |
+
# if start and end are not provided, then use the whole dataset
|
| 18 |
+
if [ -z "$START" ]
|
| 19 |
+
then
|
| 20 |
+
START=0
|
| 21 |
+
fi
|
| 22 |
+
if [ -z "$END" ]
|
| 23 |
+
then
|
| 24 |
+
END=100000
|
| 25 |
+
fi
|
| 26 |
+
echo "Start: $START"
|
| 27 |
+
echo "End: $END"
|
| 28 |
+
|
| 29 |
+
NEIGHBOURS=-1 # use the whole neighbourhood for the global mode
|
| 30 |
+
|
| 31 |
+
dataset_path="path to the movies folder"
|
| 32 |
+
annotation_json_folder="path to the jsons folder"
|
| 33 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 34 |
+
use_openai_embedding=True
|
| 35 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
exp_name="model_summary_and_subtitle"
|
| 40 |
+
fps=2
|
| 41 |
+
|
| 42 |
+
# use general summary
|
| 43 |
+
python evaluation/eval_goldfish_movie_chat.py --fps=$fps --neighbours_global=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 44 |
+
--dataset_videos_path $dataset_path --annotation_json_folder $annotation_json_folder --use_openai_embedding $use_openai_embedding
|
evaluation/Goldfish_eval/movies/eval_model_summary_movie_qa.sh
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=M_RAG_general_summary_1_subtitles_together_%j
|
| 4 |
+
#SBATCH --output=M_RAG_general_summary_1_subtitles_together_%j.out
|
| 5 |
+
#SBATCH --error=M_RAG_general_summary_1_subtitles_together_%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=100G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 14 |
+
START=$1
|
| 15 |
+
END=$2
|
| 16 |
+
BATCH_SIZE=4
|
| 17 |
+
|
| 18 |
+
NEIGHBOURS=3
|
| 19 |
+
## Dataset paths
|
| 20 |
+
videos_path="path to the videos"
|
| 21 |
+
subtitle_path="path to the subtitles"
|
| 22 |
+
video_clips_saving_path="path to save the video clips"
|
| 23 |
+
annotation_file="path to the annotation file"
|
| 24 |
+
movienet_annotations_dir="path to the movienet annotations directory"
|
| 25 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 26 |
+
use_openai_embedding=True
|
| 27 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# if start and end are not provided, then use the whole dataset
|
| 32 |
+
if [ -z "$START" ]
|
| 33 |
+
then
|
| 34 |
+
START=0
|
| 35 |
+
fi
|
| 36 |
+
if [ -z "$END" ]
|
| 37 |
+
then
|
| 38 |
+
END=100000
|
| 39 |
+
fi
|
| 40 |
+
echo "Start: $START"
|
| 41 |
+
echo "End: $END"
|
| 42 |
+
echo "Batch size: $BATCH_SIZE"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# # Vision + subtitles
|
| 46 |
+
exp_name="Vsion_subtitles_model_summary_subtitle"
|
| 47 |
+
echo $exp_name
|
| 48 |
+
python evaluation/eval_goldfish_movie_qa.py --add_unknown --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 49 |
+
--videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# vision only
|
| 53 |
+
# exp_name="vision_only"
|
| 54 |
+
# echo $exp_name
|
| 55 |
+
# python eval_goldfish_movie_qa.py --add_unknown --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 56 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 57 |
+
|
| 58 |
+
# subtiltes only (eliminate the vision)
|
| 59 |
+
# exp_name="subtitles_only"
|
| 60 |
+
# echo $exp_name
|
| 61 |
+
# python eval_goldfish_movie_qa.py --add_unknown --index_subtitles_together --subtitles_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 62 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 63 |
+
|
evaluation/Goldfish_eval/movies/eval_q_related_info_llama_vid.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=job_name%j
|
| 4 |
+
#SBATCH --output=job_name%j.out
|
| 5 |
+
#SBATCH --error=job_name%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 14 |
+
BATCH_SIZE=4
|
| 15 |
+
START=$1
|
| 16 |
+
END=$2
|
| 17 |
+
|
| 18 |
+
NEIGHBOURS=3
|
| 19 |
+
|
| 20 |
+
# Dataset paths
|
| 21 |
+
videos_path="path to the videos"
|
| 22 |
+
subtitle_path="path to the subtitles"
|
| 23 |
+
video_clips_saving_path="path to save the video clips"
|
| 24 |
+
annotation_file="path to the annotation file"
|
| 25 |
+
movienet_annotations_dir="path to the movienet annotations directory"
|
| 26 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 27 |
+
use_openai_embedding=True
|
| 28 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# if start and end are not provided, then use the whole dataset
|
| 32 |
+
if [ -z "$START" ]
|
| 33 |
+
then
|
| 34 |
+
START=0
|
| 35 |
+
fi
|
| 36 |
+
if [ -z "$END" ]
|
| 37 |
+
then
|
| 38 |
+
END=100000
|
| 39 |
+
fi
|
| 40 |
+
echo "Start: $START"
|
| 41 |
+
echo "End: $END"
|
| 42 |
+
|
| 43 |
+
# # Vision + subtitles
|
| 44 |
+
exp_name="Vsion_subtitles_model_summary_subtitle"
|
| 45 |
+
echo $exp_name
|
| 46 |
+
python evaluation/eval_goldfish_llama_vid.py --use_clips_for_info --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 47 |
+
--videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# vision only
|
| 51 |
+
# exp_name="vision_only"
|
| 52 |
+
# echo $exp_name
|
| 53 |
+
# python evaluation/eval_goldfish_llama_vid.py --use_clips_for_info --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 54 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 55 |
+
|
| 56 |
+
# # subtiltes only (eliminate the vision)
|
| 57 |
+
# it is only from summaries no need to run it with clips
|
evaluation/Goldfish_eval/movies/eval_q_related_info_movie_chat.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=job_name%j
|
| 4 |
+
#SBATCH --output=job_name%j.out
|
| 5 |
+
#SBATCH --error=job_name%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 14 |
+
BATCH_SIZE=4
|
| 15 |
+
START=$1
|
| 16 |
+
END=$2
|
| 17 |
+
# if start and end are not provided, then use the whole dataset
|
| 18 |
+
if [ -z "$START" ]
|
| 19 |
+
then
|
| 20 |
+
START=0
|
| 21 |
+
fi
|
| 22 |
+
if [ -z "$END" ]
|
| 23 |
+
then
|
| 24 |
+
END=100000
|
| 25 |
+
fi
|
| 26 |
+
echo "Start: $START"
|
| 27 |
+
echo "End: $END"
|
| 28 |
+
|
| 29 |
+
NEIGHBOURS=-1 # use the whole neighbourhood for the global mode
|
| 30 |
+
dataset_path="path to the movies folder"
|
| 31 |
+
annotation_json_folder="path to the jsons folder"
|
| 32 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 33 |
+
use_openai_embedding=True
|
| 34 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
exp_name="model_summary_and_subtitle"
|
| 38 |
+
fps=2
|
| 39 |
+
|
| 40 |
+
# use this for both info and general summary --v_sum_and_info
|
| 41 |
+
|
| 42 |
+
python evaluation/eval_goldfish_movie_chat.py --fps=$fps --neighbours_global=$NEIGHBOURS --batch_size=$BATCH_SIZE --start=$START --end=$END --use_clips_for_info --ckpt $CKPT_PATH --exp_name=$exp_name --dataset_videos_path $dataset_path --annotation_json_folder $annotation_json_folder --use_openai_embedding $use_openai_embedding
|
evaluation/Goldfish_eval/movies/eval_q_related_info_movie_qa.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=M_RAG_clips_for_info_3_subtitles_together_%j
|
| 4 |
+
#SBATCH --output=M_RAG_clips_for_info_3_subtitles_together_%j.out
|
| 5 |
+
#SBATCH --error=M_RAG_clips_for_info_3_subtitles_together_%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## run the application:
|
| 13 |
+
NAME="ckpt_92"
|
| 14 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 15 |
+
BATCH_SIZE=4
|
| 16 |
+
START=$1
|
| 17 |
+
END=$2
|
| 18 |
+
|
| 19 |
+
NEIGHBOURS=3
|
| 20 |
+
# Dataset paths
|
| 21 |
+
videos_path="path to the videos"
|
| 22 |
+
subtitle_path="path to the subtitles"
|
| 23 |
+
video_clips_saving_path="path to save the video clips"
|
| 24 |
+
annotation_file="path to the annotation file"
|
| 25 |
+
movienet_annotations_dir="path to the movienet annotations directory"
|
| 26 |
+
# if you want to use openai embedding, then you need to set the OPENAI_API_KEY
|
| 27 |
+
use_openai_embedding=True
|
| 28 |
+
export OPENAI_API_KEY="your_openai_key"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# if start and end are not provided, then use the whole dataset
|
| 32 |
+
if [ -z "$START" ]
|
| 33 |
+
then
|
| 34 |
+
START=0
|
| 35 |
+
fi
|
| 36 |
+
if [ -z "$END" ]
|
| 37 |
+
then
|
| 38 |
+
END=100000
|
| 39 |
+
fi
|
| 40 |
+
echo "Start: $START"
|
| 41 |
+
echo "End: $END"
|
| 42 |
+
echo "Batch size: $BATCH_SIZE"
|
| 43 |
+
|
| 44 |
+
# # Vision + subtitles
|
| 45 |
+
# exp_name="Vsion_subtitles_model_summary_subtitle"
|
| 46 |
+
# echo $exp_name
|
| 47 |
+
python evaluation/eval_goldfish_movie_qa.py --add_unknown --use_clips_for_info --use_choices_for_info --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 48 |
+
--videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# vision only
|
| 52 |
+
# exp_name="vision_only"
|
| 53 |
+
# echo $exp_name
|
| 54 |
+
# python evaluation/eval_goldfish_movie_qa.py --add_unknown --use_clips_for_info --use_choices_for_info --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 55 |
+
# --videos_path $videos_path --subtitle_path $subtitle_path --video_clips_saving_path $video_clips_saving_path --annotation_path $annotation_path --movienet_annotations_dir $movienet_annotations_dir --use_openai_embedding $use_openai_embedding
|
| 56 |
+
|
| 57 |
+
|
evaluation/Goldfish_eval/movies/submit_batch_jobs_llama_vid.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# bash_script = 'eval_q_related_info_llama_vid.sh'
|
| 4 |
+
|
| 5 |
+
bash_script = 'eval_model_summary_llama_vid.sh'
|
| 6 |
+
start=0
|
| 7 |
+
end=45
|
| 8 |
+
step=11
|
| 9 |
+
for i in range(start, end, step):
|
| 10 |
+
# print(i, i+step, job_id)
|
| 11 |
+
# job_id+=1
|
| 12 |
+
cmd=f'sbatch {bash_script} {str(i)} {str(i+step)}'
|
| 13 |
+
# print(cmd)
|
| 14 |
+
os.system(cmd)
|
evaluation/Goldfish_eval/movies/submit_batch_jobs_movie_qa.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
bash_script = 'eval_model_summary_movie_qa.sh'
|
| 5 |
+
# bash_script = 'eval_q_related_info_movie_qa.sh'
|
| 6 |
+
start=0
|
| 7 |
+
end=30
|
| 8 |
+
step=4
|
| 9 |
+
for i in range(start, end, step):
|
| 10 |
+
# print(i, i+step, job_id)
|
| 11 |
+
# job_id+=1
|
| 12 |
+
cmd=f'sbatch {bash_script} {str(i)} {str(i+step)}'
|
| 13 |
+
# print(cmd)
|
| 14 |
+
os.system(cmd)
|
| 15 |
+
|
| 16 |
+
|
evaluation/Goldfish_eval/movies/submit_batch_jobs_moviechat.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
bash_script = 'eval_q_related_info_movie_chat.sh'
|
| 4 |
+
|
| 5 |
+
# bash_script = 'eval_model_summary_movie_chat.sh'
|
| 6 |
+
start=0
|
| 7 |
+
end=101
|
| 8 |
+
step=26
|
| 9 |
+
for i in range(start, end, step):
|
| 10 |
+
# print(i, i+step, job_id)
|
| 11 |
+
# job_id+=1
|
| 12 |
+
cmd=f'sbatch {bash_script} {str(i)} {str(i+step)}'
|
| 13 |
+
# print(cmd)
|
| 14 |
+
os.system(cmd)
|
evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
#SBATCH --job-name=Retrieval_acc_3_%j
|
| 6 |
+
#SBATCH --output=Retrieval_acc_3_%j.out
|
| 7 |
+
#SBATCH --error=Retrieval_acc_3_%j.err
|
| 8 |
+
#SBATCH --time=0-23:00:00
|
| 9 |
+
#SBATCH --mem=100G
|
| 10 |
+
#SBATCH --gres=gpu:a100:1
|
| 11 |
+
#SBATCH --nodes=1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## run the application:
|
| 15 |
+
cd ../../../
|
| 16 |
+
NAME="ckpt_92"
|
| 17 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 18 |
+
START=$1
|
| 19 |
+
END=$2
|
| 20 |
+
BATCH_SIZE=8
|
| 21 |
+
|
| 22 |
+
# if start and end are not provided, then use the whole dataset
|
| 23 |
+
if [ -z "$START" ]
|
| 24 |
+
then
|
| 25 |
+
START=0
|
| 26 |
+
fi
|
| 27 |
+
if [ -z "$END" ]
|
| 28 |
+
then
|
| 29 |
+
END=100000
|
| 30 |
+
fi
|
| 31 |
+
echo "Start: $START"
|
| 32 |
+
echo "End: $END"
|
| 33 |
+
echo "Batch size: $BATCH_SIZE"
|
| 34 |
+
|
| 35 |
+
NEIGHBOURS=1
|
| 36 |
+
exp_name="vision"
|
| 37 |
+
|
| 38 |
+
python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 39 |
+
|
| 40 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 41 |
+
|
| 42 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# exp_name="subtitles"
|
| 47 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 48 |
+
|
| 49 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 50 |
+
|
| 51 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_sub_v.sh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
#SBATCH --job-name=Retrieval_acc_3_%j
|
| 6 |
+
#SBATCH --output=Retrieval_acc_3_%j.out
|
| 7 |
+
#SBATCH --error=Retrieval_acc_3_%j.err
|
| 8 |
+
#SBATCH --time=0-23:00:00
|
| 9 |
+
#SBATCH --mem=100G
|
| 10 |
+
#SBATCH --gres=gpu:a100:1
|
| 11 |
+
#SBATCH --nodes=1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## run the application:
|
| 15 |
+
NAME="ckpt_92"
|
| 16 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 17 |
+
START=$1
|
| 18 |
+
END=$2
|
| 19 |
+
BATCH_SIZE=8
|
| 20 |
+
|
| 21 |
+
# if start and end are not provided, then use the whole dataset
|
| 22 |
+
if [ -z "$START" ]
|
| 23 |
+
then
|
| 24 |
+
START=0
|
| 25 |
+
fi
|
| 26 |
+
if [ -z "$END" ]
|
| 27 |
+
then
|
| 28 |
+
END=100000
|
| 29 |
+
fi
|
| 30 |
+
echo "Start: $START"
|
| 31 |
+
echo "End: $END"
|
| 32 |
+
echo "Batch size: $BATCH_SIZE"
|
| 33 |
+
|
| 34 |
+
NEIGHBOURS=1
|
| 35 |
+
# exp_name="vision"
|
| 36 |
+
|
| 37 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 38 |
+
|
| 39 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 40 |
+
|
| 41 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
exp_name="subtitles"
|
| 46 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 47 |
+
|
| 48 |
+
python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 49 |
+
|
| 50 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_sub_v_sub.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
#SBATCH --job-name=Retrieval_acc_3_%j
|
| 6 |
+
#SBATCH --output=Retrieval_acc_3_%j.out
|
| 7 |
+
#SBATCH --error=Retrieval_acc_3_%j.err
|
| 8 |
+
#SBATCH --time=0-23:00:00
|
| 9 |
+
#SBATCH --mem=100G
|
| 10 |
+
#SBATCH --gres=gpu:a100:1
|
| 11 |
+
#SBATCH --nodes=1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## run the application:
|
| 15 |
+
|
| 16 |
+
NAME="ckpt_92"
|
| 17 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 18 |
+
START=$1
|
| 19 |
+
END=$2
|
| 20 |
+
BATCH_SIZE=8
|
| 21 |
+
|
| 22 |
+
# if start and end are not provided, then use the whole dataset
|
| 23 |
+
if [ -z "$START" ]
|
| 24 |
+
then
|
| 25 |
+
START=0
|
| 26 |
+
fi
|
| 27 |
+
if [ -z "$END" ]
|
| 28 |
+
then
|
| 29 |
+
END=100000
|
| 30 |
+
fi
|
| 31 |
+
echo "Start: $START"
|
| 32 |
+
echo "End: $END"
|
| 33 |
+
echo "Batch size: $BATCH_SIZE"
|
| 34 |
+
|
| 35 |
+
NEIGHBOURS=1
|
| 36 |
+
# exp_name="vision"
|
| 37 |
+
|
| 38 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 39 |
+
|
| 40 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 41 |
+
|
| 42 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
exp_name="subtitles"
|
| 47 |
+
python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 48 |
+
|
| 49 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 50 |
+
|
| 51 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
evaluation/Goldfish_eval/retrival_accuracy/eval_retrieval_acc_tvqa_job_vision_vision.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
#SBATCH --job-name=Retrieval_acc_3_%j
|
| 6 |
+
#SBATCH --output=Retrieval_acc_3_%j.out
|
| 7 |
+
#SBATCH --error=Retrieval_acc_3_%j.err
|
| 8 |
+
#SBATCH --time=0-23:00:00
|
| 9 |
+
#SBATCH --mem=100G
|
| 10 |
+
#SBATCH --gres=gpu:a100:1
|
| 11 |
+
#SBATCH --nodes=1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## run the application:
|
| 15 |
+
cd ../../../
|
| 16 |
+
NAME="ckpt_92"
|
| 17 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 18 |
+
START=$1
|
| 19 |
+
END=$2
|
| 20 |
+
BATCH_SIZE=8
|
| 21 |
+
|
| 22 |
+
# if start and end are not provided, then use the whole dataset
|
| 23 |
+
if [ -z "$START" ]
|
| 24 |
+
then
|
| 25 |
+
START=0
|
| 26 |
+
fi
|
| 27 |
+
if [ -z "$END" ]
|
| 28 |
+
then
|
| 29 |
+
END=100000
|
| 30 |
+
fi
|
| 31 |
+
echo "Start: $START"
|
| 32 |
+
echo "End: $END"
|
| 33 |
+
echo "Batch size: $BATCH_SIZE"
|
| 34 |
+
|
| 35 |
+
NEIGHBOURS=1
|
| 36 |
+
exp_name="vision"
|
| 37 |
+
|
| 38 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 39 |
+
|
| 40 |
+
python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 41 |
+
|
| 42 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# exp_name="subtitles"
|
| 47 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 48 |
+
|
| 49 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --vision_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 50 |
+
|
| 51 |
+
# python evaluation/eval_retrieval_acc_tvqa.py --subtitles_only --start=$START --end=$END --neighbours=$NEIGHBOURS --batch_size=$BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
evaluation/Goldfish_eval/tvqa_eval/eval_model_summary.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
#SBATCH --job-name=job_name%j
|
| 4 |
+
#SBATCH --output=job_name%j.out
|
| 5 |
+
#SBATCH --error=job_name%j.err
|
| 6 |
+
#SBATCH --time=0-23:00:00
|
| 7 |
+
#SBATCH --mem=64G
|
| 8 |
+
#SBATCH --gres=gpu:a100:1
|
| 9 |
+
#SBATCH --nodes=1
|
| 10 |
+
|
| 11 |
+
## run the application:
|
| 12 |
+
cd ../../../
|
| 13 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 14 |
+
START=$1
|
| 15 |
+
END=$2
|
| 16 |
+
|
| 17 |
+
BATCH_SIZE=4
|
| 18 |
+
NEIGHBOURS=3
|
| 19 |
+
|
| 20 |
+
# tvqa_json_subtitles="path to the tvqa json subtitles file"
|
| 21 |
+
# tvqa_clips_subtitles="path to the tvqa clips subtitles"
|
| 22 |
+
# videos_frames="path to the video frames"
|
| 23 |
+
# annotation_path="path to the TVQA-Long annotation file"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
tvqa_json_subtitles="datasets/evaluation_datasets/goldfish_eval_datasets/tvqa/tvqa_preprocessed_subtitles.json"
|
| 27 |
+
tvqa_clips_subtitles="/ibex/project/c2090/datasets/TVR_dataset/videos/tvqa_subtitles"
|
| 28 |
+
videos_frames="/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/"
|
| 29 |
+
annotation_path="datasets/evaluation_datasets/goldfish_eval_datasets/tvqa/tvqa_val_edited.json"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# if start and end are not provided, then use the whole dataset
|
| 33 |
+
if [ -z "$START" ]
|
| 34 |
+
then
|
| 35 |
+
START=0
|
| 36 |
+
fi
|
| 37 |
+
if [ -z "$END" ]
|
| 38 |
+
then
|
| 39 |
+
END=100000
|
| 40 |
+
fi
|
| 41 |
+
echo "Start: $START"
|
| 42 |
+
echo "End: $END"
|
| 43 |
+
|
| 44 |
+
# # Vision + subtitles
|
| 45 |
+
exp_name="Vsion_subtitles_model_summary_subtitle_videoLLM"
|
| 46 |
+
echo $exp_name
|
| 47 |
+
python eval_goldfish_tvqa_long.py --add_unknown --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 48 |
+
--tvqa_json_subtitles $tvqa_json_subtitles --tvqa_clips_subtitles $tvqa_clips_subtitles --videos_frames $videos_frames --annotation_path $annotation_path
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# vision only
|
| 52 |
+
# exp_name="vision_only"
|
| 53 |
+
# echo $exp_name
|
| 54 |
+
# python eval_goldfish_tvqa_long.py --add_unknown --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
| 55 |
+
|
| 56 |
+
# # subtiltes only (eliminate the vision)
|
| 57 |
+
# exp_name="subtitles_only"
|
| 58 |
+
# echo $exp_name
|
| 59 |
+
# python eval_goldfish_tvqa_long.py --add_unknown --index_subtitles_together --subtitles_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --name $NAME --ckpt $CKPT_PATH --exp_name=$exp_name
|
evaluation/Goldfish_eval/tvqa_eval/eval_q_related_info.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=batch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
#SBATCH --job-name=RAG_clips_info_1_vision_%j
|
| 6 |
+
#SBATCH --output=RAG_clips_info_1_vision_%j.out
|
| 7 |
+
#SBATCH --error=RAG_clips_info_1_vision_%j.err
|
| 8 |
+
#SBATCH --time=0-23:00:00
|
| 9 |
+
#SBATCH --mem=64G
|
| 10 |
+
#SBATCH --gres=gpu:a100:1
|
| 11 |
+
#SBATCH --nodes=1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## run the application:
|
| 15 |
+
cd ../../../
|
| 16 |
+
START=$1
|
| 17 |
+
END=$2
|
| 18 |
+
|
| 19 |
+
BATCH_SIZE=4
|
| 20 |
+
NEIGHBOURS=3
|
| 21 |
+
CKPT_PATH="checkpoints/video_llama_checkpoint_last.pth"
|
| 22 |
+
# tvqa_json_subtitles="path to the tvqa json subtitles file"
|
| 23 |
+
# tvqa_clips_subtitles="path to the tvqa clips subtitles"
|
| 24 |
+
# videos_frames="path to the video frames"
|
| 25 |
+
# annotation_path="path to the TVQA-Long annotation file"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
tvqa_json_subtitles="datasets/evaluation_datasets/goldfish_eval_datasets/tvqa/tvqa_preprocessed_subtitles.json"
|
| 29 |
+
tvqa_clips_subtitles="/ibex/project/c2090/datasets/TVR_dataset/videos/tvqa_subtitles"
|
| 30 |
+
videos_frames="/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/"
|
| 31 |
+
annotation_path="datasets/evaluation_datasets/goldfish_eval_datasets/tvqa/tvqa_val_edited.json"
|
| 32 |
+
|
| 33 |
+
# if start and end are not provided, then use the whole dataset
|
| 34 |
+
if [ -z "$START" ]
|
| 35 |
+
then
|
| 36 |
+
START=0
|
| 37 |
+
fi
|
| 38 |
+
if [ -z "$END" ]
|
| 39 |
+
then
|
| 40 |
+
END=100000
|
| 41 |
+
fi
|
| 42 |
+
echo "Start: $START"
|
| 43 |
+
echo "End: $END"
|
| 44 |
+
|
| 45 |
+
# # Vision + subtitles
|
| 46 |
+
exp_name="Vsion_subtitles_model_summary_subtitle"
|
| 47 |
+
echo $exp_name
|
| 48 |
+
python eval_goldfish_tvqa_long.py --add_unknown --use_clips_for_info --use_choices_for_info --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 49 |
+
--tvqa_json_subtitles $tvqa_json_subtitles --tvqa_clips_subtitles $tvqa_clips_subtitles --videos_frames $videos_frames --annotation_path $annotation_path
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# exp_name="Vsion_subtitles_info_only"
|
| 53 |
+
# echo $exp_name
|
| 54 |
+
# python eval_goldfish_tvqa_long.py --add_unknown --info_only --use_clips_for_info --use_choices_for_info --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 55 |
+
# --tvqa_json_subtitles $tvqa_json_subtitles --tvqa_clips_subtitles $tvqa_clips_subtitles --videos_frames $videos_frames --annotation_path $annotation_path
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# exp_name="info_sub_after_retrieval"
|
| 59 |
+
# echo $exp_name
|
| 60 |
+
# python eval_goldfish_tvqa_long.py --add_unknown --subtitles_only_after_retrieval --use_clips_for_info --use_choices_for_info --index_subtitles_together --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 61 |
+
# --tvqa_json_subtitles $tvqa_json_subtitles --tvqa_clips_subtitles $tvqa_clips_subtitles --videos_frames $videos_frames --annotation_path $annotation_path
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# vision only
|
| 68 |
+
# exp_name="vision_only"
|
| 69 |
+
# echo $exp_name
|
| 70 |
+
# python eval_goldfish_tvqa_long.py --add_unknown --use_clips_for_info --use_choices_for_info --vision_only --model_summary_only --neighbours=$NEIGHBOURS --start=$START --end=$END --batch_size $BATCH_SIZE --ckpt $CKPT_PATH --exp_name=$exp_name\
|
| 71 |
+
# --tvqa_json_subtitles $tvqa_json_subtitles --tvqa_clips_subtitles $tvqa_clips_subtitles --videos_frames $videos_frames --annotation_path $annotation_path
|
evaluation/Goldfish_eval/tvqa_eval/submit_batch_jobs.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
bash_script = 'RAG_summary.sh'
|
| 5 |
+
# bash_script = 'RAG.sh'
|
| 6 |
+
|
| 7 |
+
# general
|
| 8 |
+
start=0
|
| 9 |
+
end=850
|
| 10 |
+
step=60
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# bash_script="RAG_summary_R_ablations.sh"
|
| 14 |
+
# sample 50
|
| 15 |
+
# start=0
|
| 16 |
+
# end=52
|
| 17 |
+
# step=6
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# job_id=32434597
|
| 21 |
+
for i in range(start, end, step):
|
| 22 |
+
# print(i, i+step, job_id)
|
| 23 |
+
# job_id+=1
|
| 24 |
+
cmd=f'sbatch {bash_script} {str(i)} {str(i+step)}'
|
| 25 |
+
os.system(cmd)
|
evaluation/eval_goldfish_llama_vid.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
project_dir = os.getcwd()
|
| 4 |
+
sys.path.append(project_dir)
|
| 5 |
+
import json
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from goldfish_lv import GoldFish_LV,split_subtitles,time_to_seconds
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import re
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from index import MemoryIndex
|
| 15 |
+
import torch
|
| 16 |
+
import random
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch.backends.cudnn as cudnn
|
| 19 |
+
import shutil
|
| 20 |
+
def str2bool(v):
|
| 21 |
+
if isinstance(v, bool):
|
| 22 |
+
return v
|
| 23 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 24 |
+
return True
|
| 25 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 26 |
+
return False
|
| 27 |
+
else:
|
| 28 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 29 |
+
|
| 30 |
+
def get_arguments():
|
| 31 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 32 |
+
parser.add_argument("--neighbours", type=int, default=-1)
|
| 33 |
+
parser.add_argument("--name", type=str,default="ckpt_92",help="name of the experiment")
|
| 34 |
+
parser.add_argument("--add_unknown", action='store_true')
|
| 35 |
+
parser.add_argument("--use_chatgpt", action='store_true')
|
| 36 |
+
parser.add_argument("--use_choices_for_info", action='store_true')
|
| 37 |
+
parser.add_argument("--use_gt_information", action='store_true')
|
| 38 |
+
parser.add_argument("--inference_text", action='store_true')
|
| 39 |
+
parser.add_argument("--use_gt_information_with_distraction", action='store_true')
|
| 40 |
+
parser.add_argument("--num_distraction", type=int, default=2)
|
| 41 |
+
parser.add_argument("--add_confidance_score", action='store_true')
|
| 42 |
+
parser.add_argument("--use_original_video", action='store_true')
|
| 43 |
+
parser.add_argument("--use_video_embedding", action='store_true')
|
| 44 |
+
parser.add_argument("--use_clips_for_info", action='store_true')
|
| 45 |
+
parser.add_argument("--use_GT_video", action='store_true')
|
| 46 |
+
parser.add_argument("--use_gt_summary", action='store_true')
|
| 47 |
+
parser.add_argument("--index_subtitles", action='store_true')
|
| 48 |
+
parser.add_argument("--index_subtitles_together", action='store_true')
|
| 49 |
+
|
| 50 |
+
parser.add_argument("--ask_the_question_early", action='store_true')
|
| 51 |
+
parser.add_argument("--clip_in_ask_early", action='store_true')
|
| 52 |
+
parser.add_argument("--summary_with_subtitles_only", action='store_true')
|
| 53 |
+
parser.add_argument("--use_coherent_description", action='store_true')
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--start", default=0, type=int)
|
| 56 |
+
parser.add_argument("--end", default=100000, type=int)
|
| 57 |
+
parser.add_argument("--exp_name", type=str,default="",help="name of eval folder")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
parser.add_argument("--vision_only", action='store_true')
|
| 61 |
+
parser.add_argument("--model_summary_only", action='store_true')
|
| 62 |
+
parser.add_argument("--subtitles_only", action='store_true')
|
| 63 |
+
parser.add_argument("--info_only", action='store_true')
|
| 64 |
+
|
| 65 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 66 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 67 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 68 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 69 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
| 70 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 71 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 72 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 73 |
+
parser.add_argument("--video_path", type=str, help="path to the video")
|
| 74 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 75 |
+
parser.add_argument("--annotation_path", type=str, help="path to the annotation file")
|
| 76 |
+
parser.add_argument("--videos_path", type=str, help="path to the videos directory")
|
| 77 |
+
parser.add_argument("--subtitle_path", type=str, help="path to the subtitles directory")
|
| 78 |
+
parser.add_argument("--movienet_annotations_dir", type=str, help="path to the movienet annotations directory")
|
| 79 |
+
parser.add_argument("--video_clips_saving_path", type=str, help="path to save the splitted small video clips")
|
| 80 |
+
|
| 81 |
+
parser.add_argument("--save_path", type=str, help="path to save the results")
|
| 82 |
+
|
| 83 |
+
parser.add_argument("--options", nargs="+")
|
| 84 |
+
return parser.parse_args()
|
| 85 |
+
def time_to_seconds(subrip_time):
|
| 86 |
+
return subrip_time.hours * 3600 + subrip_time.minutes * 60 + subrip_time.seconds + subrip_time.milliseconds / 1000
|
| 87 |
+
|
| 88 |
+
def clean_text(subtitles_text):
|
| 89 |
+
# Remove unwanted characters except for letters, digits, and single quotes
|
| 90 |
+
subtitles_text = re.sub(r'[^a-zA-Z0-9\s\']', '', subtitles_text)
|
| 91 |
+
# Replace multiple spaces with a single space
|
| 92 |
+
subtitles_text = re.sub(r'\s+', ' ', subtitles_text)
|
| 93 |
+
return subtitles_text.strip()
|
| 94 |
+
|
| 95 |
+
class LlamaVidQAEval (GoldFish_LV):
|
| 96 |
+
|
| 97 |
+
def __init__(self,args):
|
| 98 |
+
super().__init__(args)
|
| 99 |
+
self.save_json_path = "new_workspace/clips_summary/movienet"
|
| 100 |
+
if args.use_openai_embedding:
|
| 101 |
+
self.save_pkls_path = "new_workspace/open_ai_embedding/movienet"
|
| 102 |
+
else:
|
| 103 |
+
self.save_pkls_path = "new_workspace/embedding/movienet"
|
| 104 |
+
os.makedirs(self.save_json_path, exist_ok=True)
|
| 105 |
+
annotation_path=args.annotation_path
|
| 106 |
+
with open(annotation_path, 'r') as f:
|
| 107 |
+
self.movies_dict = json.load(f)
|
| 108 |
+
self.max_sub_len=400
|
| 109 |
+
self.max_num_images=45
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _get_movie_data(self,videoname):
|
| 113 |
+
video_images_path =f"{args.videos_path}/{videoname}"
|
| 114 |
+
movie_clips_path =f"{args.video_clips_saving_path}/{videoname}"
|
| 115 |
+
subtitle_path = f"{args.subtitle_path}/{videoname}.srt"
|
| 116 |
+
annotation_file=f"{args.movienet_annotations_dir}/{videoname}.json"
|
| 117 |
+
# load the annotation file
|
| 118 |
+
with open(annotation_file, 'r') as f:
|
| 119 |
+
movie_annotation = json.load(f)
|
| 120 |
+
return video_images_path,subtitle_path,movie_annotation,movie_clips_path
|
| 121 |
+
def _store_subtitles_paragraphs(self,subtitle_path,important_data,number_of_paragraphs):
|
| 122 |
+
paragraphs=[]
|
| 123 |
+
movie_name=subtitle_path.split('/')[-1].split('.')[0]
|
| 124 |
+
# if there is no story, split the subtitles into paragraphs
|
| 125 |
+
paragraphs = split_subtitles(subtitle_path, number_of_paragraphs)
|
| 126 |
+
for i,paragraph in enumerate(paragraphs):
|
| 127 |
+
paragraph=clean_text(paragraph)
|
| 128 |
+
important_data.update({f"subtitle_{i}__{movie_name}_clip_{str(i).zfill(2)}": paragraph})
|
| 129 |
+
return important_data
|
| 130 |
+
def _get_shots_subtitles(self,movie_annotation):
|
| 131 |
+
shots_subtitles={}
|
| 132 |
+
if movie_annotation['story'] is not None:
|
| 133 |
+
for section in movie_annotation['story']:
|
| 134 |
+
for shot in section['subtitle']:
|
| 135 |
+
shot_number=shot['shot']
|
| 136 |
+
shot_subtitle=' '.join(shot['sentences'])
|
| 137 |
+
shots_subtitles[shot_number]=clean_text(shot_subtitle)
|
| 138 |
+
|
| 139 |
+
return shots_subtitles
|
| 140 |
+
|
| 141 |
+
def prepare_input_images(self,clip_path,shots_subtitles,use_subtitles):
|
| 142 |
+
total_frames=len(os.listdir(clip_path))
|
| 143 |
+
movie_name=clip_path.split('/')[-2]
|
| 144 |
+
clip_name=clip_path.split('/')[-1]
|
| 145 |
+
sampling_interval=int(total_frames//self.max_num_images)
|
| 146 |
+
if sampling_interval==0:
|
| 147 |
+
sampling_interval=1
|
| 148 |
+
use_subtitles_save_name="subtitles" if use_subtitles else "no_subtitles"
|
| 149 |
+
video_frames_path = os.path.join(clip_path)
|
| 150 |
+
total_num_frames=len(os.listdir(video_frames_path))
|
| 151 |
+
sampling_interval = round(total_num_frames / self.max_num_images)
|
| 152 |
+
if sampling_interval == 0:
|
| 153 |
+
sampling_interval = 1
|
| 154 |
+
number_of_words=0
|
| 155 |
+
video_images_list=sorted(os.listdir(video_frames_path))
|
| 156 |
+
images = []
|
| 157 |
+
img_placeholder = ""
|
| 158 |
+
for i,frame in enumerate(video_images_list):
|
| 159 |
+
if i % sampling_interval == 0:
|
| 160 |
+
frame = Image.open(os.path.join(video_frames_path,frame)).convert("RGB")
|
| 161 |
+
frame = self.vis_processor(frame)
|
| 162 |
+
images.append(frame)
|
| 163 |
+
img_placeholder += '<Img><ImageHere>'
|
| 164 |
+
shot_num=video_images_list[i].split('_')[1]
|
| 165 |
+
if shots_subtitles.get(shot_num) is not None:
|
| 166 |
+
sub=clean_text(shots_subtitles[shot_num])
|
| 167 |
+
number_of_words+=len(sub.split(' '))
|
| 168 |
+
if number_of_words<= self.max_sub_len and use_subtitles:
|
| 169 |
+
img_placeholder+=f'<Cap>{sub}'
|
| 170 |
+
if len(images) >= self.max_num_images:
|
| 171 |
+
break
|
| 172 |
+
if len(images) ==0:
|
| 173 |
+
print("Video not found",video_frames_path)
|
| 174 |
+
|
| 175 |
+
if 0 <len(images) < self.max_num_images:
|
| 176 |
+
last_item = images[-1]
|
| 177 |
+
while len(images) < self.max_num_images:
|
| 178 |
+
images.append(last_item)
|
| 179 |
+
img_placeholder += '<Img><ImageHere>'
|
| 180 |
+
images = torch.stack(images)
|
| 181 |
+
|
| 182 |
+
return images,img_placeholder
|
| 183 |
+
|
| 184 |
+
def _get_movie_summaries(self,video_images_path,use_subtitles,shots_subtitles,movie_clips_path):
|
| 185 |
+
video_images_list=sorted(os.listdir(video_images_path))
|
| 186 |
+
max_caption_index = 0
|
| 187 |
+
preds = {}
|
| 188 |
+
movie_name=movie_clips_path.split('/')[-1]
|
| 189 |
+
videos_summaries=[]
|
| 190 |
+
previous_caption=""
|
| 191 |
+
batch_size=args.batch_size
|
| 192 |
+
batch_images=[]
|
| 193 |
+
batch_instructions=[]
|
| 194 |
+
clip_numbers=[]
|
| 195 |
+
clip_number=0
|
| 196 |
+
conversations=[]
|
| 197 |
+
for i in tqdm(range(0,len(video_images_list),135), desc="Inference video clips", total=len(video_images_list)/120):
|
| 198 |
+
images=[]
|
| 199 |
+
# Add the previous caption to the new video clip
|
| 200 |
+
# if batch_size==1:
|
| 201 |
+
# previous_caption="You are analysing a one long video of mutiple clips and this is the summary from all previous clips :"+videos_summaries[-1] +"\n\n"if len(videos_summaries)>0 else ""
|
| 202 |
+
if previous_caption != "":
|
| 203 |
+
img_placeholder = previous_caption+" "
|
| 204 |
+
else:
|
| 205 |
+
img_placeholder = ""
|
| 206 |
+
number_of_words=0
|
| 207 |
+
max_num_words=400
|
| 208 |
+
max_num_images=45
|
| 209 |
+
clip_number_str=str(clip_number).zfill(2)
|
| 210 |
+
clip_path=os.path.join(movie_clips_path,f"{movie_name}_clip_{clip_number_str}")
|
| 211 |
+
os.makedirs(clip_path, exist_ok=True)
|
| 212 |
+
conversation=""
|
| 213 |
+
for j in range(i,i+135,3):
|
| 214 |
+
if j >= len(video_images_list):
|
| 215 |
+
break
|
| 216 |
+
image_path = os.path.join(video_images_path, video_images_list[j])
|
| 217 |
+
# copy the images to clip folder
|
| 218 |
+
# if the image is already copied, skip it
|
| 219 |
+
if not os.path.exists(os.path.join(clip_path,video_images_list[j])):
|
| 220 |
+
shutil.copy(image_path,clip_path)
|
| 221 |
+
img=Image.open(image_path)
|
| 222 |
+
images.append(self.vis_processor(img))
|
| 223 |
+
img_placeholder += '<Img><ImageHere>'
|
| 224 |
+
shot_num=int(video_images_list[j].split('_')[1])
|
| 225 |
+
if use_subtitles:
|
| 226 |
+
if shots_subtitles.get(shot_num) is not None:
|
| 227 |
+
sub=clean_text(shots_subtitles[shot_num])
|
| 228 |
+
number_of_words+=len(sub.split(' '))
|
| 229 |
+
if number_of_words<= max_num_words and use_subtitles:
|
| 230 |
+
img_placeholder+=f'<Cap>{sub}'
|
| 231 |
+
conversation+=sub+" "
|
| 232 |
+
if len(images) >= max_num_images:
|
| 233 |
+
break
|
| 234 |
+
if len(images) ==0:
|
| 235 |
+
print("Video not found",video_images_path)
|
| 236 |
+
continue
|
| 237 |
+
if 0 <len(images) < max_num_images:
|
| 238 |
+
last_item = images[-1]
|
| 239 |
+
while len(images) < max_num_images:
|
| 240 |
+
images.append(last_item)
|
| 241 |
+
img_placeholder += '<Img><ImageHere>'
|
| 242 |
+
images = torch.stack(images)
|
| 243 |
+
print(images.shape)
|
| 244 |
+
clip_numbers.append(clip_number_str)
|
| 245 |
+
clip_number+=1
|
| 246 |
+
conversations.append(clean_text(conversation))
|
| 247 |
+
instruction = img_placeholder + '\n' + self.summary_instruction
|
| 248 |
+
batch_images.append(images)
|
| 249 |
+
batch_instructions.append(instruction)
|
| 250 |
+
if len(batch_images) < batch_size:
|
| 251 |
+
continue
|
| 252 |
+
# run inference for the batch
|
| 253 |
+
batch_images = torch.stack(batch_images)
|
| 254 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 255 |
+
for i,pred in enumerate(batch_pred):
|
| 256 |
+
max_caption_index += 1
|
| 257 |
+
videos_summaries.append(pred)
|
| 258 |
+
if args.use_coherent_description:
|
| 259 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 260 |
+
else:
|
| 261 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = pred
|
| 262 |
+
if conversations[i]!="" and use_subtitles:
|
| 263 |
+
preds[f'subtitle_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = conversations[i]
|
| 264 |
+
|
| 265 |
+
batch_images=[]
|
| 266 |
+
batch_instructions=[]
|
| 267 |
+
clip_numbers=[]
|
| 268 |
+
conversations=[]
|
| 269 |
+
|
| 270 |
+
# run inference for the last batch
|
| 271 |
+
if len(batch_images)>0:
|
| 272 |
+
batch_images = torch.stack(batch_images)
|
| 273 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 274 |
+
for k,pred in enumerate(batch_pred):
|
| 275 |
+
max_caption_index += 1
|
| 276 |
+
videos_summaries.append(pred)
|
| 277 |
+
if args.use_coherent_description:
|
| 278 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[k]}"
|
| 279 |
+
else:
|
| 280 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = pred
|
| 281 |
+
if conversations[k]!="" and use_subtitles:
|
| 282 |
+
preds[f'subtitle_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = conversations[k]
|
| 283 |
+
|
| 284 |
+
batch_images=[]
|
| 285 |
+
batch_instructions=[]
|
| 286 |
+
return preds
|
| 287 |
+
def movie_inference(self,videoname,use_subtitles):
|
| 288 |
+
embedding_path=os.path.join(self.save_pkls_path,f"{videoname}.pkl")
|
| 289 |
+
if args.index_subtitles_together:
|
| 290 |
+
file_path=os.path.join(self.save_json_path,f"{videoname}.json")
|
| 291 |
+
embedding_path=os.path.join(self.save_pkls_path,f"{videoname}.pkl")
|
| 292 |
+
else:
|
| 293 |
+
file_path=os.path.join(self.save_json_path,f"no_subtiltles_{videoname}.json")
|
| 294 |
+
embedding_path=os.path.join(self.save_pkls_path,f"no_subtiltles_{videoname}.pkl")
|
| 295 |
+
|
| 296 |
+
if args.subtitles_only:
|
| 297 |
+
file_path=os.path.join(self.save_json_path,f"subtiltles_only_{videoname}.json")
|
| 298 |
+
embedding_path=os.path.join(self.save_pkls_path,f"subtiltles_only_{videoname}.pkl")
|
| 299 |
+
|
| 300 |
+
if os.path.exists(file_path):
|
| 301 |
+
print("Already processed")
|
| 302 |
+
return file_path,embedding_path
|
| 303 |
+
important_data = {}
|
| 304 |
+
video_images_path,subtitle_path,movie_annotation,movie_clips_path=self._get_movie_data(videoname)
|
| 305 |
+
shots_subtitles={}
|
| 306 |
+
if use_subtitles:
|
| 307 |
+
if movie_annotation['story'] is not None:
|
| 308 |
+
shots_subtitles=self._get_shots_subtitles(movie_annotation)
|
| 309 |
+
if args.subtitles_only:
|
| 310 |
+
number_of_paragraphs=20
|
| 311 |
+
important_data=self._store_subtitles_paragraphs(subtitle_path,important_data,number_of_paragraphs)
|
| 312 |
+
else:
|
| 313 |
+
preds=self._get_movie_summaries(video_images_path,use_subtitles,shots_subtitles,movie_clips_path)
|
| 314 |
+
if len(shots_subtitles)==0 and use_subtitles:
|
| 315 |
+
number_of_paragraphs=len(preds)
|
| 316 |
+
important_data=self._store_subtitles_paragraphs(subtitle_path,important_data,number_of_paragraphs)
|
| 317 |
+
important_data.update(preds)
|
| 318 |
+
with open(file_path, 'w') as file:
|
| 319 |
+
json.dump(important_data, file, indent=4)
|
| 320 |
+
return file_path,embedding_path
|
| 321 |
+
def answer_movie_questions_RAG(self,qa_list,information_RAG_path,embedding_path):
|
| 322 |
+
QA_external_memory=MemoryIndex(args.neighbours, use_openai=args.use_openai_embedding)
|
| 323 |
+
if os.path.exists(embedding_path):
|
| 324 |
+
QA_external_memory.load_embeddings_from_pkl(embedding_path)
|
| 325 |
+
else:
|
| 326 |
+
QA_external_memory.load_documents_from_json(information_RAG_path,embedding_path)
|
| 327 |
+
summarization_external_memory=MemoryIndex(-1, use_openai=args.use_openai_embedding)
|
| 328 |
+
if os.path.exists(embedding_path):
|
| 329 |
+
summarization_external_memory.load_embeddings_from_pkl(embedding_path)
|
| 330 |
+
else:
|
| 331 |
+
summarization_external_memory.load_documents_from_json(information_RAG_path,embedding_path)
|
| 332 |
+
|
| 333 |
+
# get the most similar context from the external memory to this instruction
|
| 334 |
+
general_related_context_keys_list=[]
|
| 335 |
+
general_related_context_documents_list=[]
|
| 336 |
+
summary_related_context_documents_list=[]
|
| 337 |
+
summary_related_context_keys_list=[]
|
| 338 |
+
total_batch_pred=[]
|
| 339 |
+
related_text=[]
|
| 340 |
+
qa_genearl_prompts=[]
|
| 341 |
+
qa_summary_prompts=[]
|
| 342 |
+
qa_general=[]
|
| 343 |
+
qa_summary=[]
|
| 344 |
+
for qa in qa_list:
|
| 345 |
+
if qa['q_type']=='summary':
|
| 346 |
+
related_context_documents,related_context_keys = summarization_external_memory.search_by_similarity(qa['Q'])
|
| 347 |
+
summary_related_context_documents_list.append(related_context_documents)
|
| 348 |
+
summary_related_context_keys_list.append(related_context_keys)
|
| 349 |
+
prompt=self.prepare_prompt(qa)
|
| 350 |
+
qa_summary_prompts.append(prompt)
|
| 351 |
+
qa_summary.append(qa)
|
| 352 |
+
else:
|
| 353 |
+
related_context_documents,related_context_keys = QA_external_memory.search_by_similarity(qa['Q'])
|
| 354 |
+
general_related_context_keys_list.append(related_context_keys)
|
| 355 |
+
general_related_context_documents_list.append(related_context_documents)
|
| 356 |
+
prompt=self.prepare_prompt(qa)
|
| 357 |
+
qa_genearl_prompts.append(prompt)
|
| 358 |
+
qa_general.append(qa)
|
| 359 |
+
# if I have summary questions answer first, without the need to use clips for information
|
| 360 |
+
if len(qa_summary_prompts)>0:
|
| 361 |
+
# Here the retrieved clips are all movie clips
|
| 362 |
+
context_information_list=[]
|
| 363 |
+
for related_context_keys in summary_related_context_keys_list:
|
| 364 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 365 |
+
context_information=""
|
| 366 |
+
for clip_name in most_related_clips:
|
| 367 |
+
clip_conversation=""
|
| 368 |
+
general_sum=""
|
| 369 |
+
for key in related_context_keys:
|
| 370 |
+
if clip_name in key and 'caption' in key:
|
| 371 |
+
general_sum="Clip Summary: "+summarization_external_memory.documents[key]
|
| 372 |
+
if clip_name in key and 'subtitle' in key:
|
| 373 |
+
clip_conversation="Clip Subtitles: "+summarization_external_memory.documents[key]
|
| 374 |
+
|
| 375 |
+
if args.use_coherent_description:
|
| 376 |
+
context_information+=f"{general_sum}\n"
|
| 377 |
+
else:
|
| 378 |
+
if args.model_summary_only:
|
| 379 |
+
context_information+=f"{general_sum}\n"
|
| 380 |
+
elif args.subtitles_only:
|
| 381 |
+
context_information+=f"{clip_conversation}\n"
|
| 382 |
+
else:
|
| 383 |
+
context_information+=f"{general_sum},{clip_conversation}\n"
|
| 384 |
+
context_information_list.append(context_information)
|
| 385 |
+
if args.use_chatgpt :
|
| 386 |
+
batch_pred=self.inference_RAG_chatGPT(qa_summary_prompts,context_information_list)
|
| 387 |
+
else:
|
| 388 |
+
batch_pred=self.inference_RAG(qa_summary_prompts,context_information_list)
|
| 389 |
+
total_batch_pred.extend(batch_pred)
|
| 390 |
+
related_text.extend(context_information_list)
|
| 391 |
+
|
| 392 |
+
if args.use_clips_for_info:
|
| 393 |
+
batch_pred,general_related_context_keys_list=self.use_clips_for_info(qa_general,general_related_context_keys_list,QA_external_memory)
|
| 394 |
+
total_batch_pred.extend(batch_pred)
|
| 395 |
+
related_text.extend(general_related_context_keys_list)
|
| 396 |
+
else:
|
| 397 |
+
related_context_documents_text_list=[]
|
| 398 |
+
for related_context_documents,related_context_keys in zip(general_related_context_documents_list,general_related_context_keys_list):
|
| 399 |
+
related_information=""
|
| 400 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 401 |
+
for clip_name in most_related_clips:
|
| 402 |
+
clip_conversation=""
|
| 403 |
+
general_sum=""
|
| 404 |
+
for key in QA_external_memory.documents.keys():
|
| 405 |
+
if clip_name in key and 'caption' in key:
|
| 406 |
+
general_sum="Clip Summary: "+QA_external_memory.documents[key]
|
| 407 |
+
if clip_name in key and 'subtitle' in key:
|
| 408 |
+
clip_conversation="Clip Subtitles: "+QA_external_memory.documents[key]
|
| 409 |
+
if args.use_coherent_description:
|
| 410 |
+
related_information+=f"{general_sum}\n"
|
| 411 |
+
else:
|
| 412 |
+
if args.model_summary_only:
|
| 413 |
+
related_information+=f"{general_sum}\n"
|
| 414 |
+
elif args.subtitles_only:
|
| 415 |
+
related_information+=f"{clip_conversation}\n"
|
| 416 |
+
else:
|
| 417 |
+
related_information+=f"{general_sum},{clip_conversation}\n"
|
| 418 |
+
|
| 419 |
+
related_context_documents_text_list.append(related_information)
|
| 420 |
+
|
| 421 |
+
if len (qa_genearl_prompts) >0 and args.use_chatgpt :
|
| 422 |
+
batch_pred=self.inference_RAG_chatGPT(qa_genearl_prompts,related_context_documents_text_list)
|
| 423 |
+
elif len (qa_genearl_prompts) >0:
|
| 424 |
+
batch_pred=self.inference_RAG(qa_genearl_prompts,related_context_documents_text_list)
|
| 425 |
+
total_batch_pred.extend(batch_pred)
|
| 426 |
+
related_text.extend(related_context_documents_text_list)
|
| 427 |
+
assert len(total_batch_pred)==len(related_text)
|
| 428 |
+
return total_batch_pred, related_text
|
| 429 |
+
def get_most_related_clips(self,related_context_keys):
|
| 430 |
+
most_related_clips=[]
|
| 431 |
+
for context_key in related_context_keys:
|
| 432 |
+
if len(context_key.split('__'))>1:
|
| 433 |
+
most_related_clips.append(context_key.split('__')[1])
|
| 434 |
+
if len(most_related_clips)==args.neighbours:
|
| 435 |
+
break
|
| 436 |
+
assert len(most_related_clips)!=0, f"No related clips found {related_context_keys}"
|
| 437 |
+
return most_related_clips
|
| 438 |
+
|
| 439 |
+
def clip_inference(self,clips_name,prompts):
|
| 440 |
+
setup_seeds(seed)
|
| 441 |
+
images_batch, instructions_batch = [], []
|
| 442 |
+
for clip_name, prompt in zip(clips_name, prompts):
|
| 443 |
+
movie_name=clip_name.split('_')[0]
|
| 444 |
+
video_images_path,subtitle_path,movie_annotation,movie_clips_path=self._get_movie_data(movie_name)
|
| 445 |
+
clip_path=os.path.join(movie_clips_path,clip_name)
|
| 446 |
+
if movie_annotation['story'] is not None:
|
| 447 |
+
shots_subtitles=self._get_shots_subtitles(movie_annotation)
|
| 448 |
+
else:
|
| 449 |
+
shots_subtitles={}
|
| 450 |
+
images,img_placeholder=self.prepare_input_images(clip_path,shots_subtitles,use_subtitles=not args.vision_only)
|
| 451 |
+
instruction = img_placeholder + '\n' + prompt
|
| 452 |
+
images_batch.append(images)
|
| 453 |
+
instructions_batch.append(instruction)
|
| 454 |
+
# run inference for the batch
|
| 455 |
+
images_batch=torch.stack(images_batch)
|
| 456 |
+
batch_pred=self.run_images(images_batch,instructions_batch)
|
| 457 |
+
return batch_pred
|
| 458 |
+
def prepare_prompt(self,qa):
|
| 459 |
+
prompt=qa["Q"]
|
| 460 |
+
return prompt
|
| 461 |
+
def use_clips_for_info(self,qa_list,related_context_keys_list,external_memory):
|
| 462 |
+
total_batch_pred=[]
|
| 463 |
+
questions=[]
|
| 464 |
+
related_information_list=[]
|
| 465 |
+
related_context_keys_list_new=[]
|
| 466 |
+
for qa,related_context_keys in zip(qa_list,related_context_keys_list):
|
| 467 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 468 |
+
question=qa['Q']
|
| 469 |
+
# prompt=self.prepare_prompt(qa)
|
| 470 |
+
# prompt+=" and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 471 |
+
prompt=f"From this video extract the related information to This question and provide an explaination for your answer and If you can't find related information, say 'I DON'T KNOW' as option 5 because maybe the questoin is not related to the video content.\n the question is :\n {question}\n your answer :"
|
| 472 |
+
# all_info=self.clip_inference(most_related_clips,[prompt]*len(most_related_clips))
|
| 473 |
+
# make the most_related_clips has unique elements (if retrival from vision summary and conversations)
|
| 474 |
+
most_related_clips=list(set(most_related_clips))
|
| 475 |
+
batch_inference=[]
|
| 476 |
+
all_info=[]
|
| 477 |
+
for related_clip in most_related_clips:
|
| 478 |
+
batch_inference.append(related_clip)
|
| 479 |
+
if len(batch_inference)<args.batch_size:
|
| 480 |
+
continue
|
| 481 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 482 |
+
batch_inference=[]
|
| 483 |
+
if len(batch_inference)>0:
|
| 484 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 485 |
+
|
| 486 |
+
related_information=""
|
| 487 |
+
for info,clip_name in zip(all_info,most_related_clips):
|
| 488 |
+
clip_conversation=""
|
| 489 |
+
general_sum=""
|
| 490 |
+
for key in external_memory.documents.keys():
|
| 491 |
+
if clip_name in key and 'caption' in key:
|
| 492 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 493 |
+
if clip_name in key and 'subtitle' in key:
|
| 494 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 495 |
+
|
| 496 |
+
if args.use_coherent_description:
|
| 497 |
+
related_information+=f"question_related_information: {info},{general_sum}\n"
|
| 498 |
+
else:
|
| 499 |
+
if args.model_summary_only:
|
| 500 |
+
related_information+=f"{general_sum},question_related_information: {info}\n"
|
| 501 |
+
elif args.info_only:
|
| 502 |
+
related_information+=f"question_related_information: {info}\n"
|
| 503 |
+
elif args.subtitles_only:
|
| 504 |
+
related_information+=f"{clip_conversation},question_related_information: {info}\n"
|
| 505 |
+
else:
|
| 506 |
+
related_information+=f"{general_sum},{clip_conversation},question_related_information: {info}\n"
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# related_information+=f"question_related_information: {info},{clip_conversation}\n"
|
| 510 |
+
questions.append(question)
|
| 511 |
+
related_information_list.append(related_information)
|
| 512 |
+
related_context_keys.append(related_information)
|
| 513 |
+
related_context_keys_list_new.append(related_context_keys)
|
| 514 |
+
if len(questions)< args.batch_size:
|
| 515 |
+
continue
|
| 516 |
+
setup_seeds(seed)
|
| 517 |
+
if args.use_chatgpt :
|
| 518 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 519 |
+
else:
|
| 520 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 521 |
+
|
| 522 |
+
for pred in batch_pred:
|
| 523 |
+
total_batch_pred.append(pred)
|
| 524 |
+
questions=[]
|
| 525 |
+
related_information_list=[]
|
| 526 |
+
|
| 527 |
+
if len(questions)>0:
|
| 528 |
+
setup_seeds(seed)
|
| 529 |
+
if args.use_chatgpt :
|
| 530 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 531 |
+
else:
|
| 532 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 533 |
+
for pred in batch_pred:
|
| 534 |
+
total_batch_pred.append(pred)
|
| 535 |
+
return total_batch_pred,related_context_keys_list_new
|
| 536 |
+
def define_save_name(self):
|
| 537 |
+
save_name="subtitles" if args.index_subtitles_together else "no_subtitles"
|
| 538 |
+
save_name+="_clips_for_info" if args.use_clips_for_info else ""
|
| 539 |
+
save_name+="_chatgpt" if args.use_chatgpt else ""
|
| 540 |
+
save_name+="_vision_only" if args.vision_only else ""
|
| 541 |
+
save_name+="_model_summary_only" if args.model_summary_only else ""
|
| 542 |
+
save_name+="_subtitles_only" if args.subtitles_only else ""
|
| 543 |
+
save_name+="_info_only" if args.info_only else ""
|
| 544 |
+
print("save_name",save_name)
|
| 545 |
+
return save_name
|
| 546 |
+
def eval_llama_vid(self):
|
| 547 |
+
## LLAMa vid QA evaluation
|
| 548 |
+
full_questions_result=[]
|
| 549 |
+
movie_number=0
|
| 550 |
+
start=args.start
|
| 551 |
+
end=args.end
|
| 552 |
+
save_name=self.define_save_name()
|
| 553 |
+
for movie in tqdm(self.movies_dict.keys()):
|
| 554 |
+
if args.start <=movie_number < args.end:
|
| 555 |
+
save_dir=f"new_workspace/results/llama_vid/{args.exp_name}/{save_name}_{args.neighbours}_neighbours"
|
| 556 |
+
if os.path.exists( f"{save_dir}/{movie}.json" ):
|
| 557 |
+
print(f"Movie {movie} already processed")
|
| 558 |
+
with open(f"{save_dir}/{movie}.json", 'r') as f:
|
| 559 |
+
pred_json = json.load(f)
|
| 560 |
+
full_questions_result.extend(pred_json)
|
| 561 |
+
continue
|
| 562 |
+
use_subtitles_while_generating_summary=not args.vision_only
|
| 563 |
+
information_RAG_path,embedding_path=self.movie_inference(movie,use_subtitles_while_generating_summary)
|
| 564 |
+
external_memory=MemoryIndex(args.neighbours, use_openai=args.use_openai_embedding)
|
| 565 |
+
if os.path.exists(embedding_path):
|
| 566 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 567 |
+
else:
|
| 568 |
+
external_memory.load_documents_from_json(information_RAG_path,emdedding_path=embedding_path)
|
| 569 |
+
save_dir=f"new_workspace/results/llama_vid/{args.exp_name}/{save_name}_{args.neighbours}_neighbours"
|
| 570 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 571 |
+
pred_json=[]
|
| 572 |
+
batch_questions=[]
|
| 573 |
+
for qa in tqdm(self.movies_dict[movie],desc="Inference questions"):
|
| 574 |
+
batch_questions.append(qa)
|
| 575 |
+
if len(batch_questions)<args.batch_size:
|
| 576 |
+
continue
|
| 577 |
+
model_ans,related_text=self.answer_movie_questions_RAG(batch_questions,information_RAG_path,embedding_path)
|
| 578 |
+
for qa,ans,related_info in zip(batch_questions,model_ans,related_text):
|
| 579 |
+
qa.update({'pred':ans})
|
| 580 |
+
qa.update({'related_info':related_info})
|
| 581 |
+
pred_json.append(qa)
|
| 582 |
+
batch_questions=[]
|
| 583 |
+
if len(batch_questions)>0:
|
| 584 |
+
model_ans,related_text=self.answer_movie_questions_RAG(batch_questions,information_RAG_path,embedding_path)
|
| 585 |
+
for qa,ans,related_info in zip(batch_questions,model_ans,related_text):
|
| 586 |
+
qa.update({'pred':ans})
|
| 587 |
+
qa.update({'related_info':related_info})
|
| 588 |
+
pred_json.append(qa)
|
| 589 |
+
full_questions_result.extend(pred_json)
|
| 590 |
+
with open(f"{save_dir}/{movie}.json", 'w') as fp:
|
| 591 |
+
json.dump(pred_json, fp)
|
| 592 |
+
print(f"Movie {movie} prediction saved to {save_dir}/{movie}.json")
|
| 593 |
+
movie_number+=1
|
| 594 |
+
with open(f"{save_dir}/full_pred_s{start}_end{end}.json", 'w') as fp:
|
| 595 |
+
json.dump(full_questions_result, fp)
|
| 596 |
+
args=get_arguments()
|
| 597 |
+
|
| 598 |
+
def setup_seeds(seed):
|
| 599 |
+
random.seed(seed)
|
| 600 |
+
np.random.seed(seed)
|
| 601 |
+
torch.manual_seed(seed)
|
| 602 |
+
torch.cuda.manual_seed(seed)
|
| 603 |
+
cudnn.benchmark = False
|
| 604 |
+
cudnn.deterministic = True
|
| 605 |
+
|
| 606 |
+
import yaml
|
| 607 |
+
# read this file test_configs/llama2_test_config.yaml
|
| 608 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 609 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 610 |
+
seed=config['run']['seed']
|
| 611 |
+
print("seed",seed)
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
setup_seeds(seed)
|
| 615 |
+
llama_vid_eval=LlamaVidQAEval(args)
|
| 616 |
+
llama_vid_eval.eval_llama_vid()
|
evaluation/eval_goldfish_movie_chat.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
project_dir = os.getcwd()
|
| 4 |
+
sys.path.append(project_dir)
|
| 5 |
+
import json
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from goldfish_lv import GoldFish_LV,split_subtitles,time_to_seconds
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
# from openai import OpenAI
|
| 14 |
+
from minigpt4.common.eval_utils import init_model
|
| 15 |
+
from minigpt4.conversation.conversation import CONV_VISION
|
| 16 |
+
from index import MemoryIndex
|
| 17 |
+
import pysrt
|
| 18 |
+
import chardet
|
| 19 |
+
import torch
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.backends.cudnn as cudnn
|
| 23 |
+
def str2bool(v):
|
| 24 |
+
if isinstance(v, bool):
|
| 25 |
+
return v
|
| 26 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 27 |
+
return True
|
| 28 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 29 |
+
return False
|
| 30 |
+
else:
|
| 31 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 32 |
+
|
| 33 |
+
def get_arguments():
|
| 34 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 35 |
+
parser.add_argument("--neighbours", type=int, default=-1)
|
| 36 |
+
parser.add_argument("--neighbours_global", type=int, default=-1)
|
| 37 |
+
parser.add_argument("--fps", type=float, default=0.5)
|
| 38 |
+
parser.add_argument("--name", type=str,default="ckpt_92",help="name of the experiment")
|
| 39 |
+
parser.add_argument("--add_unknown", action='store_true')
|
| 40 |
+
parser.add_argument("--use_chatgpt", action='store_true')
|
| 41 |
+
parser.add_argument("--use_choices_for_info", action='store_true')
|
| 42 |
+
parser.add_argument("--use_gt_information", action='store_true')
|
| 43 |
+
parser.add_argument("--inference_text", action='store_true')
|
| 44 |
+
parser.add_argument("--use_gt_information_with_distraction", action='store_true')
|
| 45 |
+
parser.add_argument("--num_distraction", type=int, default=2)
|
| 46 |
+
parser.add_argument("--add_confidance_score", action='store_true')
|
| 47 |
+
parser.add_argument("--use_original_video", action='store_true')
|
| 48 |
+
parser.add_argument("--use_video_embedding", action='store_true')
|
| 49 |
+
parser.add_argument("--use_clips_for_info", action='store_true')
|
| 50 |
+
parser.add_argument("--use_GT_video", action='store_true')
|
| 51 |
+
parser.add_argument("--use_gt_summary", action='store_true')
|
| 52 |
+
parser.add_argument("--index_subtitles", action='store_true')
|
| 53 |
+
parser.add_argument("--index_subtitles_together", action='store_true')
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--ask_the_question_early", action='store_true')
|
| 56 |
+
parser.add_argument("--clip_in_ask_early", action='store_true')
|
| 57 |
+
parser.add_argument("--summary_with_subtitles_only", action='store_true')
|
| 58 |
+
parser.add_argument("--use_coherent_description", action='store_true')
|
| 59 |
+
parser.add_argument("--v_sum_and_info", action='store_true')
|
| 60 |
+
|
| 61 |
+
parser.add_argument("--start", default=0, type=int)
|
| 62 |
+
parser.add_argument("--end", default=100000, type=int)
|
| 63 |
+
parser.add_argument("--exp_name", type=str,default="",help="name of eval folder")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 67 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 68 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 69 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 70 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
| 71 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 72 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 73 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 74 |
+
parser.add_argument("--video_path", type=str, help="path to the video")
|
| 75 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 76 |
+
parser.add_argument("--dataset_videos_path", type=str, help="path to the dataset videos")
|
| 77 |
+
parser.add_argument("--annotation_json_folder", type=str, help="path to the annotation folder")
|
| 78 |
+
parser.add_argument("--options", nargs="+")
|
| 79 |
+
return parser.parse_args()
|
| 80 |
+
|
| 81 |
+
def get_movie_time(subtitle_path):
|
| 82 |
+
# read the subtitle file and detect the encoding
|
| 83 |
+
with open(subtitle_path, 'rb') as f:
|
| 84 |
+
result = chardet.detect(f.read())
|
| 85 |
+
subtitles = pysrt.open(subtitle_path, encoding=result['encoding'])
|
| 86 |
+
video_time=time_to_seconds(subtitles[-1].end)
|
| 87 |
+
return video_time
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
import torch
|
| 91 |
+
from torch.utils.data import Dataset, DataLoader
|
| 92 |
+
from torchvision.transforms import Compose
|
| 93 |
+
import h5py
|
| 94 |
+
import torch
|
| 95 |
+
import os
|
| 96 |
+
|
| 97 |
+
def numerical_sort_key(filename):
|
| 98 |
+
base_name = os.path.splitext(filename)[0]
|
| 99 |
+
return int(base_name)
|
| 100 |
+
|
| 101 |
+
class MovieChatDataset(Dataset):
|
| 102 |
+
def __init__(self, dataset_path, annotation_path,fps, transform=None,start=0,end=100000):
|
| 103 |
+
self.dataset_path = dataset_path
|
| 104 |
+
self.annotation_path=annotation_path
|
| 105 |
+
self.transform = transform
|
| 106 |
+
self.movie_name = os.listdir(dataset_path)
|
| 107 |
+
self.movie_name = [file for file in self.movie_name if file != '.DS_Store']
|
| 108 |
+
self.fps = fps
|
| 109 |
+
self.len_clip = 45
|
| 110 |
+
self.start=start
|
| 111 |
+
self.end=end
|
| 112 |
+
def load_frames(self, movie_name):
|
| 113 |
+
filenames = sorted(os.listdir(os.path.join(self.dataset_path, movie_name)))
|
| 114 |
+
|
| 115 |
+
filenames.sort(key=numerical_sort_key)
|
| 116 |
+
# define torch tensor to store the frames of size(0,0,0)
|
| 117 |
+
data = []
|
| 118 |
+
for filename_number in tqdm(filenames,desc="Loading frames"):
|
| 119 |
+
file_path = os.path.join(self.dataset_path, movie_name, filename_number)
|
| 120 |
+
|
| 121 |
+
if not os.path.isfile(file_path):
|
| 122 |
+
print(f"Did not find file: {filename_number}")
|
| 123 |
+
try:
|
| 124 |
+
with h5py.File(file_path, 'r') as h5_file:
|
| 125 |
+
image_embeds=torch.tensor(h5_file[f"frames_{filename_number[:-3]}"][:])
|
| 126 |
+
image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
|
| 127 |
+
# concate each 4 neighbours image tokens
|
| 128 |
+
bs, pn, hs = image_embeds.shape
|
| 129 |
+
image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4))
|
| 130 |
+
data.extend(image_embeds)
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Failed to process {filename_number}: {e}")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
frames=torch.stack(data)
|
| 137 |
+
return frames
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.movie_name)
|
| 141 |
+
|
| 142 |
+
def _get_movie_questions(self,movie_annotations):
|
| 143 |
+
global_questions=movie_annotations['global']
|
| 144 |
+
local_questions=movie_annotations['breakpoint']
|
| 145 |
+
return global_questions,local_questions
|
| 146 |
+
def __getitem__(self, idx):
|
| 147 |
+
if self.start<=idx<self.end:
|
| 148 |
+
self.frames = self.load_frames(self.movie_name[idx])
|
| 149 |
+
movie_name=self.movie_name[idx]
|
| 150 |
+
with open(os.path.join(self.annotation_path,movie_name+".json"), 'r') as f:
|
| 151 |
+
movie_annotations = json.load(f)
|
| 152 |
+
global_questions,local_questions=self._get_movie_questions(movie_annotations)
|
| 153 |
+
sampling_value = int(movie_annotations['info']['fps']/self.fps)
|
| 154 |
+
clips_list=[]
|
| 155 |
+
current_clip=[]
|
| 156 |
+
for i in range(0,self.frames.shape[0], sampling_value):
|
| 157 |
+
current_clip.append(self.frames[i])
|
| 158 |
+
if len(current_clip) >= self.len_clip:
|
| 159 |
+
clips_list.append(torch.stack(current_clip))
|
| 160 |
+
current_clip=[]
|
| 161 |
+
if len(current_clip) > 0:
|
| 162 |
+
last_frame_current_clip = current_clip[-1]
|
| 163 |
+
while len(current_clip) < self.len_clip:
|
| 164 |
+
current_clip.append(last_frame_current_clip)
|
| 165 |
+
clips_list.append(torch.stack(current_clip))
|
| 166 |
+
return clips_list, movie_name,global_questions,local_questions
|
| 167 |
+
else:
|
| 168 |
+
return [], self.movie_name[idx],[],[]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class MovieChat (GoldFish_LV):
|
| 172 |
+
|
| 173 |
+
def __init__(self,args):
|
| 174 |
+
super().__init__(args)
|
| 175 |
+
self.args=args
|
| 176 |
+
self.save_long_videos_path = "new_workspace/clips_summary/movie_chat/"
|
| 177 |
+
if args.use_openai_embedding:
|
| 178 |
+
self.save_embedding_path = "new_workspace/open_ai_embedding/movie_chat/"
|
| 179 |
+
else:
|
| 180 |
+
self.save_embedding_path = "new_workspace/embedding/movie_chat/"
|
| 181 |
+
os.makedirs(self.save_long_videos_path, exist_ok=True)
|
| 182 |
+
os.makedirs(self.save_embedding_path, exist_ok=True)
|
| 183 |
+
self.max_sub_len=400
|
| 184 |
+
self.max_num_images=45
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _get_long_video_summaries(self,clips,save_path):
|
| 188 |
+
batch=[]
|
| 189 |
+
batch_instructions=[]
|
| 190 |
+
preds={}
|
| 191 |
+
clip_numbers=[]
|
| 192 |
+
max_caption_index=0
|
| 193 |
+
for i,clip_features in enumerate(clips):
|
| 194 |
+
if len(clip_features)!=self.max_num_images:
|
| 195 |
+
continue
|
| 196 |
+
batch.append(clip_features)
|
| 197 |
+
img_placeholder=""
|
| 198 |
+
for j in range(len(clip_features)):
|
| 199 |
+
img_placeholder+="<Img><ImageHere>"
|
| 200 |
+
instruction = img_placeholder + '\n' + self.summary_instruction
|
| 201 |
+
batch_instructions.append(instruction)
|
| 202 |
+
clip_numbers.append(i)
|
| 203 |
+
if len(batch)<args.batch_size:
|
| 204 |
+
continue
|
| 205 |
+
batch=torch.stack(batch)
|
| 206 |
+
batch_pred= self.run_images_features(batch,batch_instructions)
|
| 207 |
+
for j,pred in enumerate(batch_pred):
|
| 208 |
+
max_caption_index += 1
|
| 209 |
+
if pred !="":
|
| 210 |
+
preds[f'caption__clip_{str(clip_numbers[j]).zfill(2)}'] = pred
|
| 211 |
+
batch=[]
|
| 212 |
+
clip_numbers=[]
|
| 213 |
+
batch_instructions=[]
|
| 214 |
+
if len(batch)>0:
|
| 215 |
+
batch=torch.stack(batch)
|
| 216 |
+
batch_pred= self.run_images_features(batch,batch_instructions)
|
| 217 |
+
for j,pred in enumerate(batch_pred):
|
| 218 |
+
max_caption_index += 1
|
| 219 |
+
if pred !="":
|
| 220 |
+
preds[f'caption__clip_{str(clip_numbers[j]).zfill(2)}'] = pred
|
| 221 |
+
with open(save_path, 'w') as file:
|
| 222 |
+
json.dump(preds, file, indent=4)
|
| 223 |
+
return preds
|
| 224 |
+
def use_model_summary (self,qa_prompts,related_context_documents_list,related_context_keys_list,external_memory):
|
| 225 |
+
related_context_documents_text_list=[]
|
| 226 |
+
for related_context_documents,related_context_keys in zip(related_context_documents_list,related_context_keys_list):
|
| 227 |
+
related_information=""
|
| 228 |
+
most_related_clips=self.get_most_related_clips_index(related_context_keys,external_memory)
|
| 229 |
+
for clip_name in most_related_clips:
|
| 230 |
+
general_sum=""
|
| 231 |
+
clip_name=str(clip_name).zfill(2)
|
| 232 |
+
for key in external_memory.documents.keys():
|
| 233 |
+
if clip_name in key and 'caption' in key:
|
| 234 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 235 |
+
break
|
| 236 |
+
related_information+=f"{general_sum}\n"
|
| 237 |
+
related_context_documents_text_list.append(related_information)
|
| 238 |
+
|
| 239 |
+
if args.use_chatgpt :
|
| 240 |
+
batch_pred=self.inference_RAG_chatGPT(qa_prompts,related_context_documents_text_list)
|
| 241 |
+
else:
|
| 242 |
+
batch_pred=self.inference_RAG(qa_prompts,related_context_documents_text_list)
|
| 243 |
+
return batch_pred, related_context_documents_text_list
|
| 244 |
+
def answer_movie_questions_RAG(self,qa_list,information_RAG_path,embedding_path,q_type):
|
| 245 |
+
if q_type=='local':
|
| 246 |
+
external_memory=MemoryIndex(args.neighbours, use_openai=self.args.use_openai_embedding)
|
| 247 |
+
else:
|
| 248 |
+
external_memory=MemoryIndex(args.neighbours_global, use_openai=self.args.use_openai_embedding)
|
| 249 |
+
if os.path.exists(embedding_path):
|
| 250 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 251 |
+
else:
|
| 252 |
+
external_memory.load_documents_from_json(information_RAG_path,embedding_path)
|
| 253 |
+
# get the most similar context from the external memory to this instruction
|
| 254 |
+
related_context_documents_list=[]
|
| 255 |
+
related_context_keys_list=[]
|
| 256 |
+
total_batch_pred=[]
|
| 257 |
+
related_text=[]
|
| 258 |
+
qa_prompts=[]
|
| 259 |
+
for qa in qa_list:
|
| 260 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(qa['question'])
|
| 261 |
+
related_context_documents_list.append(related_context_documents)
|
| 262 |
+
related_context_keys_list.append(related_context_keys)
|
| 263 |
+
prompt=self.prepare_prompt(qa)
|
| 264 |
+
qa_prompts.append(prompt)
|
| 265 |
+
if args.use_clips_for_info:
|
| 266 |
+
batch_pred,related_context_keys_list=self.use_clips_for_info(qa_list,related_context_keys_list,external_memory)
|
| 267 |
+
total_batch_pred.extend(batch_pred)
|
| 268 |
+
related_text.extend(related_context_keys_list)
|
| 269 |
+
else:
|
| 270 |
+
batch_pred, related_context_documents_text_list=self.use_model_summary (qa_prompts,
|
| 271 |
+
related_context_documents_list,related_context_keys_list,external_memory)
|
| 272 |
+
total_batch_pred.extend(batch_pred)
|
| 273 |
+
related_text.extend(related_context_documents_text_list)
|
| 274 |
+
assert len(total_batch_pred)==len(qa_list)
|
| 275 |
+
assert len(total_batch_pred)==len(related_text)
|
| 276 |
+
return total_batch_pred, related_text
|
| 277 |
+
def get_most_related_clips_index(self,related_context_keys,external_memory):
|
| 278 |
+
most_related_clips_index=[]
|
| 279 |
+
for context_key in related_context_keys:
|
| 280 |
+
# loop over memory keys to get the context key index
|
| 281 |
+
for i,key in enumerate(external_memory.documents.keys()):
|
| 282 |
+
if context_key in key:
|
| 283 |
+
most_related_clips_index.append(i)
|
| 284 |
+
break
|
| 285 |
+
|
| 286 |
+
return most_related_clips_index
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def clip_inference(self,clips_idx,prompts):
|
| 290 |
+
setup_seeds(seed)
|
| 291 |
+
images_batch, instructions_batch = [], []
|
| 292 |
+
for clip_idx, prompt in zip(clips_idx, prompts):
|
| 293 |
+
clip_features=self.video_clips[clip_idx]
|
| 294 |
+
img_placeholder=""
|
| 295 |
+
for j in range(len(clip_features)):
|
| 296 |
+
img_placeholder+='<Img><ImageHere>'
|
| 297 |
+
instruction = img_placeholder + '\n' + prompt
|
| 298 |
+
images_batch.append(clip_features)
|
| 299 |
+
instructions_batch.append(instruction)
|
| 300 |
+
# run inference for the batch
|
| 301 |
+
images_batch=torch.stack(images_batch)
|
| 302 |
+
batch_pred= self.run_images_features(images_batch,instructions_batch)
|
| 303 |
+
return batch_pred
|
| 304 |
+
def prepare_prompt(self,qa):
|
| 305 |
+
prompt=qa["question"]
|
| 306 |
+
return prompt
|
| 307 |
+
def use_clips_for_info(self,qa_list,related_context_keys_list,external_memory):
|
| 308 |
+
total_batch_pred=[]
|
| 309 |
+
questions=[]
|
| 310 |
+
related_information_list=[]
|
| 311 |
+
related_context_keys_list_new=[]
|
| 312 |
+
for qa,related_context_keys in zip(qa_list,related_context_keys_list):
|
| 313 |
+
most_related_clips_index=self.get_most_related_clips_index(related_context_keys,external_memory)
|
| 314 |
+
question=qa['question']
|
| 315 |
+
prompt=f"From this video extract the related information to This question and provide an explaination for your answer and If you can't find any related information, say 'I DON'T KNOW' as option 5 because maybe the questoin is not related to the video content.\n the question is :\n {question}\n your answer :"
|
| 316 |
+
batch_inference=[]
|
| 317 |
+
all_info=[]
|
| 318 |
+
for clip_idx in most_related_clips_index:
|
| 319 |
+
batch_inference.append(clip_idx)
|
| 320 |
+
if len(batch_inference)<args.batch_size:
|
| 321 |
+
continue
|
| 322 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 323 |
+
batch_inference=[]
|
| 324 |
+
if len(batch_inference)>0:
|
| 325 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 326 |
+
# all_info=self.clip_inference(most_related_clips_index,[prompt]*len(most_related_clips_index))
|
| 327 |
+
related_information=""
|
| 328 |
+
for info,clip_name in zip(all_info,most_related_clips_index):
|
| 329 |
+
general_sum=""
|
| 330 |
+
clip_name=str(clip_name).zfill(2)
|
| 331 |
+
for key in external_memory.documents.keys():
|
| 332 |
+
if clip_name in key and 'caption' in key:
|
| 333 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 334 |
+
if args.v_sum_and_info:
|
| 335 |
+
related_information+=f"{general_sum},question_related_information: {info}\n"
|
| 336 |
+
else:
|
| 337 |
+
related_information+=f"question_related_information: {info}\n"
|
| 338 |
+
questions.append(question)
|
| 339 |
+
related_information_list.append(related_information)
|
| 340 |
+
related_context_keys.append(related_information)
|
| 341 |
+
related_context_keys_list_new.append(related_context_keys)
|
| 342 |
+
if len(questions)< args.batch_size:
|
| 343 |
+
continue
|
| 344 |
+
setup_seeds(seed)
|
| 345 |
+
if args.use_chatgpt :
|
| 346 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 347 |
+
else:
|
| 348 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 349 |
+
|
| 350 |
+
for pred in batch_pred:
|
| 351 |
+
total_batch_pred.append(pred)
|
| 352 |
+
questions=[]
|
| 353 |
+
related_information_list=[]
|
| 354 |
+
|
| 355 |
+
if len(questions)>0:
|
| 356 |
+
setup_seeds(seed)
|
| 357 |
+
if args.use_chatgpt :
|
| 358 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 359 |
+
else:
|
| 360 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 361 |
+
for pred in batch_pred:
|
| 362 |
+
total_batch_pred.append(pred)
|
| 363 |
+
return total_batch_pred,related_context_keys_list_new
|
| 364 |
+
def define_save_name(self):
|
| 365 |
+
save_name="subtitles" if args.index_subtitles else "no_subtitles"
|
| 366 |
+
save_name="subtitles_together" if args.index_subtitles_together else save_name
|
| 367 |
+
save_name="summary_with_subtitles_only" if args.summary_with_subtitles_only else save_name
|
| 368 |
+
save_name+="_unknown" if args.add_unknown else ""
|
| 369 |
+
save_name+="_clips_for_info" if args.use_clips_for_info else ""
|
| 370 |
+
save_name+="_chatgpt" if args.use_chatgpt else ""
|
| 371 |
+
save_name+="_choices_for_info" if args.use_choices_for_info else ""
|
| 372 |
+
save_name+="_v_sum_and_info" if args.v_sum_and_info else ""
|
| 373 |
+
save_name+='fps_'+str(args.fps)
|
| 374 |
+
save_dir=f"new_workspace/results/moviechat/{args.exp_name}/{save_name}_{args.neighbours_global}_neighbours"
|
| 375 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 376 |
+
return save_dir
|
| 377 |
+
|
| 378 |
+
def eval_moviechat(self):
|
| 379 |
+
start=args.start
|
| 380 |
+
end=args.end
|
| 381 |
+
dataset_path = args.dataset_videos_path
|
| 382 |
+
annotation_json_folder=args.annotation_json_folder
|
| 383 |
+
dataset = MovieChatDataset(dataset_path,annotation_json_folder, fps=args.fps,start=start,end=end)
|
| 384 |
+
# dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
| 385 |
+
full_questions_result=[]
|
| 386 |
+
save_dir=self.define_save_name()
|
| 387 |
+
|
| 388 |
+
for i,(clips ,video_name,global_questions,local_questions) in enumerate(dataset):
|
| 389 |
+
# code here
|
| 390 |
+
if start<=i < end:
|
| 391 |
+
print("video_name",video_name)
|
| 392 |
+
self.video_clips=clips
|
| 393 |
+
self.video_name=video_name
|
| 394 |
+
file_path=os.path.join(self.save_long_videos_path,self.video_name+f"_fps{args.fps}.json")
|
| 395 |
+
embedding_path=os.path.join(self.save_embedding_path,self.video_name+f"_fps{args.fps}.pkl")
|
| 396 |
+
if os.path.exists(file_path):
|
| 397 |
+
print("Already processed")
|
| 398 |
+
else:
|
| 399 |
+
self._get_long_video_summaries(clips,file_path)
|
| 400 |
+
batch_questions=[]
|
| 401 |
+
for qa in global_questions:
|
| 402 |
+
batch_questions.append(qa)
|
| 403 |
+
if len(batch_questions)<args.batch_size:
|
| 404 |
+
continue
|
| 405 |
+
model_answers, related_text=self.answer_movie_questions_RAG(batch_questions,file_path,embedding_path,q_type='global')
|
| 406 |
+
for qa,ans in zip(batch_questions,model_answers):
|
| 407 |
+
qa.update({'pred':ans})
|
| 408 |
+
qa['Q']=qa['question']
|
| 409 |
+
qa['A']=qa['answer']
|
| 410 |
+
qa.pop('question', None)
|
| 411 |
+
qa.pop('answer', None)
|
| 412 |
+
|
| 413 |
+
batch_questions=[]
|
| 414 |
+
if len(batch_questions)>0:
|
| 415 |
+
model_answers, related_text=self.answer_movie_questions_RAG(batch_questions,file_path,embedding_path,q_type='global')
|
| 416 |
+
for qa,ans in zip(batch_questions,model_answers):
|
| 417 |
+
qa.update({'pred':ans})
|
| 418 |
+
qa['Q']=qa['question']
|
| 419 |
+
qa['A']=qa['answer']
|
| 420 |
+
qa.pop('question', None)
|
| 421 |
+
qa.pop('answer', None)
|
| 422 |
+
|
| 423 |
+
full_questions_result.extend(global_questions)
|
| 424 |
+
print(f"Finished {i} out of {len(dataset)}")
|
| 425 |
+
# save the results
|
| 426 |
+
with open(f"{save_dir}/{self.video_name}.json", 'w') as file:
|
| 427 |
+
# json.dump(global_questions+local_questions, file, indent=4)
|
| 428 |
+
json.dump(global_questions, file, indent=4)
|
| 429 |
+
|
| 430 |
+
with open(f"{save_dir}/full_pred_{start}_{end}.json", 'w') as fp:
|
| 431 |
+
json.dump(full_questions_result, fp)
|
| 432 |
+
args=get_arguments()
|
| 433 |
+
|
| 434 |
+
def setup_seeds(seed):
|
| 435 |
+
random.seed(seed)
|
| 436 |
+
np.random.seed(seed)
|
| 437 |
+
torch.manual_seed(seed)
|
| 438 |
+
torch.cuda.manual_seed(seed)
|
| 439 |
+
cudnn.benchmark = False
|
| 440 |
+
cudnn.deterministic = True
|
| 441 |
+
|
| 442 |
+
import yaml
|
| 443 |
+
# read this file test_configs/llama2_test_config.yaml
|
| 444 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 445 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 446 |
+
seed=config['run']['seed']
|
| 447 |
+
print("seed",seed)
|
| 448 |
+
|
| 449 |
+
if __name__ == "__main__":
|
| 450 |
+
setup_seeds(seed)
|
| 451 |
+
llama_vid_eval=MovieChat(args)
|
| 452 |
+
llama_vid_eval.eval_moviechat()
|
| 453 |
+
|
evaluation/eval_goldfish_movie_qa.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
project_dir = os.getcwd()
|
| 4 |
+
sys.path.append(project_dir)
|
| 5 |
+
import json
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from goldfish_lv import GoldFish_LV,split_subtitles,time_to_seconds
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
import re
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from PIL import Image
|
| 15 |
+
# from openai import OpenAI
|
| 16 |
+
from index import MemoryIndex
|
| 17 |
+
import pysrt
|
| 18 |
+
import chardet
|
| 19 |
+
import torch
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.backends.cudnn as cudnn
|
| 23 |
+
import shutil
|
| 24 |
+
def str2bool(v):
|
| 25 |
+
if isinstance(v, bool):
|
| 26 |
+
return v
|
| 27 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 28 |
+
return True
|
| 29 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 30 |
+
return False
|
| 31 |
+
else:
|
| 32 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 33 |
+
|
| 34 |
+
def get_arguments():
|
| 35 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 36 |
+
parser.add_argument("--neighbours", type=int, default=-1)
|
| 37 |
+
parser.add_argument("--name", type=str,default="ckpt_92",help="name of the experiment")
|
| 38 |
+
parser.add_argument("--add_unknown", action='store_true')
|
| 39 |
+
parser.add_argument("--use_chatgpt", action='store_true')
|
| 40 |
+
parser.add_argument("--use_choices_for_info", action='store_true')
|
| 41 |
+
parser.add_argument("--use_gt_information", action='store_true')
|
| 42 |
+
parser.add_argument("--inference_text", action='store_true')
|
| 43 |
+
parser.add_argument("--use_gt_information_with_distraction", action='store_true')
|
| 44 |
+
parser.add_argument("--num_distraction", type=int, default=2)
|
| 45 |
+
parser.add_argument("--add_confidance_score", action='store_true')
|
| 46 |
+
parser.add_argument("--use_original_video", action='store_true')
|
| 47 |
+
parser.add_argument("--use_video_embedding", action='store_true')
|
| 48 |
+
parser.add_argument("--use_clips_for_info", action='store_true')
|
| 49 |
+
parser.add_argument("--use_GT_video", action='store_true')
|
| 50 |
+
parser.add_argument("--use_gt_summary", action='store_true')
|
| 51 |
+
parser.add_argument("--index_subtitles", action='store_true')
|
| 52 |
+
parser.add_argument("--index_subtitles_together", action='store_true')
|
| 53 |
+
|
| 54 |
+
parser.add_argument("--ask_the_question_early", action='store_true')
|
| 55 |
+
parser.add_argument("--clip_in_ask_early", action='store_true')
|
| 56 |
+
parser.add_argument("--summary_with_subtitles_only", action='store_true')
|
| 57 |
+
parser.add_argument("--use_coherent_description", action='store_true')
|
| 58 |
+
|
| 59 |
+
parser.add_argument("--start", default=0, type=int)
|
| 60 |
+
parser.add_argument("--end", default=100000, type=int)
|
| 61 |
+
parser.add_argument("--exp_name", type=str,default="",help="name of eval folder")
|
| 62 |
+
|
| 63 |
+
parser.add_argument("--vision_only", action='store_true')
|
| 64 |
+
parser.add_argument("--model_summary_only", action='store_true')
|
| 65 |
+
parser.add_argument("--subtitles_only", action='store_true')
|
| 66 |
+
parser.add_argument("--info_only", action='store_true')
|
| 67 |
+
|
| 68 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 69 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 70 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 71 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 72 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
| 73 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 74 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 75 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 76 |
+
parser.add_argument("--video_path", type=str, help="path to the video")
|
| 77 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 78 |
+
parser.add_argument("--annotation_path", type=str, help="path to the annotation file")
|
| 79 |
+
parser.add_argument("--videos_path", type=str, help="path to the videos directory")
|
| 80 |
+
parser.add_argument("--subtitle_path", type=str, help="path to the subtitles directory")
|
| 81 |
+
parser.add_argument("--movienet_annotations_dir", type=str, help="path to the movienet annotations directory")
|
| 82 |
+
parser.add_argument("--video_clips_saving_path", type=str, help="path to save the splitted small video clips")
|
| 83 |
+
parser.add_argument("--options", nargs="+")
|
| 84 |
+
return parser.parse_args()
|
| 85 |
+
|
| 86 |
+
def time_to_seconds(subrip_time):
|
| 87 |
+
return subrip_time.hours * 3600 + subrip_time.minutes * 60 + subrip_time.seconds + subrip_time.milliseconds / 1000
|
| 88 |
+
|
| 89 |
+
def get_movie_time(subtitle_path):
|
| 90 |
+
# read the subtitle file and detect the encoding
|
| 91 |
+
with open(subtitle_path, 'rb') as f:
|
| 92 |
+
result = chardet.detect(f.read())
|
| 93 |
+
subtitles = pysrt.open(subtitle_path, encoding=result['encoding'])
|
| 94 |
+
video_time=time_to_seconds(subtitles[-1].end)
|
| 95 |
+
return video_time
|
| 96 |
+
def clean_text(subtitles_text):
|
| 97 |
+
# Remove unwanted characters except for letters, digits, and single quotes
|
| 98 |
+
subtitles_text = re.sub(r'[^a-zA-Z0-9\s\']', '', subtitles_text)
|
| 99 |
+
# Replace multiple spaces with a single space
|
| 100 |
+
subtitles_text = re.sub(r'\s+', ' ', subtitles_text)
|
| 101 |
+
return subtitles_text.strip()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MovieQAEval (GoldFish_LV):
|
| 105 |
+
|
| 106 |
+
def __init__(self,args):
|
| 107 |
+
super().__init__(args)
|
| 108 |
+
self.save_json_path = "new_workspace/clips_summary/movienet"
|
| 109 |
+
if args.use_openai_embedding:
|
| 110 |
+
self.save_pkls_path = "new_workspace/open_ai_embedding/movienet"
|
| 111 |
+
else:
|
| 112 |
+
self.save_pkls_path = "new_workspace/embedding/movienet"
|
| 113 |
+
os.makedirs(self.save_json_path, exist_ok=True)
|
| 114 |
+
movie_qa_dataset_path=args.annotation_path
|
| 115 |
+
with open(movie_qa_dataset_path, 'r') as f:
|
| 116 |
+
self.movies_dict = json.load(f)
|
| 117 |
+
self.max_sub_len=400
|
| 118 |
+
self.max_num_images=45
|
| 119 |
+
|
| 120 |
+
def _get_movie_data(self,videoname):
|
| 121 |
+
video_images_path =f"{args.videos_path}/{videoname}"
|
| 122 |
+
movie_clips_path =f"{args.video_clips_saving_path}/{videoname}"
|
| 123 |
+
subtitle_path = f"{args.subtitle_path}/{videoname}.srt"
|
| 124 |
+
annotation_file=f"{args.movienet_annotations_dir}/{videoname}.json"
|
| 125 |
+
# load the annotation file
|
| 126 |
+
with open(annotation_file, 'r') as f:
|
| 127 |
+
movie_annotation = json.load(f)
|
| 128 |
+
return video_images_path,subtitle_path,movie_annotation,movie_clips_path
|
| 129 |
+
def _store_subtitles_paragraphs(self,subtitle_path,important_data,number_of_paragraphs):
|
| 130 |
+
paragraphs=[]
|
| 131 |
+
movie_name=subtitle_path.split('/')[-1].split('.')[0]
|
| 132 |
+
# if there is no story, split the subtitles into paragraphs
|
| 133 |
+
paragraphs = split_subtitles(subtitle_path, number_of_paragraphs)
|
| 134 |
+
for i,paragraph in enumerate(paragraphs):
|
| 135 |
+
paragraph=clean_text(paragraph)
|
| 136 |
+
important_data.update({f"subtitle_{i}__{movie_name}_clip_{str(i).zfill(2)}": paragraph})
|
| 137 |
+
return important_data
|
| 138 |
+
def _get_shots_subtitles(self,movie_annotation):
|
| 139 |
+
shots_subtitles={}
|
| 140 |
+
if movie_annotation['story'] is not None:
|
| 141 |
+
for section in movie_annotation['story']:
|
| 142 |
+
for shot in section['subtitle']:
|
| 143 |
+
shot_number=shot['shot']
|
| 144 |
+
shot_subtitle=' '.join(shot['sentences'])
|
| 145 |
+
shots_subtitles[shot_number]=clean_text(shot_subtitle)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
return shots_subtitles
|
| 149 |
+
|
| 150 |
+
def prepare_input_images(self,clip_path,shots_subtitles,use_subtitles):
|
| 151 |
+
total_frames=len(os.listdir(clip_path))
|
| 152 |
+
sampling_interval=int(total_frames//self.max_num_images)
|
| 153 |
+
if sampling_interval==0:
|
| 154 |
+
sampling_interval=1
|
| 155 |
+
images=[]
|
| 156 |
+
img_placeholder = ""
|
| 157 |
+
video_frames_path = os.path.join(clip_path)
|
| 158 |
+
total_num_frames=len(os.listdir(video_frames_path))
|
| 159 |
+
sampling_interval = round(total_num_frames / self.max_num_images)
|
| 160 |
+
if sampling_interval == 0:
|
| 161 |
+
sampling_interval = 1
|
| 162 |
+
number_of_words=0
|
| 163 |
+
video_images_list=sorted(os.listdir(video_frames_path))
|
| 164 |
+
for i,frame in enumerate(video_images_list):
|
| 165 |
+
if i % sampling_interval == 0:
|
| 166 |
+
frame = Image.open(os.path.join(video_frames_path,frame)).convert("RGB")
|
| 167 |
+
frame = self.vis_processor(frame)
|
| 168 |
+
images.append(frame)
|
| 169 |
+
img_placeholder += '<Img><ImageHere>'
|
| 170 |
+
shot_num=video_images_list[i].split('_')[1]
|
| 171 |
+
if shots_subtitles.get(shot_num) is not None:
|
| 172 |
+
sub=clean_text(shots_subtitles[shot_num])
|
| 173 |
+
number_of_words+=len(sub.split(' '))
|
| 174 |
+
if number_of_words<= self.max_sub_len and use_subtitles:
|
| 175 |
+
img_placeholder+=f'<Cap>{sub}'
|
| 176 |
+
if len(images) >= self.max_num_images:
|
| 177 |
+
break
|
| 178 |
+
if len(images) ==0:
|
| 179 |
+
print("Video not found",video_frames_path)
|
| 180 |
+
|
| 181 |
+
if 0 <len(images) < self.max_num_images:
|
| 182 |
+
last_item = images[-1]
|
| 183 |
+
while len(images) < self.max_num_images:
|
| 184 |
+
images.append(last_item)
|
| 185 |
+
img_placeholder += '<Img><ImageHere>'
|
| 186 |
+
images = torch.stack(images)
|
| 187 |
+
return images,img_placeholder
|
| 188 |
+
|
| 189 |
+
def _get_movie_summaries(self,video_images_path,use_subtitles,shots_subtitles,movie_clips_path):
|
| 190 |
+
video_images_list=sorted(os.listdir(video_images_path))
|
| 191 |
+
max_caption_index = 0
|
| 192 |
+
preds = {}
|
| 193 |
+
movie_name=movie_clips_path.split('/')[-1]
|
| 194 |
+
videos_summaries=[]
|
| 195 |
+
previous_caption=""
|
| 196 |
+
batch_size=args.batch_size
|
| 197 |
+
batch_images=[]
|
| 198 |
+
batch_instructions=[]
|
| 199 |
+
clip_numbers=[]
|
| 200 |
+
clip_number=0
|
| 201 |
+
conversations=[]
|
| 202 |
+
for i in tqdm(range(0,len(video_images_list),135), desc="Inference video clips", total=len(video_images_list)/135):
|
| 203 |
+
images=[]
|
| 204 |
+
img_placeholder = ""
|
| 205 |
+
number_of_words=0
|
| 206 |
+
clip_number_str=str(clip_number).zfill(2)
|
| 207 |
+
clip_path=os.path.join(movie_clips_path,f"{movie_name}_clip_{clip_number_str}")
|
| 208 |
+
os.makedirs(clip_path, exist_ok=True)
|
| 209 |
+
conversation=""
|
| 210 |
+
for j in range(i,i+135,3):
|
| 211 |
+
if j >= len(video_images_list):
|
| 212 |
+
break
|
| 213 |
+
image_path = os.path.join(video_images_path, video_images_list[j])
|
| 214 |
+
# copy the images to clip folder
|
| 215 |
+
shutil.copy(image_path,clip_path)
|
| 216 |
+
img=Image.open(image_path)
|
| 217 |
+
images.append(self.vis_processor(img))
|
| 218 |
+
img_placeholder += '<Img><ImageHere>'
|
| 219 |
+
shot_num=int(video_images_list[j].split('_')[1])
|
| 220 |
+
if use_subtitles:
|
| 221 |
+
if shots_subtitles.get(shot_num) is not None:
|
| 222 |
+
sub=clean_text(shots_subtitles[shot_num])
|
| 223 |
+
number_of_words+=len(sub.split(' '))
|
| 224 |
+
if number_of_words<= self.max_num_words :
|
| 225 |
+
img_placeholder+=f'<Cap>{sub}'
|
| 226 |
+
conversation+=sub+" "
|
| 227 |
+
if len(images) >= self.max_num_images:
|
| 228 |
+
break
|
| 229 |
+
if len(images) ==0:
|
| 230 |
+
print("Video not found",video_images_path)
|
| 231 |
+
continue
|
| 232 |
+
if 0 <len(images) < self.max_num_images:
|
| 233 |
+
last_item = images[-1]
|
| 234 |
+
while len(images) < self.max_num_images:
|
| 235 |
+
images.append(last_item)
|
| 236 |
+
img_placeholder += '<Img><ImageHere>'
|
| 237 |
+
|
| 238 |
+
images = torch.stack(images)
|
| 239 |
+
print(images.shape)
|
| 240 |
+
clip_numbers.append(clip_number_str)
|
| 241 |
+
clip_number+=1
|
| 242 |
+
conversations.append(clean_text(conversation))
|
| 243 |
+
instruction = img_placeholder + '\n' + self.summary_instruction
|
| 244 |
+
batch_images.append(images)
|
| 245 |
+
batch_instructions.append(instruction)
|
| 246 |
+
if len(batch_images) < batch_size:
|
| 247 |
+
continue
|
| 248 |
+
# run inference for the batch
|
| 249 |
+
batch_images = torch.stack(batch_images)
|
| 250 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 251 |
+
for i,pred in enumerate(batch_pred):
|
| 252 |
+
max_caption_index += 1
|
| 253 |
+
videos_summaries.append(pred)
|
| 254 |
+
if args.use_coherent_description:
|
| 255 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 256 |
+
else:
|
| 257 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = pred
|
| 258 |
+
if conversations[i]!="" and use_subtitles:
|
| 259 |
+
preds[f'subtitle_{max_caption_index}__{movie_name}_clip_{clip_numbers[i]}'] = conversations[i]
|
| 260 |
+
|
| 261 |
+
batch_images=[]
|
| 262 |
+
batch_instructions=[]
|
| 263 |
+
clip_numbers=[]
|
| 264 |
+
conversations=[]
|
| 265 |
+
|
| 266 |
+
# run inference for the last batch
|
| 267 |
+
if len(batch_images)>0:
|
| 268 |
+
batch_images = torch.stack(batch_images)
|
| 269 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 270 |
+
for k,pred in enumerate(batch_pred):
|
| 271 |
+
max_caption_index += 1
|
| 272 |
+
videos_summaries.append(pred)
|
| 273 |
+
if args.use_coherent_description:
|
| 274 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[k]}"
|
| 275 |
+
else:
|
| 276 |
+
preds[f'caption_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = pred
|
| 277 |
+
if conversations[k]!="" and use_subtitles:
|
| 278 |
+
preds[f'subtitle_{max_caption_index}__{movie_name}_clip_{clip_numbers[k]}'] = conversations[k]
|
| 279 |
+
batch_images=[]
|
| 280 |
+
batch_instructions=[]
|
| 281 |
+
return preds
|
| 282 |
+
def movie_inference(self,videoname,use_subtitles):
|
| 283 |
+
|
| 284 |
+
embedding_path=os.path.join(self.save_pkls_path,f"{videoname}.pkl")
|
| 285 |
+
if args.index_subtitles_together:
|
| 286 |
+
file_path=os.path.join(self.save_json_path,f"{videoname}.json")
|
| 287 |
+
embedding_path=os.path.join(self.save_pkls_path,f"{videoname}.pkl")
|
| 288 |
+
else:
|
| 289 |
+
file_path=os.path.join(self.save_json_path,f"no_subtiltles_{videoname}.json")
|
| 290 |
+
embedding_path=os.path.join(self.save_pkls_path,f"no_subtiltles_{videoname}.pkl")
|
| 291 |
+
|
| 292 |
+
if args.subtitles_only:
|
| 293 |
+
file_path=os.path.join(self.save_json_path,f"subtiltles_only_{videoname}.json")
|
| 294 |
+
embedding_path=os.path.join(self.save_pkls_path,f"subtiltles_only_{videoname}.pkl")
|
| 295 |
+
|
| 296 |
+
if os.path.exists(file_path):
|
| 297 |
+
print("Already processed")
|
| 298 |
+
return file_path,embedding_path
|
| 299 |
+
|
| 300 |
+
important_data = {}
|
| 301 |
+
video_images_path,subtitle_path,movie_annotation,movie_clips_path=self._get_movie_data(videoname)
|
| 302 |
+
shots_subtitles={}
|
| 303 |
+
if use_subtitles:
|
| 304 |
+
if movie_annotation['story'] is not None:
|
| 305 |
+
shots_subtitles=self._get_shots_subtitles(movie_annotation)
|
| 306 |
+
if args.subtitles_only:
|
| 307 |
+
number_of_paragraphs=20
|
| 308 |
+
important_data=self._store_subtitles_paragraphs(subtitle_path,important_data,number_of_paragraphs)
|
| 309 |
+
else:
|
| 310 |
+
preds=self._get_movie_summaries(video_images_path,use_subtitles,shots_subtitles,movie_clips_path)
|
| 311 |
+
if len(shots_subtitles)==0 and use_subtitles:
|
| 312 |
+
number_of_paragraphs=len(preds)
|
| 313 |
+
important_data=self._store_subtitles_paragraphs(subtitle_path,important_data,number_of_paragraphs)
|
| 314 |
+
important_data.update(preds)
|
| 315 |
+
with open(file_path, 'w') as file:
|
| 316 |
+
json.dump(important_data, file, indent=4)
|
| 317 |
+
return file_path,embedding_path
|
| 318 |
+
def answer_movie_questions_RAG(self,qa_list,external_memory):
|
| 319 |
+
# get the most similar context from the external memory to this instruction
|
| 320 |
+
related_context_keys_list=[]
|
| 321 |
+
related_context_documents_list=[]
|
| 322 |
+
related_text=[]
|
| 323 |
+
questions=[]
|
| 324 |
+
prompts=[]
|
| 325 |
+
for qa in qa_list:
|
| 326 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(qa['question'])
|
| 327 |
+
related_context_documents_list.append(related_context_documents)
|
| 328 |
+
related_context_keys_list.append(related_context_keys)
|
| 329 |
+
questions.append(qa)
|
| 330 |
+
prompt=self.prepare_prompt(qa)
|
| 331 |
+
prompts.append(prompt)
|
| 332 |
+
if args.use_clips_for_info:
|
| 333 |
+
batch_pred,related_context_keys_list=self.use_clips_for_info(qa_list,related_context_keys_list,external_memory)
|
| 334 |
+
related_text.extend(related_context_keys_list)
|
| 335 |
+
else:
|
| 336 |
+
related_context_documents_text_list=[]
|
| 337 |
+
for related_context_documents,related_context_keys in zip(related_context_documents_list,related_context_keys_list):
|
| 338 |
+
related_information=""
|
| 339 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 340 |
+
for clip_name in most_related_clips:
|
| 341 |
+
clip_conversation=""
|
| 342 |
+
general_sum=""
|
| 343 |
+
for key in external_memory.documents.keys():
|
| 344 |
+
if clip_name in key and 'caption' in key:
|
| 345 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 346 |
+
if clip_name in key and 'subtitle' in key:
|
| 347 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 348 |
+
related_information+=f"{general_sum},{clip_conversation}\n"
|
| 349 |
+
|
| 350 |
+
if args.model_summary_only:
|
| 351 |
+
related_information+=f"{general_sum}\n"
|
| 352 |
+
elif args.subtitles_only:
|
| 353 |
+
related_information+=f"{clip_conversation}\n"
|
| 354 |
+
else:
|
| 355 |
+
related_information+=f"{general_sum},{clip_conversation}\n"
|
| 356 |
+
|
| 357 |
+
related_context_documents_text_list.append(related_information)
|
| 358 |
+
|
| 359 |
+
if args.use_chatgpt :
|
| 360 |
+
batch_pred=self.inference_RAG_chatGPT(prompts,related_context_documents_text_list)
|
| 361 |
+
related_text.extend(related_context_documents_text_list)
|
| 362 |
+
else:
|
| 363 |
+
batch_pred=self.inference_RAG(prompts,related_context_documents_text_list)
|
| 364 |
+
related_text.extend(related_context_documents_text_list)
|
| 365 |
+
return batch_pred ,related_text
|
| 366 |
+
def get_most_related_clips(self,related_context_keys):
|
| 367 |
+
most_related_clips=[]
|
| 368 |
+
for context_key in related_context_keys:
|
| 369 |
+
if len(context_key.split('__'))>1:
|
| 370 |
+
most_related_clips.append(context_key.split('__')[1])
|
| 371 |
+
if len(most_related_clips)==args.neighbours:
|
| 372 |
+
break
|
| 373 |
+
assert len(most_related_clips)!=0, f"No related clips found {related_context_keys}"
|
| 374 |
+
return most_related_clips
|
| 375 |
+
|
| 376 |
+
def clip_inference(self,clips_name,prompts):
|
| 377 |
+
setup_seeds(seed)
|
| 378 |
+
images_batch, instructions_batch = [], []
|
| 379 |
+
for clip_name, prompt in zip(clips_name, prompts):
|
| 380 |
+
movie_name=clip_name.split('_')[0]
|
| 381 |
+
video_images_path,subtitle_path,movie_annotation,movie_clips_path=self._get_movie_data(movie_name)
|
| 382 |
+
clip_path=os.path.join(movie_clips_path,clip_name)
|
| 383 |
+
if movie_annotation['story'] is not None:
|
| 384 |
+
shots_subtitles=self._get_shots_subtitles(movie_annotation)
|
| 385 |
+
else:
|
| 386 |
+
shots_subtitles={}
|
| 387 |
+
images,img_placeholder=self.prepare_input_images(clip_path,shots_subtitles,use_subtitles=not args.vision_only)
|
| 388 |
+
instruction = img_placeholder + '\n' + prompt
|
| 389 |
+
images_batch.append(images)
|
| 390 |
+
instructions_batch.append(instruction)
|
| 391 |
+
# run inference for the batch
|
| 392 |
+
images_batch=torch.stack(images_batch)
|
| 393 |
+
batch_pred=self.run_images(images_batch,instructions_batch)
|
| 394 |
+
return batch_pred
|
| 395 |
+
def prepare_prompt(self,qa):
|
| 396 |
+
prompt=qa["question"]+" \n As you watched in this video Choose ONE suitable answer from these mutiple choices \n"
|
| 397 |
+
for i,choice in enumerate(qa['choices']):
|
| 398 |
+
prompt+=f"option {i}: {choice} \n"
|
| 399 |
+
if args.add_unknown and args.add_confidance_score:
|
| 400 |
+
# Add unknown option
|
| 401 |
+
prompt+=f"option 5: Can't answer based on the provided information\n"
|
| 402 |
+
prompt+="Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE and aslo output a CONFIDANCE SCORE FROM 0 TO 5 representing how confident you are with your answer where 0 is the least confident and 5 is the most confident"
|
| 403 |
+
elif args.add_unknown:
|
| 404 |
+
prompt+=f"option 5: Can't answer based on the provided information\n"
|
| 405 |
+
prompt+="Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE"
|
| 406 |
+
elif args.add_confidance_score:
|
| 407 |
+
prompt+="Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE and aslo output a CONFIDANCE SCORE FROM 0 TO 5 representing how confident you are with your answer where 0 is the least confident and 5 is the most confident"
|
| 408 |
+
else:
|
| 409 |
+
prompt+="Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE"
|
| 410 |
+
return prompt
|
| 411 |
+
def use_clips_for_info(self,qa_list,related_context_keys_list,external_memory):
|
| 412 |
+
total_batch_pred=[]
|
| 413 |
+
questions=[]
|
| 414 |
+
related_information_list=[]
|
| 415 |
+
related_context_keys_list_new=[]
|
| 416 |
+
for qa,related_context_keys in zip(qa_list,related_context_keys_list):
|
| 417 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 418 |
+
|
| 419 |
+
question=qa['question']+ "\n and these are the options for the question\n\n"
|
| 420 |
+
for i,choice in enumerate(qa['choices']):
|
| 421 |
+
question+=f"option {i}: {choice} \n\n"
|
| 422 |
+
if args.add_unknown:
|
| 423 |
+
question+= "option 5: Can't answer based on the provided information\n\n"
|
| 424 |
+
question+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE"
|
| 425 |
+
else:
|
| 426 |
+
question+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE"
|
| 427 |
+
|
| 428 |
+
if args.use_choices_for_info:
|
| 429 |
+
# prompt=self.prepare_prompt(qa)
|
| 430 |
+
# prompt+=" and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 431 |
+
prompt=f"From this video extract the related information to This multichioce question and provide an explaination for your answer and If you can't find any related inforamtion, say 'I DON'T KNOW' as option 5 because maybe the questoin is not related to the video content.\n the question is :\n {question}\n your answer :"
|
| 432 |
+
else:
|
| 433 |
+
prompt=f"As you watched in this video answer this {qa['q']}\n\n and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 434 |
+
# if args.use_choices_for_info:
|
| 435 |
+
# prompt=self.prepare_prompt(qa)
|
| 436 |
+
# prompt+=" and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 437 |
+
# else:
|
| 438 |
+
# prompt=f"As you watched in this video {qa['question']}\n\n and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 439 |
+
# make the most_related_clips has unique elements (if retrival from vision summary and conversations)
|
| 440 |
+
most_related_clips=list(set(most_related_clips))
|
| 441 |
+
|
| 442 |
+
# all_info=self.clip_inference(most_related_clips,[prompt]*len(most_related_clips))
|
| 443 |
+
batch_inference=[]
|
| 444 |
+
all_info=[]
|
| 445 |
+
for related_clip in most_related_clips:
|
| 446 |
+
batch_inference.append(related_clip)
|
| 447 |
+
if len(batch_inference)<args.batch_size:
|
| 448 |
+
continue
|
| 449 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 450 |
+
batch_inference=[]
|
| 451 |
+
if len(batch_inference)>0:
|
| 452 |
+
all_info.extend(self.clip_inference(batch_inference,[prompt]*len(batch_inference)))
|
| 453 |
+
|
| 454 |
+
related_information=""
|
| 455 |
+
for info,clip_name in zip(all_info,most_related_clips):
|
| 456 |
+
clip_conversation=""
|
| 457 |
+
general_sum=""
|
| 458 |
+
for key in external_memory.documents.keys():
|
| 459 |
+
if clip_name in key and 'caption' in key:
|
| 460 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 461 |
+
if clip_name in key and 'subtitle' in key:
|
| 462 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 463 |
+
|
| 464 |
+
if args.use_coherent_description:
|
| 465 |
+
related_information+=f"question_related_information: {info},{general_sum}\n"
|
| 466 |
+
else:
|
| 467 |
+
# related_information+=f"{general_sum},{clip_conversation},question_related_information: {info}\n"
|
| 468 |
+
# related_information+=f"question_related_information: {info},{clip_conversation}\n"
|
| 469 |
+
if args.model_summary_only:
|
| 470 |
+
related_information+=f"{general_sum},question_related_information: {info}\n"
|
| 471 |
+
elif args.info_only:
|
| 472 |
+
related_information+=f"question_related_information: {info}\n"
|
| 473 |
+
elif args.subtitles_only:
|
| 474 |
+
related_information+=f"{clip_conversation},question_related_information: {info}\n"
|
| 475 |
+
else:
|
| 476 |
+
related_information+=f"{general_sum},{clip_conversation},question_related_information: {info}\n"
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
questions.append(question)
|
| 480 |
+
related_information_list.append(related_information)
|
| 481 |
+
related_context_keys.append(related_information)
|
| 482 |
+
related_context_keys_list_new.append(related_context_keys)
|
| 483 |
+
if len(questions)< args.batch_size:
|
| 484 |
+
continue
|
| 485 |
+
setup_seeds(seed)
|
| 486 |
+
if args.use_chatgpt :
|
| 487 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 488 |
+
else:
|
| 489 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 490 |
+
|
| 491 |
+
for pred in batch_pred:
|
| 492 |
+
total_batch_pred.append(pred)
|
| 493 |
+
questions=[]
|
| 494 |
+
related_information_list=[]
|
| 495 |
+
|
| 496 |
+
if len(questions)>0:
|
| 497 |
+
setup_seeds(seed)
|
| 498 |
+
if args.use_chatgpt :
|
| 499 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 500 |
+
else:
|
| 501 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 502 |
+
for pred in batch_pred:
|
| 503 |
+
total_batch_pred.append(pred)
|
| 504 |
+
return total_batch_pred,related_context_keys_list_new
|
| 505 |
+
|
| 506 |
+
def define_save_name(self):
|
| 507 |
+
save_name="subtitles" if args.index_subtitles_together else "no_subtitles"
|
| 508 |
+
save_name+="_clips_for_info" if args.use_clips_for_info else ""
|
| 509 |
+
save_name+="_chatgpt" if args.use_chatgpt else ""
|
| 510 |
+
save_name+="_vision_only" if args.vision_only else ""
|
| 511 |
+
save_name+="_model_summary_only" if args.model_summary_only else ""
|
| 512 |
+
save_name+="_subtitles_only" if args.subtitles_only else ""
|
| 513 |
+
save_name+="_choices_for_info" if args.use_choices_for_info else ""
|
| 514 |
+
save_name+="_unknown" if args.add_unknown else ""
|
| 515 |
+
save_name+="_info_only" if args.info_only else ""
|
| 516 |
+
print("save_name",save_name)
|
| 517 |
+
return save_name
|
| 518 |
+
def eval_movie_qa(self):
|
| 519 |
+
## Movie QA evaluation
|
| 520 |
+
full_questions_result=[]
|
| 521 |
+
movie_number=0
|
| 522 |
+
start=args.start
|
| 523 |
+
end=args.end
|
| 524 |
+
for movie in tqdm(self.movies_dict.keys()):
|
| 525 |
+
# if the movie has no answer, skip it
|
| 526 |
+
if self.movies_dict[movie][0]['answer'] is None:
|
| 527 |
+
continue
|
| 528 |
+
if args.start <=movie_number < args.end:
|
| 529 |
+
save_name=self.define_save_name()
|
| 530 |
+
save_dir=f"new_workspace/results/movie_qa/{args.exp_name}/{save_name}_{args.neighbours}_neighbours"
|
| 531 |
+
if os.path.exists( f"{save_dir}/{movie}.json" ):
|
| 532 |
+
print(f"Movie {movie} already processed")
|
| 533 |
+
with open(f"{save_dir}/{movie}.json", 'r') as f:
|
| 534 |
+
pred_json = json.load(f)
|
| 535 |
+
full_questions_result.extend(pred_json)
|
| 536 |
+
continue
|
| 537 |
+
use_subtitles_while_generating_summary=not args.vision_only
|
| 538 |
+
information_RAG_path,embedding_path=self.movie_inference(movie,use_subtitles_while_generating_summary)
|
| 539 |
+
external_memory=MemoryIndex(args.neighbours, use_openai=args.use_openai_embedding)
|
| 540 |
+
if os.path.exists(embedding_path):
|
| 541 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 542 |
+
else:
|
| 543 |
+
external_memory.load_documents_from_json(information_RAG_path,emdedding_path=embedding_path)
|
| 544 |
+
|
| 545 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 546 |
+
pred_json=[]
|
| 547 |
+
batch_questions=[]
|
| 548 |
+
for qa in tqdm(self.movies_dict[movie]):
|
| 549 |
+
batch_questions.append(qa)
|
| 550 |
+
if len(batch_questions)<args.batch_size:
|
| 551 |
+
continue
|
| 552 |
+
model_ans,related_text=self.answer_movie_questions_RAG(batch_questions,external_memory)
|
| 553 |
+
for qa,ans,related_info in zip(batch_questions,model_ans,related_text):
|
| 554 |
+
qa.update({'pred':ans})
|
| 555 |
+
qa.update({'related_info':related_info})
|
| 556 |
+
pred_json.append(qa)
|
| 557 |
+
batch_questions=[]
|
| 558 |
+
if len(batch_questions)>0:
|
| 559 |
+
model_ans,related_text=self.answer_movie_questions_RAG(batch_questions,external_memory)
|
| 560 |
+
for qa,ans,related_info in zip(batch_questions,model_ans,related_text):
|
| 561 |
+
qa.update({'pred':ans})
|
| 562 |
+
qa.update({'related_info':related_info})
|
| 563 |
+
pred_json.append(qa)
|
| 564 |
+
full_questions_result.extend(pred_json)
|
| 565 |
+
with open(f"{save_dir}/{movie}.json", 'w') as fp:
|
| 566 |
+
json.dump(pred_json, fp)
|
| 567 |
+
print(f"Movie {movie} prediction saved to {save_dir}/{movie}_pred_{args.neighbours}.json")
|
| 568 |
+
movie_number+=1
|
| 569 |
+
with open(f"{save_dir}/full_pred_s{start}_end{end}.json", 'w') as fp:
|
| 570 |
+
json.dump(full_questions_result, fp)
|
| 571 |
+
|
| 572 |
+
args=get_arguments()
|
| 573 |
+
|
| 574 |
+
def setup_seeds(seed):
|
| 575 |
+
random.seed(seed)
|
| 576 |
+
np.random.seed(seed)
|
| 577 |
+
torch.manual_seed(seed)
|
| 578 |
+
torch.cuda.manual_seed(seed)
|
| 579 |
+
cudnn.benchmark = False
|
| 580 |
+
cudnn.deterministic = True
|
| 581 |
+
|
| 582 |
+
import yaml
|
| 583 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 584 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 585 |
+
seed=config['run']['seed']
|
| 586 |
+
print("seed",seed)
|
| 587 |
+
|
| 588 |
+
if __name__ == "__main__":
|
| 589 |
+
setup_seeds(seed)
|
| 590 |
+
movie_qa_eval=MovieQAEval(args)
|
| 591 |
+
movie_qa_eval.eval_movie_qa()
|
evaluation/eval_goldfish_tvqa_long.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
project_dir = os.getcwd()
|
| 4 |
+
sys.path.append(project_dir)
|
| 5 |
+
import json
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from goldfish_lv import GoldFish_LV,split_subtitles,time_to_seconds
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
import re
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from PIL import Image
|
| 15 |
+
# from openai import OpenAI
|
| 16 |
+
from index import MemoryIndex
|
| 17 |
+
import pysrt
|
| 18 |
+
import chardet
|
| 19 |
+
import torch
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.backends.cudnn as cudnn
|
| 23 |
+
def str2bool(v):
|
| 24 |
+
if isinstance(v, bool):
|
| 25 |
+
return v
|
| 26 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 27 |
+
return True
|
| 28 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 29 |
+
return False
|
| 30 |
+
else:
|
| 31 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 32 |
+
|
| 33 |
+
def get_arguments():
|
| 34 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 35 |
+
parser.add_argument("--neighbours", type=int, default=-1)
|
| 36 |
+
parser.add_argument("--name", type=str,default="ckpt_92",help="name of the experiment")
|
| 37 |
+
parser.add_argument("--exp_name", type=str,default="",help="name of the experiment")
|
| 38 |
+
parser.add_argument("--add_unknown", action='store_true')
|
| 39 |
+
parser.add_argument("--use_chatgpt", action='store_true')
|
| 40 |
+
parser.add_argument("--use_choices_for_info", action='store_true')
|
| 41 |
+
parser.add_argument("--use_gt_information", action='store_true')
|
| 42 |
+
parser.add_argument("--inference_text", action='store_true')
|
| 43 |
+
parser.add_argument("--use_gt_information_with_distraction", action='store_true')
|
| 44 |
+
parser.add_argument("--num_distraction", type=int, default=2)
|
| 45 |
+
parser.add_argument("--add_confidance_score", action='store_true')
|
| 46 |
+
parser.add_argument("--use_original_video", action='store_true')
|
| 47 |
+
parser.add_argument("--use_video_embedding", action='store_true')
|
| 48 |
+
parser.add_argument("--use_clips_for_info", action='store_true')
|
| 49 |
+
parser.add_argument("--use_GT_video", action='store_true')
|
| 50 |
+
parser.add_argument("--use_gt_summary", action='store_true')
|
| 51 |
+
parser.add_argument("--index_subtitles_together", action='store_true')
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--ask_the_question_early", action='store_true')
|
| 54 |
+
parser.add_argument("--clip_in_ask_early", action='store_true')
|
| 55 |
+
parser.add_argument("--use_coherent_description", action='store_true')
|
| 56 |
+
|
| 57 |
+
parser.add_argument("--start", default=0, type=int)
|
| 58 |
+
parser.add_argument("--end", default=100000, type=int)
|
| 59 |
+
|
| 60 |
+
parser.add_argument("--vision_only", action='store_true')
|
| 61 |
+
parser.add_argument("--model_summary_only", action='store_true')
|
| 62 |
+
parser.add_argument("--subtitles_only", action='store_true')
|
| 63 |
+
parser.add_argument("--subtitles_only_after_retrieval", action='store_true')
|
| 64 |
+
parser.add_argument("--info_only", action='store_true')
|
| 65 |
+
|
| 66 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 67 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 68 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 69 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 70 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
| 71 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 72 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 73 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 74 |
+
parser.add_argument("--video_path", type=str, help="path to the video")
|
| 75 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 76 |
+
parser.add_argument("--annotation_path", type=str, help="path to the annotation file")
|
| 77 |
+
parser.add_argument("--videos_frames", type=str, help="path to the dataset extracted frames")
|
| 78 |
+
parser.add_argument("--tvqa_json_subtitles", type=str, help="path to the tvqa json subtitles")
|
| 79 |
+
parser.add_argument("--tvqa_clips_subtitles", type=str, help="path to the tvqa json")
|
| 80 |
+
parser.add_argument("--options", nargs="+")
|
| 81 |
+
return parser.parse_args()
|
| 82 |
+
|
| 83 |
+
def clean_text(subtitles_text):
|
| 84 |
+
# Remove unwanted characters except for letters, digits, and single quotes
|
| 85 |
+
subtitles_text = re.sub(r'[^a-zA-Z0-9\s\']', '', subtitles_text)
|
| 86 |
+
# Replace multiple spaces with a single space
|
| 87 |
+
subtitles_text = re.sub(r'\s+', ' ', subtitles_text)
|
| 88 |
+
return subtitles_text.strip()
|
| 89 |
+
|
| 90 |
+
class TVQAEVAL (GoldFish_LV):
|
| 91 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 92 |
+
super().__init__(args)
|
| 93 |
+
self.tv_shows_mapping={"Grey's Anatomy":"grey_frames", 'How I Met You Mother':"met_frames", 'Friends':"friends_frames", 'The Big Bang Theory':"bbt_frames", 'House M.D.':"house_frames", 'Castle':"castle_frames"}
|
| 94 |
+
self.save_long_videos_path = f"new_workspace/clips_summary/tvqa"
|
| 95 |
+
if args.use_openai_embedding:
|
| 96 |
+
self.save_embedding_path = f"new_workspace/open_ai_embedding/tvqa"
|
| 97 |
+
else:
|
| 98 |
+
self.save_embedding_path = f"new_workspace/embedding/tvqa"
|
| 99 |
+
os.makedirs(self.save_long_videos_path, exist_ok=True)
|
| 100 |
+
self.max_sub_len=400
|
| 101 |
+
self.max_num_images=45
|
| 102 |
+
self.fps=3
|
| 103 |
+
with open(args.tvqa_json_subtitles) as f:
|
| 104 |
+
self.subtitles_list=json.load(f)
|
| 105 |
+
self.subtitles={}
|
| 106 |
+
for sub in self.subtitles_list:
|
| 107 |
+
self.subtitles[sub["vid_name"]]=sub["sub"]
|
| 108 |
+
|
| 109 |
+
def _get_TVs_data(self):
|
| 110 |
+
json_file_path=args.annotation_path
|
| 111 |
+
frames_path=args.videos_frames
|
| 112 |
+
subtitle_path=args.tvqa_clips_subtitles
|
| 113 |
+
with open (json_file_path) as f:
|
| 114 |
+
tv_shows_data=json.load(f)
|
| 115 |
+
return tv_shows_data,frames_path,subtitle_path
|
| 116 |
+
def _get_shows_subtitles(self,clip_subtitles_path):
|
| 117 |
+
try :
|
| 118 |
+
with open(clip_subtitles_path, 'rb') as f:
|
| 119 |
+
result = chardet.detect(f.read())
|
| 120 |
+
clip_subtitles = pysrt.open(clip_subtitles_path, encoding=result['encoding'])
|
| 121 |
+
return clip_subtitles
|
| 122 |
+
except:
|
| 123 |
+
print("No subtitles found")
|
| 124 |
+
return []
|
| 125 |
+
def episode_inference(self,clips,folder_name,use_subtitles):
|
| 126 |
+
max_caption_index = 0
|
| 127 |
+
max_subtitle_index = 0
|
| 128 |
+
preds={}
|
| 129 |
+
important_data = {}
|
| 130 |
+
videos_summaries=[]
|
| 131 |
+
batch_size=args.batch_size
|
| 132 |
+
batch_images=[]
|
| 133 |
+
batch_instructions=[]
|
| 134 |
+
conversations=[]
|
| 135 |
+
clips_names=[]
|
| 136 |
+
for clip_name in tqdm(clips,desc="Inference Episode clips"):
|
| 137 |
+
conversation=""
|
| 138 |
+
try:
|
| 139 |
+
for subtitle in self.subtitles[clip_name]:
|
| 140 |
+
conversation+=subtitle['text']+" "
|
| 141 |
+
except:
|
| 142 |
+
pass
|
| 143 |
+
conversations.append(clean_text(conversation))
|
| 144 |
+
images,img_placeholder=self.prepare_input_images(clip_name,folder_name,use_subtitles)
|
| 145 |
+
instruction = img_placeholder + '\n' + self.summary_instruction
|
| 146 |
+
batch_images.append(images)
|
| 147 |
+
batch_instructions.append(instruction)
|
| 148 |
+
clips_names.append(clip_name)
|
| 149 |
+
if len(batch_images) < batch_size:
|
| 150 |
+
continue
|
| 151 |
+
batch_images = torch.stack(batch_images)
|
| 152 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 153 |
+
for i,pred in enumerate(batch_pred):
|
| 154 |
+
max_caption_index += 1
|
| 155 |
+
videos_summaries.append(pred)
|
| 156 |
+
if args.use_coherent_description:
|
| 157 |
+
preds[f'caption_{max_caption_index}__{clips_names[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 158 |
+
else:
|
| 159 |
+
if args.index_subtitles_together and use_subtitles:
|
| 160 |
+
if conversations[i] != "":
|
| 161 |
+
max_subtitle_index+=1
|
| 162 |
+
important_data.update({f"subtitle_{max_subtitle_index}__{clips_names[i]}": conversations[i]})
|
| 163 |
+
preds[f'caption_{max_caption_index}__{clips_names[i]}'] = pred
|
| 164 |
+
|
| 165 |
+
batch_images=[]
|
| 166 |
+
batch_instructions=[]
|
| 167 |
+
clips_names=[]
|
| 168 |
+
conversations=[]
|
| 169 |
+
# run inference for the last batch
|
| 170 |
+
if len(batch_images)>0:
|
| 171 |
+
batch_images = torch.stack(batch_images)
|
| 172 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 173 |
+
for i,pred in enumerate(batch_pred):
|
| 174 |
+
max_caption_index += 1
|
| 175 |
+
videos_summaries.append(pred)
|
| 176 |
+
if args.use_coherent_description:
|
| 177 |
+
preds[f'caption_{max_caption_index}__{clips_names[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 178 |
+
else:
|
| 179 |
+
if args.index_subtitles_together and use_subtitles:
|
| 180 |
+
if conversations[i] != "":
|
| 181 |
+
max_subtitle_index+=1
|
| 182 |
+
important_data.update({f"subtitle_{max_subtitle_index}__{clips_names[i]}": conversations[i]})
|
| 183 |
+
preds[f'caption_{max_caption_index}__{clips_names[i]}'] = pred
|
| 184 |
+
batch_images=[]
|
| 185 |
+
batch_instructions=[]
|
| 186 |
+
clips_names=[]
|
| 187 |
+
return preds,important_data
|
| 188 |
+
|
| 189 |
+
def episode_inference_only_subtitles(self,clips,tv_images_path,subtitle_path):
|
| 190 |
+
max_subtitle_index = 0
|
| 191 |
+
important_data = {}
|
| 192 |
+
for c_name in tqdm(clips,desc="Inference Episode clips"):
|
| 193 |
+
clip_subtitles_path=os.path.join(subtitle_path,c_name+".srt")
|
| 194 |
+
clip_subtitles=self._get_shows_subtitles(clip_subtitles_path)
|
| 195 |
+
conversation=""
|
| 196 |
+
if args.index_subtitles_together:
|
| 197 |
+
if self.subtitles.get(c_name,False):
|
| 198 |
+
for subtitle in self.subtitles[c_name]:
|
| 199 |
+
conversation+=subtitle['text']+" "
|
| 200 |
+
conversation=clean_text(conversation)
|
| 201 |
+
if conversation != "":
|
| 202 |
+
max_subtitle_index+=1
|
| 203 |
+
important_data.update({f"subtitle_{max_subtitle_index}__{c_name}": conversation})
|
| 204 |
+
return important_data
|
| 205 |
+
def prepare_input_images(self,clip_name,folder_name,use_subtitles):
|
| 206 |
+
tv_shows_data,frames_path,subtitle_path=self._get_TVs_data()
|
| 207 |
+
tv_images_path =os.path.join(frames_path,folder_name)
|
| 208 |
+
clip_path=os.path.join(tv_images_path,clip_name)
|
| 209 |
+
total_frames=len(os.listdir(clip_path))
|
| 210 |
+
sampling_interval=int(total_frames//self.max_num_images)
|
| 211 |
+
if sampling_interval==0:
|
| 212 |
+
sampling_interval=1
|
| 213 |
+
images=[]
|
| 214 |
+
img_placeholder = ""
|
| 215 |
+
video_frames_path = os.path.join(frames_path,folder_name,clip_name)
|
| 216 |
+
total_num_frames=len(os.listdir(video_frames_path))
|
| 217 |
+
sampling_interval = round(total_num_frames / self.max_num_images)
|
| 218 |
+
if sampling_interval == 0:
|
| 219 |
+
sampling_interval = 1
|
| 220 |
+
subtitle_text_in_interval = ""
|
| 221 |
+
history_subtitles = {}
|
| 222 |
+
number_of_sub_words=0
|
| 223 |
+
for i,frame in enumerate(sorted(os.listdir(video_frames_path))):
|
| 224 |
+
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
| 225 |
+
# we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds
|
| 226 |
+
if self.subtitles.get(clip_name,False) and use_subtitles:
|
| 227 |
+
for subtitle in self.subtitles[clip_name]:
|
| 228 |
+
if (subtitle['start'] <= (i / self.fps) <= subtitle['end']) and subtitle['text'] not in subtitle_text_in_interval:
|
| 229 |
+
if not history_subtitles.get(subtitle['text'],False):
|
| 230 |
+
subtitle_text_in_interval+=subtitle['text']+" "
|
| 231 |
+
history_subtitles[subtitle['text']]=True
|
| 232 |
+
break
|
| 233 |
+
if i % sampling_interval == 0:
|
| 234 |
+
frame = Image.open(os.path.join(video_frames_path,frame)).convert("RGB")
|
| 235 |
+
frame = self.vis_processor(frame)
|
| 236 |
+
images.append(frame)
|
| 237 |
+
img_placeholder += '<Img><ImageHere>'
|
| 238 |
+
if number_of_sub_words<self.max_sub_len and use_subtitles:
|
| 239 |
+
if subtitle_text_in_interval != "":
|
| 240 |
+
subtitle_text_in_interval=clean_text(subtitle_text_in_interval)
|
| 241 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
| 242 |
+
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
| 243 |
+
subtitle_text_in_interval = ""
|
| 244 |
+
if len(images) >= self.max_num_images:
|
| 245 |
+
break
|
| 246 |
+
if len(images) ==0:
|
| 247 |
+
print("Video not found",video_frames_path)
|
| 248 |
+
|
| 249 |
+
if 0 <len(images) < self.max_num_images:
|
| 250 |
+
last_item = images[-1]
|
| 251 |
+
while len(images) < self.max_num_images:
|
| 252 |
+
images.append(last_item)
|
| 253 |
+
img_placeholder += '<Img><ImageHere>'
|
| 254 |
+
images = torch.stack(images)
|
| 255 |
+
return images,img_placeholder
|
| 256 |
+
def clip_inference(self,clips_name,folders_name,prompts):
|
| 257 |
+
setup_seeds(seed)
|
| 258 |
+
images_batch, instructions_batch = [], []
|
| 259 |
+
for clip_name,folder_name, prompt in zip(clips_name,folders_name, prompts):
|
| 260 |
+
images,img_placeholder=self.prepare_input_images(clip_name,folder_name,use_subtitles=not args.vision_only)
|
| 261 |
+
instruction = img_placeholder + '\n' + prompt
|
| 262 |
+
images_batch.append(images)
|
| 263 |
+
instructions_batch.append(instruction)
|
| 264 |
+
# run inference for the batch
|
| 265 |
+
images_batch=torch.stack(images_batch)
|
| 266 |
+
batch_pred=self.run_images(images_batch,instructions_batch)
|
| 267 |
+
return batch_pred
|
| 268 |
+
def prepare_prompt(self,qa):
|
| 269 |
+
prompt=qa["q"]+" \n\n As you watched in this video Choose ONE suitable answer from these mutiple choices \n"
|
| 270 |
+
for i,choice in enumerate(["a0","a1","a2","a3","a4"]):
|
| 271 |
+
prompt+=f"option {i}: {qa[choice]} \n"
|
| 272 |
+
if args.add_unknown and args.add_confidance_score:
|
| 273 |
+
# Add unknown option
|
| 274 |
+
prompt+=f"option 5: Can't answer based on the provided information\n"
|
| 275 |
+
prompt+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE and aslo output a CONFIDANCE SCORE FROM 0 TO 5 representing how confident you are with your answer where 0 is the least confident and 5 is the most confident"
|
| 276 |
+
elif args.add_unknown:
|
| 277 |
+
prompt+=f"option 5: Can't answer based on the provided information\n"
|
| 278 |
+
prompt+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE"
|
| 279 |
+
elif args.add_confidance_score:
|
| 280 |
+
prompt+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE and aslo output a CONFIDANCE SCORE FROM 0 TO 5 representing how confident you are with your answer where 0 is the least confident and 5 is the most confident"
|
| 281 |
+
else:
|
| 282 |
+
prompt+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE"
|
| 283 |
+
return prompt
|
| 284 |
+
def get_most_related_clips(self,qa,related_context_keys):
|
| 285 |
+
if args.use_gt_information:
|
| 286 |
+
most_related_clips=[qa['vid_name']]
|
| 287 |
+
elif args.use_gt_information_with_distraction:
|
| 288 |
+
most_related_clips=[qa['vid_name']]
|
| 289 |
+
for context_key in related_context_keys:
|
| 290 |
+
if len(context_key.split('__'))>1:
|
| 291 |
+
most_related_clips.append(context_key.split('__')[1])
|
| 292 |
+
if len(most_related_clips)==args.num_distraction+1:
|
| 293 |
+
break
|
| 294 |
+
else:
|
| 295 |
+
most_related_clips=[]
|
| 296 |
+
for context_key in related_context_keys:
|
| 297 |
+
if len(context_key.split('__'))>1:
|
| 298 |
+
most_related_clips.append(context_key.split('__')[1])
|
| 299 |
+
if len(most_related_clips)==args.neighbours:
|
| 300 |
+
break
|
| 301 |
+
assert len(most_related_clips)!=0, f"No related clips found {related_context_keys}"
|
| 302 |
+
return most_related_clips
|
| 303 |
+
def use_clips_for_info(self,qa_list,related_context_keys_list,external_memory):
|
| 304 |
+
total_batch_pred=[]
|
| 305 |
+
questions=[]
|
| 306 |
+
related_information_list=[]
|
| 307 |
+
related_context_keys_list_new=[]
|
| 308 |
+
for qa,related_context_keys in zip(qa_list,related_context_keys_list):
|
| 309 |
+
most_related_clips=self.get_most_related_clips(qa,related_context_keys)
|
| 310 |
+
folder_name=self.tv_shows_mapping[qa['show_name']]
|
| 311 |
+
question=qa['q']+ "\nand these are the choices :\n"
|
| 312 |
+
for i,choice in enumerate(["a0","a1","a2","a3","a4"]):
|
| 313 |
+
question+=f"option {i}: {qa[choice]} \n"
|
| 314 |
+
if args.add_unknown:
|
| 315 |
+
question+= "option 5: Can't answer based on the provided information\n"
|
| 316 |
+
question+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 5 INCLUSIVE"
|
| 317 |
+
else:
|
| 318 |
+
question+="\n Your output should be THE NUMBER OF THE CORRECT ANSWER FROM THE CHOICES FROM 0 TO 4 INCLUSIVE"
|
| 319 |
+
if args.use_choices_for_info:
|
| 320 |
+
# prompt=self.prepare_prompt(qa)
|
| 321 |
+
# prompt+=" and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 322 |
+
prompt=f"From this video extract the related information to This multichioce question and provide an explaination for your answer and If you don't know the answer, say 'I DON'T KNOW' as option 5 because maybe the questoin is not related to the video content.\n the question is :\n {question}\n your answer :"
|
| 323 |
+
|
| 324 |
+
else:
|
| 325 |
+
prompt=f"As you watched in this video answer this {qa['q']}\n\n and also provide an EXPLAINATION for your answer and If you don't know the answer, say that you don't know.\n\n"
|
| 326 |
+
all_info=self.clip_inference(most_related_clips,[folder_name]*len(most_related_clips),[prompt]*len(most_related_clips))
|
| 327 |
+
# concatinate all the information together
|
| 328 |
+
related_information=""
|
| 329 |
+
for info,clip_name in zip(all_info,most_related_clips):
|
| 330 |
+
clip_conversation=""
|
| 331 |
+
general_sum=""
|
| 332 |
+
for key in external_memory.documents.keys():
|
| 333 |
+
if clip_name in key and 'caption' in key:
|
| 334 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 335 |
+
if clip_name in key and 'subtitle' in key:
|
| 336 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 337 |
+
|
| 338 |
+
if args.use_coherent_description:
|
| 339 |
+
related_information+=f"question_related_information: {info},{general_sum}\n"
|
| 340 |
+
else:
|
| 341 |
+
# related_information+=f"{general_sum},{clip_conversation},question_related_information: {info}\n"
|
| 342 |
+
# related_information+=f"question_related_information: {info},{clip_conversation}\n"
|
| 343 |
+
if args.model_summary_only:
|
| 344 |
+
related_information+=f"{general_sum},question_related_information: {info}\n"
|
| 345 |
+
elif args.info_only:
|
| 346 |
+
related_information+=f"question_related_information: {info}\n"
|
| 347 |
+
elif args.subtitles_only:
|
| 348 |
+
related_information+=f"{clip_conversation},question_related_information: {info}\n"
|
| 349 |
+
elif args.subtitles_only_after_retrieval:
|
| 350 |
+
related_information+=f"{clip_conversation},question_related_information: {info}\n"
|
| 351 |
+
else:
|
| 352 |
+
related_information+=f"{general_sum},{clip_conversation},question_related_information: {info}\n"
|
| 353 |
+
|
| 354 |
+
questions.append(question)
|
| 355 |
+
related_information_list.append(related_information)
|
| 356 |
+
related_context_keys.append(related_information)
|
| 357 |
+
related_context_keys_list_new.append(related_context_keys)
|
| 358 |
+
if len(questions)< args.batch_size:
|
| 359 |
+
continue
|
| 360 |
+
setup_seeds(seed)
|
| 361 |
+
if args.use_chatgpt :
|
| 362 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 363 |
+
else:
|
| 364 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 365 |
+
|
| 366 |
+
for pred in batch_pred:
|
| 367 |
+
total_batch_pred.append(pred)
|
| 368 |
+
questions=[]
|
| 369 |
+
related_information_list=[]
|
| 370 |
+
|
| 371 |
+
if len(questions)>0:
|
| 372 |
+
setup_seeds(seed)
|
| 373 |
+
if args.use_chatgpt :
|
| 374 |
+
batch_pred=self.inference_RAG_chatGPT(questions, related_information_list)
|
| 375 |
+
else:
|
| 376 |
+
batch_pred=self.inference_RAG(questions, related_information_list)
|
| 377 |
+
for pred in batch_pred:
|
| 378 |
+
total_batch_pred.append(pred)
|
| 379 |
+
return total_batch_pred,related_context_keys_list_new
|
| 380 |
+
def answer_TV_questions_RAG(self,qa_list,external_memory,episode_clips,episode_name):
|
| 381 |
+
related_context_keys_list,related_context_documents_list=[],[]
|
| 382 |
+
setup_seeds(seed)
|
| 383 |
+
for qa in qa_list:
|
| 384 |
+
question_choices=qa['q']+ "\n and these are the options for the question\n\n"
|
| 385 |
+
for i,choice in enumerate(["a0","a1","a2","a3","a4"]):
|
| 386 |
+
question_choices+=f"option {i}: {qa[choice]} \n\n"
|
| 387 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(question_choices)
|
| 388 |
+
|
| 389 |
+
related_context_documents_list.append(related_context_documents)
|
| 390 |
+
related_context_keys_list.append(related_context_keys)
|
| 391 |
+
|
| 392 |
+
if args.use_clips_for_info:
|
| 393 |
+
batch_pred,related_context_keys_list=self.use_clips_for_info(qa_list,related_context_keys_list,external_memory)
|
| 394 |
+
else:
|
| 395 |
+
prompts=[]
|
| 396 |
+
related_context_documents_text_list=[]
|
| 397 |
+
for qa,related_context_documents,related_context_keys in zip(qa_list,related_context_documents_list,related_context_keys_list):
|
| 398 |
+
|
| 399 |
+
related_information=""
|
| 400 |
+
most_related_clips=self.get_most_related_clips(qa,related_context_keys)
|
| 401 |
+
for clip_name in most_related_clips:
|
| 402 |
+
clip_conversation=""
|
| 403 |
+
general_sum=""
|
| 404 |
+
for key in external_memory.documents.keys():
|
| 405 |
+
if clip_name in key and 'caption' in key:
|
| 406 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 407 |
+
if clip_name in key and 'subtitle' in key:
|
| 408 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 409 |
+
# related_information+=f"{general_sum},{clip_conversation}\n"
|
| 410 |
+
if args.use_coherent_description:
|
| 411 |
+
related_information+=f"{general_sum}\n"
|
| 412 |
+
else:
|
| 413 |
+
if args.model_summary_only:
|
| 414 |
+
related_information+=f"{general_sum}\n"
|
| 415 |
+
elif args.subtitles_only:
|
| 416 |
+
related_information+=f"{clip_conversation}\n"
|
| 417 |
+
else:
|
| 418 |
+
related_information+=f"{general_sum},{clip_conversation}\n"
|
| 419 |
+
|
| 420 |
+
prompt=self.prepare_prompt(qa)
|
| 421 |
+
prompts.append(prompt)
|
| 422 |
+
related_context_documents_text_list.append(related_information)
|
| 423 |
+
|
| 424 |
+
setup_seeds(seed)
|
| 425 |
+
if args.use_chatgpt:
|
| 426 |
+
batch_pred=self.inference_RAG_chatGPT(prompts, related_context_documents_text_list)
|
| 427 |
+
else:
|
| 428 |
+
batch_pred=self.inference_RAG(prompts, related_context_documents_text_list)
|
| 429 |
+
return batch_pred ,related_context_keys_list
|
| 430 |
+
def answer_episode_questions(self,questions,information_RAG_path,embedding_path,episode_clips):
|
| 431 |
+
external_memory=MemoryIndex(args.neighbours, use_openai=args.use_openai_embedding)
|
| 432 |
+
if os.path.exists(embedding_path):
|
| 433 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 434 |
+
else:
|
| 435 |
+
external_memory.load_documents_from_json(information_RAG_path,embedding_path)
|
| 436 |
+
episode_name=information_RAG_path.split('/')[-1].split('.')[0]
|
| 437 |
+
pred_json=[]
|
| 438 |
+
batch_questions=[]
|
| 439 |
+
for qa in tqdm(questions,desc="Answering questions"):
|
| 440 |
+
batch_questions.append(qa)
|
| 441 |
+
if len(batch_questions)<args.batch_size:
|
| 442 |
+
continue
|
| 443 |
+
batch_pred,batch_related_context_keys = self.answer_TV_questions_RAG(batch_questions,external_memory,episode_clips,episode_name)
|
| 444 |
+
for pred,related_context_keys,qa in zip(batch_pred,batch_related_context_keys,batch_questions):
|
| 445 |
+
qa['pred']=pred
|
| 446 |
+
qa['related_context_keys']=related_context_keys
|
| 447 |
+
pred_json.append(qa)
|
| 448 |
+
batch_questions=[]
|
| 449 |
+
if len(batch_questions)>0:
|
| 450 |
+
batch_pred,batch_related_context_keys = self.answer_TV_questions_RAG(batch_questions,external_memory,episode_clips,episode_name)
|
| 451 |
+
for pred,related_context_keys,qa in zip(batch_pred,batch_related_context_keys,batch_questions):
|
| 452 |
+
qa['pred']=pred
|
| 453 |
+
qa['related_context_keys']=related_context_keys
|
| 454 |
+
pred_json.append(qa)
|
| 455 |
+
return pred_json
|
| 456 |
+
|
| 457 |
+
def eval_tv_shows(self,):
|
| 458 |
+
tv_shows_data,frames_path,subtitle_path=self._get_TVs_data()
|
| 459 |
+
full_questions_result=[]
|
| 460 |
+
number_of_episodes=0
|
| 461 |
+
start=args.start
|
| 462 |
+
end=args.end
|
| 463 |
+
for show in tqdm(tv_shows_data,desc="Inference TV shows"):
|
| 464 |
+
for season in tqdm(tv_shows_data[show],desc=f"Inference {show} seasons"):
|
| 465 |
+
for episode in tqdm(tv_shows_data[show][season],desc=f"Inference {show} {season} episodes"):
|
| 466 |
+
# Generate clips summary and store the important data (summary and subtitles) in json file
|
| 467 |
+
if start<=number_of_episodes<end:
|
| 468 |
+
folder_name=self.tv_shows_mapping[show]
|
| 469 |
+
tv_images_path =os.path.join(frames_path,folder_name)
|
| 470 |
+
os.makedirs(self.save_long_videos_path, exist_ok=True)
|
| 471 |
+
save_name="" if args.index_subtitles_together else "no_subtitles_"
|
| 472 |
+
save_name="subtitles_only" if args.subtitles_only else save_name
|
| 473 |
+
save_name="use_coherent_description" if args.use_coherent_description else save_name
|
| 474 |
+
file_path=os.path.join(self.save_long_videos_path,save_name+folder_name+"_"+season+"_"+episode+".json")
|
| 475 |
+
embedding_path=os.path.join(self.save_embedding_path,save_name+folder_name+"_"+season+"_"+episode+".pkl")
|
| 476 |
+
# options don't require rerunning the inference
|
| 477 |
+
save_name+="_unknown" if args.add_unknown else ""
|
| 478 |
+
save_name+="_clips_for_info" if args.use_clips_for_info else ""
|
| 479 |
+
save_name+="_chatgpt" if args.use_chatgpt else ""
|
| 480 |
+
save_name+="_choices_for_info" if args.use_choices_for_info else ""
|
| 481 |
+
save_name+="_info_only" if args.info_only else ""
|
| 482 |
+
save_name+="_subtitles_only" if args.subtitles_only else ""
|
| 483 |
+
save_name+="_subtitles_only_after_retrieval" if args.subtitles_only_after_retrieval else ""
|
| 484 |
+
if os.path.exists(file_path):
|
| 485 |
+
with open(file_path, 'r') as file:
|
| 486 |
+
important_data = json.load(file)
|
| 487 |
+
print("Already processed")
|
| 488 |
+
else:
|
| 489 |
+
episode_clips=tv_shows_data[show][season][episode]['clips']
|
| 490 |
+
if args.subtitles_only :
|
| 491 |
+
important_data=self.episode_inference_only_subtitles(episode_clips,tv_images_path,subtitle_path)
|
| 492 |
+
else:
|
| 493 |
+
preds,important_data=self.episode_inference(episode_clips,folder_name,use_subtitles=not args.vision_only)
|
| 494 |
+
important_data.update(preds)
|
| 495 |
+
# if not args.subtitles_only :
|
| 496 |
+
# summary = self.compine_summaries(important_data)
|
| 497 |
+
# preds['summary'] = summary
|
| 498 |
+
# important_data["summary"]=summary
|
| 499 |
+
with open(file_path, 'w') as file:
|
| 500 |
+
json.dump(important_data, file, indent=4)
|
| 501 |
+
# Answer questions
|
| 502 |
+
questions=tv_shows_data[show][season][episode]['questions']
|
| 503 |
+
episode_clips=tv_shows_data[show][season][episode]['clips']
|
| 504 |
+
episode_name=file_path.split('/')[-1].split('.')[0]
|
| 505 |
+
pred_json=self.answer_episode_questions(questions,file_path,embedding_path,episode_clips)
|
| 506 |
+
full_questions_result.extend(pred_json)
|
| 507 |
+
save_dir=f"new_workspace/results/tvqa/{args.exp_name}/{save_name}_{args.neighbours}_neighbours"
|
| 508 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 509 |
+
with open(f"{save_dir}/{episode_name}.json", 'w') as fp:
|
| 510 |
+
json.dump(pred_json, fp)
|
| 511 |
+
print(f"Episode {episode_name} prediction saved to {save_dir}/{episode_name}_pred_{args.neighbours}.json")
|
| 512 |
+
number_of_episodes+=1
|
| 513 |
+
with open(f"{save_dir}/full_pred_{start}_{end}.json", 'w') as fp:
|
| 514 |
+
json.dump(full_questions_result, fp)
|
| 515 |
+
print(f"TV shows prediction saved to {save_dir}/full_pred_{start}{end}.json")
|
| 516 |
+
args=get_arguments()
|
| 517 |
+
|
| 518 |
+
def setup_seeds(seed):
|
| 519 |
+
random.seed(seed)
|
| 520 |
+
np.random.seed(seed)
|
| 521 |
+
torch.manual_seed(seed)
|
| 522 |
+
torch.cuda.manual_seed(seed)
|
| 523 |
+
cudnn.benchmark = False
|
| 524 |
+
cudnn.deterministic = True
|
| 525 |
+
|
| 526 |
+
import yaml
|
| 527 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 528 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 529 |
+
seed=config['run']['seed']
|
| 530 |
+
print("seed",seed)
|
| 531 |
+
|
| 532 |
+
if __name__ == "__main__":
|
| 533 |
+
setup_seeds(seed)
|
| 534 |
+
tvqa_eval=TVQAEVAL(args)
|
| 535 |
+
tvqa_eval.eval_tv_shows()
|
evaluation/eval_minigpt4_video.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import sys
|
| 5 |
+
project_dir = os.getcwd()
|
| 6 |
+
sys.path.append(project_dir)
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
| 9 |
+
from minigpt4.conversation.conversation import CONV_VISION
|
| 10 |
+
from minigpt4.processors.blip_processors import Blip2ImageTrainProcessor,BlipCaptionProcessor
|
| 11 |
+
from minigpt4.datasets.datasets.video_datasets import VideoChatGPTEvalDataset,VideoChatGPTEval_consistancy,Video_validation_Dataset,TVQAEVAL
|
| 12 |
+
|
| 13 |
+
parser = eval_parser()
|
| 14 |
+
parser.add_argument("--dataset", type=str, default='msvd', help="dataset to evaluate")
|
| 15 |
+
parser.add_argument("--add_subtitles",action='store_true',help="whether to add subtitles to the video")
|
| 16 |
+
parser.add_argument("--name", type=str, default='test', help="evaluation name")
|
| 17 |
+
parser.add_argument("--videos_path", type=str, default='videos path', help="path to videos")
|
| 18 |
+
parser.add_argument("--subtitles_path", type=str, default='subtitles path', help="path to subtitles")
|
| 19 |
+
parser.add_argument("--ann_path", type=str, default='annotations path', help="path to annotations")
|
| 20 |
+
|
| 21 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
| 22 |
+
parser.add_argument("--start", type=int, default=0, help="start from video number")
|
| 23 |
+
parser.add_argument("--end", type=int, default=10000000, help="end at video number")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
print(args.ckpt)
|
| 27 |
+
print(args.name)
|
| 28 |
+
print(args.cfg_path)
|
| 29 |
+
if "test_configs/mistral_test_config.yaml" == args.cfg_path:
|
| 30 |
+
llm_name="mistral"
|
| 31 |
+
else:
|
| 32 |
+
llm_name="llama2"
|
| 33 |
+
print("using captions",args.add_subtitles)
|
| 34 |
+
model, vis_processor,whisper_gpu_id,minigpt4_gpu_id,answer_module_gpu_id = init_model(args)
|
| 35 |
+
conv_temp = CONV_VISION.copy()
|
| 36 |
+
conv_temp.system = ""
|
| 37 |
+
if args.dataset == 'video_chatgpt_generic':
|
| 38 |
+
ann_path=args.ann_path
|
| 39 |
+
videos_path= args.videos_path
|
| 40 |
+
subtitles_path=args.subtitles_path
|
| 41 |
+
annotations_keys=['Q','A','video_name']
|
| 42 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 43 |
+
elif args.dataset == 'video_chatgpt_temporal':
|
| 44 |
+
ann_path=args.ann_path
|
| 45 |
+
videos_path= args.videos_path
|
| 46 |
+
subtitles_path=args.subtitles_path
|
| 47 |
+
annotations_keys=['Q','A','video_name']
|
| 48 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 49 |
+
elif args.dataset == 'video_chatgpt_consistency':
|
| 50 |
+
ann_path=args.ann_path
|
| 51 |
+
videos_path= args.videos_path
|
| 52 |
+
subtitles_path=args.subtitles_path
|
| 53 |
+
annotations_keys=[['Q1','Q2'],'A','video_name']
|
| 54 |
+
data = VideoChatGPTEval_consistancy(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 55 |
+
|
| 56 |
+
elif args.dataset == 'msrvtt':
|
| 57 |
+
ann_path=args.ann_path
|
| 58 |
+
videos_path= args.videos_path
|
| 59 |
+
subtitles_path=args.subtitles_path
|
| 60 |
+
annotations_keys=['question','answer','video_id']
|
| 61 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 62 |
+
|
| 63 |
+
elif args.dataset == 'msvd':
|
| 64 |
+
ann_path=args.ann_path
|
| 65 |
+
videos_path= args.videos_path
|
| 66 |
+
subtitles_path="" # no subtitles for msvd as these videos don't have audio
|
| 67 |
+
annotations_keys=['question','answer','video_id']
|
| 68 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 69 |
+
elif args.dataset == 'activitynet':
|
| 70 |
+
ann_path=args.ann_path
|
| 71 |
+
videos_path= args.videos_path
|
| 72 |
+
subtitles_path=args.subtitles_path
|
| 73 |
+
annotations_keys=['question','answer','video_id']
|
| 74 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 75 |
+
elif args.dataset == 'tgif':
|
| 76 |
+
ann_path="datasets/evaluation_datasets/tgif/Test_frameqa_question.json"
|
| 77 |
+
videos_path= args.videos_path
|
| 78 |
+
subtitles_path="" # no subtitles for TGIF as these videos don't have audio
|
| 79 |
+
annotations_keys=['question','answer','gif_name']
|
| 80 |
+
data = VideoChatGPTEvalDataset(vis_processor, videos_path, ann_path,subtitles_path,annotations_keys, add_subtitles=False,llm_name=llm_name)
|
| 81 |
+
elif args.dataset == 'tvqa':
|
| 82 |
+
# TVQA dataset
|
| 83 |
+
ann_path="datasets/evaluation_datasets/tvqa_short/tvqa_val.json"
|
| 84 |
+
videos_path= args.videos_path
|
| 85 |
+
subtitles_path=args.subtitles_path
|
| 86 |
+
data = TVQAEVAL(vis_processor, videos_path, ann_path,subtitles_path,add_subtitles=args.add_subtitles,llm_name=llm_name)
|
| 87 |
+
|
| 88 |
+
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
| 89 |
+
|
| 90 |
+
minigpt4_predict = []
|
| 91 |
+
sub="subtitles" if args.add_subtitles else "no_subtitles"
|
| 92 |
+
if args.start == 0 and args.end == 10000000:
|
| 93 |
+
save_path = f'results/{args.name}_{args.dataset}_{sub}.json'
|
| 94 |
+
else:
|
| 95 |
+
print("start from video number",args.start)
|
| 96 |
+
print("end at video number",args.end)
|
| 97 |
+
save_path = f'results/{args.name}_{args.dataset}_{sub}_{args.start}_{args.end}.json'
|
| 98 |
+
|
| 99 |
+
os.makedirs("results", exist_ok=True)
|
| 100 |
+
c=0
|
| 101 |
+
pred_result = {}
|
| 102 |
+
gt_result = {}
|
| 103 |
+
if args.dataset == 'video_chatgpt_consistency':
|
| 104 |
+
for images, texts_1,texts_2, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
| 105 |
+
if args.start<= c <args.end :
|
| 106 |
+
texts_q1 = prepare_texts(texts_1, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
| 107 |
+
texts_q2 = prepare_texts(texts_2, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
| 108 |
+
models_answers_q1 = model.generate(images, texts_q1, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
| 109 |
+
models_answers_q2 = model.generate(images, texts_q2, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
| 110 |
+
for video_id,model_answer_q1,model_answer_q2, gt_answer,text_q1,text_q2 in zip(videos_ids,models_answers_q1,models_answers_q2, gt_answers,texts_q1,texts_q2):
|
| 111 |
+
result = dict()
|
| 112 |
+
result['video_name'] = video_id
|
| 113 |
+
result['Q1'] = text_q1.split('\n')[-1].replace('[/INST]','')
|
| 114 |
+
result['Q2'] = text_q2.split('\n')[-1].replace('[/INST]','')
|
| 115 |
+
result['A'] = gt_answer
|
| 116 |
+
result['pred1'] = model_answer_q1
|
| 117 |
+
result['pred2'] = model_answer_q2
|
| 118 |
+
pred_result[video_id] = [model_answer_q1,model_answer_q2]
|
| 119 |
+
gt_result[video_id] = [gt_answer]
|
| 120 |
+
minigpt4_predict.append(result)
|
| 121 |
+
# save results every 100 videos to avoid losing results
|
| 122 |
+
if c%100==0:
|
| 123 |
+
with open(save_path, 'w') as f:
|
| 124 |
+
json.dump(minigpt4_predict, f)
|
| 125 |
+
if c >= args.end :
|
| 126 |
+
break
|
| 127 |
+
c+=1
|
| 128 |
+
|
| 129 |
+
elif args.dataset == 'tvr':
|
| 130 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
| 131 |
+
if args.start<= c <args.end :
|
| 132 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
| 133 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
| 134 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
| 135 |
+
result = dict()
|
| 136 |
+
result['video_name'] = video_id
|
| 137 |
+
result['Q'] = text.split('\n')[-1].replace('[/INST]','')
|
| 138 |
+
result['A'] = gt_answer
|
| 139 |
+
result['pred'] = model_answer
|
| 140 |
+
pred_result[video_id] = [model_answer]
|
| 141 |
+
gt_result[video_id] = [gt_answer]
|
| 142 |
+
minigpt4_predict.append(result)
|
| 143 |
+
# save results every 100 videos to avoid losing results
|
| 144 |
+
if c%100==0:
|
| 145 |
+
with open(save_path, 'w') as f:
|
| 146 |
+
json.dump(minigpt4_predict, f)
|
| 147 |
+
if c >= args.end :
|
| 148 |
+
break
|
| 149 |
+
c+=1
|
| 150 |
+
elif args.dataset == 'ego_schema' or args.dataset == 'tvqa' or args.dataset == 'tvqa_long_videos':
|
| 151 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
| 152 |
+
if args.start<= c <args.end :
|
| 153 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
| 154 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
| 155 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
| 156 |
+
result = dict()
|
| 157 |
+
result['video_name'] = video_id
|
| 158 |
+
if args.dataset == 'tvqa_long_videos':
|
| 159 |
+
result['Q'] = text.split('\n\n')[1:]
|
| 160 |
+
else:
|
| 161 |
+
result['Q'] = text.split('\n')[1:]
|
| 162 |
+
result['A'] = gt_answer
|
| 163 |
+
result['pred'] = model_answer
|
| 164 |
+
pred_result[video_id] = [model_answer]
|
| 165 |
+
gt_result[video_id] = [gt_answer]
|
| 166 |
+
minigpt4_predict.append(result)
|
| 167 |
+
# save results every 100 videos to avoid losing results
|
| 168 |
+
if c%100==0:
|
| 169 |
+
with open(save_path, 'w') as f:
|
| 170 |
+
json.dump(minigpt4_predict, f)
|
| 171 |
+
if c >= args.end :
|
| 172 |
+
break
|
| 173 |
+
c+=1
|
| 174 |
+
else:
|
| 175 |
+
for images, texts, gt_answers, lengths,videos_ids in tqdm(eval_dataloader,desc=f"Eval {args.dataset}"):
|
| 176 |
+
if args.start<= c <args.end :
|
| 177 |
+
texts = prepare_texts(texts, conv_temp, template='', lengths=lengths) # warp the texts with conversation template
|
| 178 |
+
models_answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=lengths,num_beams=1)
|
| 179 |
+
for video_id,model_answer, gt_answer,text in zip(videos_ids,models_answers, gt_answers,texts):
|
| 180 |
+
result = dict()
|
| 181 |
+
result['video_name'] = video_id
|
| 182 |
+
result['Q'] = text.split('\n')[-1].replace('[/INST]','')
|
| 183 |
+
result['A'] = gt_answer
|
| 184 |
+
result['pred'] = model_answer
|
| 185 |
+
pred_result[video_id] = [model_answer]
|
| 186 |
+
gt_result[video_id] = [gt_answer]
|
| 187 |
+
minigpt4_predict.append(result)
|
| 188 |
+
# save results every 100 videos to avoid losing results
|
| 189 |
+
if c%100==0:
|
| 190 |
+
with open(save_path, 'w') as f:
|
| 191 |
+
json.dump(minigpt4_predict, f)
|
| 192 |
+
if c >= args.end :
|
| 193 |
+
break
|
| 194 |
+
c+=1
|
| 195 |
+
|
| 196 |
+
with open(save_path, 'w') as f:
|
| 197 |
+
json.dump(minigpt4_predict, f)
|
| 198 |
+
print("saved results to",save_path)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
evaluation/eval_retrieval_acc_tvqa.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
project_dir = os.getcwd()
|
| 4 |
+
sys.path.append(project_dir)
|
| 5 |
+
import json
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from goldfish_lv import GoldFish_LV,split_subtitles,time_to_seconds
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
import re
|
| 13 |
+
from PIL import Image
|
| 14 |
+
# from openai import OpenAI
|
| 15 |
+
from index import MemoryIndex
|
| 16 |
+
import torch
|
| 17 |
+
import random
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch.backends.cudnn as cudnn
|
| 20 |
+
|
| 21 |
+
def get_arguments():
|
| 22 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 23 |
+
parser.add_argument("--neighbours", type=int, default=-1)
|
| 24 |
+
parser.add_argument("--name", type=str,default="ckpt_92",help="name of the experiment")
|
| 25 |
+
parser.add_argument("--exp_name", type=str,default="",help="name of the experiment")
|
| 26 |
+
parser.add_argument("--add_unknown", action='store_true')
|
| 27 |
+
parser.add_argument("--use_chatgpt", action='store_true')
|
| 28 |
+
parser.add_argument("--use_choices_for_info", action='store_true')
|
| 29 |
+
parser.add_argument("--use_gt_information", action='store_true')
|
| 30 |
+
parser.add_argument("--inference_text", action='store_true')
|
| 31 |
+
parser.add_argument("--use_gt_information_with_distraction", action='store_true')
|
| 32 |
+
parser.add_argument("--num_distraction", type=int, default=2)
|
| 33 |
+
parser.add_argument("--add_confidance_score", action='store_true')
|
| 34 |
+
parser.add_argument("--use_original_video", action='store_true')
|
| 35 |
+
parser.add_argument("--use_video_embedding", action='store_true')
|
| 36 |
+
parser.add_argument("--use_clips_for_info", action='store_true')
|
| 37 |
+
parser.add_argument("--use_GT_video", action='store_true')
|
| 38 |
+
parser.add_argument("--use_gt_summary", action='store_true')
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--ask_the_question_early", action='store_true')
|
| 41 |
+
parser.add_argument("--clip_in_ask_early", action='store_true')
|
| 42 |
+
parser.add_argument("--use_coherent_description", action='store_true')
|
| 43 |
+
|
| 44 |
+
parser.add_argument("--start", default=0, type=int)
|
| 45 |
+
parser.add_argument("--end", default=100000, type=int)
|
| 46 |
+
|
| 47 |
+
parser.add_argument("--vision_only", action='store_true')
|
| 48 |
+
parser.add_argument("--model_summary_only", action='store_true')
|
| 49 |
+
parser.add_argument("--subtitles_only", action='store_true')
|
| 50 |
+
parser.add_argument("--subtitles_only_after_retrieval", action='store_true')
|
| 51 |
+
parser.add_argument("--info_only", action='store_true')
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 54 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 55 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 56 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 57 |
+
parser.add_argument("--max_new_tokens", type=int, default=300)
|
| 58 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 59 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 60 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 61 |
+
parser.add_argument("--video_path", type=str, help="path to the video")
|
| 62 |
+
parser.add_argument("--options", nargs="+")
|
| 63 |
+
return parser.parse_args()
|
| 64 |
+
|
| 65 |
+
def clean_text(subtitles_text):
|
| 66 |
+
# Remove unwanted characters except for letters, digits, and single quotes
|
| 67 |
+
subtitles_text = re.sub(r'[^a-zA-Z0-9\s\']', '', subtitles_text)
|
| 68 |
+
# Replace multiple spaces with a single space
|
| 69 |
+
subtitles_text = re.sub(r'\s+', ' ', subtitles_text)
|
| 70 |
+
return subtitles_text.strip()
|
| 71 |
+
|
| 72 |
+
class TVQAEVALRetrieval (GoldFish_LV):
|
| 73 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 74 |
+
super().__init__(args)
|
| 75 |
+
self.tv_shows_mapping={"Grey's Anatomy":"grey_frames", 'How I Met You Mother':"met_frames", 'Friends':"friends_frames", 'The Big Bang Theory':"bbt_frames", 'House M.D.':"house_frames", 'Castle':"castle_frames"}
|
| 76 |
+
self.save_long_videos_path = f"workspace/results/tv_shows/{args.name}"
|
| 77 |
+
os.makedirs(self.save_long_videos_path, exist_ok=True)
|
| 78 |
+
self.max_sub_len=400
|
| 79 |
+
self.max_num_images=45
|
| 80 |
+
self.fps=3
|
| 81 |
+
with open("datasets/evaluation_datasets/goldfish_eval_datasets/tvqa/tvqa_preprocessed_subtitles.json") as f:
|
| 82 |
+
self.subtitles_list=json.load(f)
|
| 83 |
+
self.subtitles={}
|
| 84 |
+
for sub in self.subtitles_list:
|
| 85 |
+
self.subtitles[sub["vid_name"]]=sub["sub"]
|
| 86 |
+
|
| 87 |
+
def _get_TVs_data(self):
|
| 88 |
+
json_file_path="datasets/evaluation_datasets/long_video_datasets/tvqa/tvqa_val_edited.json"
|
| 89 |
+
frames_path="/ibex/project/c2090/datasets/TVR_dataset/videos/video_files/frames_hq/"
|
| 90 |
+
subtitle_path="/ibex/project/c2090/datasets/TVR_dataset/videos/tvqa_subtitles"
|
| 91 |
+
with open (json_file_path) as f:
|
| 92 |
+
tv_shows_data=json.load(f)
|
| 93 |
+
return tv_shows_data,frames_path,subtitle_path
|
| 94 |
+
|
| 95 |
+
return vision_questions,subtitle_questions,frames_path
|
| 96 |
+
def episode_inference(self,video_frames_path,qa,use_subtitles):
|
| 97 |
+
batch_prepared_images,batch_img_placeholder,gt_clip_numbers=self.prepare_input_images(video_frames_path,qa,use_subtitles,n_clips=10)
|
| 98 |
+
preds={}
|
| 99 |
+
batch_instructions=[]
|
| 100 |
+
batch_images=[]
|
| 101 |
+
important_data = {}
|
| 102 |
+
conversations=[]
|
| 103 |
+
clips_numbers=[]
|
| 104 |
+
for clip_number,images,img_placeholder in zip(range(len(batch_prepared_images)),batch_prepared_images,batch_img_placeholder):
|
| 105 |
+
instruction = img_placeholder + '\n' + self.summary_instruction
|
| 106 |
+
batch_images.append(images)
|
| 107 |
+
batch_instructions.append(instruction)
|
| 108 |
+
conv=img_placeholder.replace('<Img><ImageHere>','')
|
| 109 |
+
conv=conv.replace('<Cap>',' ')
|
| 110 |
+
conversations.append(conv.strip())
|
| 111 |
+
clips_numbers.append(clip_number)
|
| 112 |
+
if len(batch_images) < args.batch_size:
|
| 113 |
+
continue
|
| 114 |
+
batch_images = torch.stack(batch_images)
|
| 115 |
+
setup_seeds(seed)
|
| 116 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 117 |
+
for i,pred in enumerate(batch_pred):
|
| 118 |
+
if args.use_coherent_description:
|
| 119 |
+
preds[f'caption__{clips_numbers[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 120 |
+
else:
|
| 121 |
+
if use_subtitles:
|
| 122 |
+
if conversations[i] != "":
|
| 123 |
+
important_data.update({f"subtitle__{clips_numbers[i]}": conversations[i]})
|
| 124 |
+
preds[f'caption__{clips_numbers[i]}'] = pred
|
| 125 |
+
|
| 126 |
+
batch_images=[]
|
| 127 |
+
batch_instructions=[]
|
| 128 |
+
conversations=[]
|
| 129 |
+
clips_numbers=[]
|
| 130 |
+
# run inference for the last batch
|
| 131 |
+
if len(batch_images)>0:
|
| 132 |
+
batch_images = torch.stack(batch_images)
|
| 133 |
+
batch_pred=self.run_images(batch_images,batch_instructions)
|
| 134 |
+
for i,pred in enumerate(batch_pred):
|
| 135 |
+
if args.use_coherent_description:
|
| 136 |
+
preds[f'caption__{clips_numbers[i]}'] = f"model_summary :{pred}\nVideo conversation :{conversations[i]}"
|
| 137 |
+
else:
|
| 138 |
+
if use_subtitles:
|
| 139 |
+
if conversations[i] != "":
|
| 140 |
+
important_data.update({f"subtitle__{clips_numbers[i]}": conversations[i]})
|
| 141 |
+
preds[f'caption__{clips_numbers[i]}'] = pred
|
| 142 |
+
batch_images=[]
|
| 143 |
+
batch_instructions=[]
|
| 144 |
+
clips_numbers=[]
|
| 145 |
+
return preds,important_data ,gt_clip_numbers
|
| 146 |
+
|
| 147 |
+
def episode_inference_only_subtitles(self,video_frames_path,qa):
|
| 148 |
+
use_subtitles=True
|
| 149 |
+
batch_prepared_images,batch_img_placeholder,gt_clip_numbers=self.prepare_input_images(video_frames_path,qa,use_subtitles,n_clips=10)
|
| 150 |
+
important_data = {}
|
| 151 |
+
for clip_number,img_placeholder in enumerate(batch_img_placeholder) :
|
| 152 |
+
conv=img_placeholder.replace('<Img><ImageHere>','')
|
| 153 |
+
conv=conv.replace('<Cap>',' ')
|
| 154 |
+
conversation=conv.strip()
|
| 155 |
+
conversation=clean_text(conversation)
|
| 156 |
+
if conversation != "":
|
| 157 |
+
important_data.update({f"subtitle__{clip_number}": conversation})
|
| 158 |
+
return important_data ,gt_clip_numbers
|
| 159 |
+
def prepare_input_images(self,video_frames_path,qa,use_subtitles,n_clips=10):
|
| 160 |
+
batch_images=[]
|
| 161 |
+
batch_img_placeholder = []
|
| 162 |
+
clip_name=video_frames_path.split('/')[-1]
|
| 163 |
+
images=[]
|
| 164 |
+
img_placeholders = []
|
| 165 |
+
gt_clip_numbers = set()
|
| 166 |
+
gt_start_time=qa['ts'][0]
|
| 167 |
+
gt_end_time=qa['ts'][1]
|
| 168 |
+
total_num_frames=len(os.listdir(video_frames_path))
|
| 169 |
+
subtitle_text_in_interval = ""
|
| 170 |
+
history_subtitles = {}
|
| 171 |
+
number_of_sub_words=0
|
| 172 |
+
# samples_per_clip = total_num_frames // n_clips
|
| 173 |
+
samples_per_clip=45
|
| 174 |
+
clip_num=0
|
| 175 |
+
for i,frame in enumerate(sorted(os.listdir(video_frames_path))):
|
| 176 |
+
# Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
|
| 177 |
+
# we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds
|
| 178 |
+
if self.subtitles.get(clip_name,False) and use_subtitles:
|
| 179 |
+
for subtitle in self.subtitles[clip_name]:
|
| 180 |
+
if (subtitle['start'] <= (i / self.fps) <= subtitle['end']) and subtitle['text'] not in subtitle_text_in_interval:
|
| 181 |
+
if not history_subtitles.get(subtitle['text'],False):
|
| 182 |
+
subtitle_text_in_interval+=subtitle['text']+" "
|
| 183 |
+
history_subtitles[subtitle['text']]=True
|
| 184 |
+
break
|
| 185 |
+
if gt_start_time<=(i/self.fps)<= gt_end_time:
|
| 186 |
+
gt_clip_numbers.add(clip_num)
|
| 187 |
+
if i % samples_per_clip == 0 and i != 0:
|
| 188 |
+
# here we have one clip , let's sample 45 frames from images array
|
| 189 |
+
sample_value=len(images)//self.max_num_images
|
| 190 |
+
if sample_value==0:
|
| 191 |
+
sample_value=1
|
| 192 |
+
frames_indices = [i for i in range(0, len(images), sample_value)]
|
| 193 |
+
samples_images=[]
|
| 194 |
+
img_placeholder=''
|
| 195 |
+
for j in frames_indices:
|
| 196 |
+
samples_images.append(images[j])
|
| 197 |
+
img_placeholder+=img_placeholders[j]
|
| 198 |
+
if len(samples_images) >= self.max_num_images:
|
| 199 |
+
break
|
| 200 |
+
if 0 <len(samples_images) < self.max_num_images:
|
| 201 |
+
last_item = samples_images[-1]
|
| 202 |
+
while len(samples_images) < self.max_num_images:
|
| 203 |
+
samples_images.append(last_item)
|
| 204 |
+
img_placeholder += '<Img><ImageHere>'
|
| 205 |
+
samples_images = torch.stack(samples_images)
|
| 206 |
+
batch_images.append(samples_images)
|
| 207 |
+
batch_img_placeholder.append(img_placeholder)
|
| 208 |
+
img_placeholders =[]
|
| 209 |
+
images = []
|
| 210 |
+
clip_num+=1
|
| 211 |
+
|
| 212 |
+
frame = Image.open(os.path.join(video_frames_path,frame)).convert("RGB")
|
| 213 |
+
frame = self.vis_processor(frame)
|
| 214 |
+
images.append(frame)
|
| 215 |
+
img_placeholder = '<Img><ImageHere>'
|
| 216 |
+
if number_of_sub_words<self.max_sub_len and use_subtitles:
|
| 217 |
+
if subtitle_text_in_interval != "":
|
| 218 |
+
subtitle_text_in_interval=clean_text(subtitle_text_in_interval)
|
| 219 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
| 220 |
+
number_of_sub_words+=len(subtitle_text_in_interval.split(' '))
|
| 221 |
+
subtitle_text_in_interval = ""
|
| 222 |
+
img_placeholders.append(img_placeholder)
|
| 223 |
+
return batch_images,batch_img_placeholder,list(gt_clip_numbers)
|
| 224 |
+
|
| 225 |
+
def test_retrieval(self,indexed_data_path,qa,gt_clip_numbers):
|
| 226 |
+
external_memory=MemoryIndex(args.neighbours, use_openai=True)
|
| 227 |
+
external_memory.load_documents_from_json(indexed_data_path)
|
| 228 |
+
question=qa['desc']
|
| 229 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(question)
|
| 230 |
+
print(f"related_context_keys {related_context_keys}")
|
| 231 |
+
print(f"gt_clip_numbers {gt_clip_numbers}")
|
| 232 |
+
for key in related_context_keys:
|
| 233 |
+
clip_idx=int(key.split('__')[-1])
|
| 234 |
+
if clip_idx in gt_clip_numbers:
|
| 235 |
+
return True
|
| 236 |
+
return False
|
| 237 |
+
|
| 238 |
+
def get_ground_truth_clip(self,video_frames_path,qa):
|
| 239 |
+
gt_clip_numbers = set()
|
| 240 |
+
gt_start_time=qa['ts'][0]
|
| 241 |
+
gt_end_time=qa['ts'][1]
|
| 242 |
+
samples_per_clip=45
|
| 243 |
+
clip_num=0
|
| 244 |
+
for i in range(len(os.listdir(video_frames_path))):
|
| 245 |
+
if gt_start_time<=(i/self.fps)<= gt_end_time:
|
| 246 |
+
gt_clip_numbers.add(clip_num)
|
| 247 |
+
if i % samples_per_clip == 0 and i != 0:
|
| 248 |
+
clip_num+=1
|
| 249 |
+
return list(gt_clip_numbers)
|
| 250 |
+
|
| 251 |
+
def eval_tv_shows(self,):
|
| 252 |
+
vision_questions,subtitle_questions,frames_path=self._get_TVs_data()
|
| 253 |
+
number_of_videos=0
|
| 254 |
+
start=args.start
|
| 255 |
+
end=args.end
|
| 256 |
+
if args.exp_name=="vision":
|
| 257 |
+
questions=vision_questions
|
| 258 |
+
else:
|
| 259 |
+
questions=subtitle_questions
|
| 260 |
+
correct_retrieval=0
|
| 261 |
+
wrong_retrieval=0
|
| 262 |
+
for qa in questions:
|
| 263 |
+
# Generate clips summary and store the important data (summary and subtitles) in json file
|
| 264 |
+
if start<=number_of_videos<end:
|
| 265 |
+
show_name=qa['vid_name'].split('_')[0]
|
| 266 |
+
if self.tv_shows_mapping.get(show_name,False):
|
| 267 |
+
folder_name=self.tv_shows_mapping[show_name]
|
| 268 |
+
else:
|
| 269 |
+
folder_name=self.tv_shows_mapping['bbt']
|
| 270 |
+
|
| 271 |
+
clip_frames_path =os.path.join(frames_path,folder_name,qa['vid_name'])
|
| 272 |
+
save_name="subtitles_only" if args.subtitles_only else "vision_only" if args.vision_only else "vision_subtitles"
|
| 273 |
+
indexed_data_path=os.path.join(self.save_long_videos_path,f"{qa['vid_name']}_{args.exp_name}_{save_name}_num_{number_of_videos}.json")
|
| 274 |
+
if not os.path.exists(indexed_data_path):
|
| 275 |
+
if args.subtitles_only :
|
| 276 |
+
# TODO
|
| 277 |
+
important_data,gt_clip_numbers=self.episode_inference_only_subtitles(clip_frames_path,qa)
|
| 278 |
+
else:
|
| 279 |
+
preds,important_data ,gt_clip_numbers=self.episode_inference(clip_frames_path,qa,use_subtitles=not args.vision_only)
|
| 280 |
+
important_data.update(preds)
|
| 281 |
+
with open(indexed_data_path, 'w') as file:
|
| 282 |
+
json.dump(important_data, file, indent=4)
|
| 283 |
+
else:
|
| 284 |
+
gt_clip_numbers=self.get_ground_truth_clip(clip_frames_path,qa)
|
| 285 |
+
retrieval_res=self.test_retrieval(indexed_data_path,qa,gt_clip_numbers)
|
| 286 |
+
if retrieval_res==True:
|
| 287 |
+
correct_retrieval+=1
|
| 288 |
+
else:
|
| 289 |
+
wrong_retrieval+=1
|
| 290 |
+
number_of_videos+=1
|
| 291 |
+
|
| 292 |
+
save_dir=f"workspace/eval/retrieval/{args.exp_name}_neighbors_{args.neighbours}"
|
| 293 |
+
save_dir+="_subtitles_only" if args.subtitles_only else "_vision_only" if args.vision_only else "_vision_subtitles"
|
| 294 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 295 |
+
with open(f"{save_dir}/s{start}_end{end}.json", 'w') as fp:
|
| 296 |
+
json.dump({"correct":correct_retrieval,"wrong":wrong_retrieval}, fp)
|
| 297 |
+
args=get_arguments()
|
| 298 |
+
|
| 299 |
+
def setup_seeds(seed):
|
| 300 |
+
random.seed(seed)
|
| 301 |
+
np.random.seed(seed)
|
| 302 |
+
torch.manual_seed(seed)
|
| 303 |
+
torch.cuda.manual_seed(seed)
|
| 304 |
+
cudnn.benchmark = False
|
| 305 |
+
cudnn.deterministic = True
|
| 306 |
+
|
| 307 |
+
import yaml
|
| 308 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 309 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 310 |
+
seed=config['run']['seed']
|
| 311 |
+
print("seed",seed)
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
setup_seeds(seed)
|
| 315 |
+
tvqa_eval=TVQAEVALRetrieval(args)
|
| 316 |
+
tvqa_eval.eval_tv_shows()
|
filter_json.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# === 配置路径 ===
|
| 5 |
+
# 视频所在的文件夹
|
| 6 |
+
video_dir = 'datasets/stage3/videos'
|
| 7 |
+
# 原始 JSON 文件路径
|
| 8 |
+
json_path = 'datasets/stage3/video_instruct_data.json'
|
| 9 |
+
|
| 10 |
+
def filter_data():
|
| 11 |
+
print(f"正在扫描视频文件夹: {video_dir} ...")
|
| 12 |
+
|
| 13 |
+
if not os.path.exists(video_dir):
|
| 14 |
+
print(f"错误: 找不到视频文件夹 {video_dir}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
# 1. 获取所有存在的视频 ID (去掉文件名后缀,比如 .mp4)
|
| 18 |
+
existing_video_ids = set()
|
| 19 |
+
files = os.listdir(video_dir)
|
| 20 |
+
for f in files:
|
| 21 |
+
# 跳过隐藏文件
|
| 22 |
+
if f.startswith('.'):
|
| 23 |
+
continue
|
| 24 |
+
# 获取文件名作为 ID (例如 v_xyz.mp4 -> v_xyz)
|
| 25 |
+
vid_id = os.path.splitext(f)[0]
|
| 26 |
+
existing_video_ids.add(vid_id)
|
| 27 |
+
|
| 28 |
+
print(f"找到 {len(existing_video_ids)} 个视频文件。")
|
| 29 |
+
|
| 30 |
+
# 2. 读取原始 JSON
|
| 31 |
+
print(f"正在读取 JSON: {json_path} ...")
|
| 32 |
+
if not os.path.exists(json_path):
|
| 33 |
+
print(f"错误: 找不到 JSON 文件 {json_path}")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 37 |
+
data = json.load(f)
|
| 38 |
+
|
| 39 |
+
original_count = len(data)
|
| 40 |
+
print(f"原始 JSON 包含 {original_count} 条数据。")
|
| 41 |
+
|
| 42 |
+
# 3. 进行过滤
|
| 43 |
+
filtered_data = []
|
| 44 |
+
for item in data:
|
| 45 |
+
# 获取 JSON 里的 video_id
|
| 46 |
+
vid = item.get('video_id')
|
| 47 |
+
# 检查是否在刚才扫描的集合里
|
| 48 |
+
if vid in existing_video_ids:
|
| 49 |
+
filtered_data.append(item)
|
| 50 |
+
|
| 51 |
+
filtered_count = len(filtered_data)
|
| 52 |
+
print(f"过滤后剩余 {filtered_count} 条数据 (剔除了 {original_count - filtered_count} 条)。")
|
| 53 |
+
|
| 54 |
+
# 4. 覆盖保存
|
| 55 |
+
if filtered_count > 0:
|
| 56 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 57 |
+
json.dump(filtered_data, f, indent=4)
|
| 58 |
+
print("✅ JSON 文件已更新!现在可以开始训练了。")
|
| 59 |
+
else:
|
| 60 |
+
print("⚠️ 警告: 过滤后数据为空!请检查视频文件夹路径是否正确,或视频文件名是否与 JSON 中的 ID 匹配。")
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
filter_data()
|
goldfish_demo.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import spaces
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from goldfish_lv import GoldFish_LV
|
| 8 |
+
from theme import minigptlv_style, custom_css,text_css
|
| 9 |
+
import re
|
| 10 |
+
from huggingface_hub import login, hf_hub_download
|
| 11 |
+
import time
|
| 12 |
+
import moviepy.editor as mp
|
| 13 |
+
from index import MemoryIndex
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# hf_token = os.environ.get('HF_TKN')
|
| 17 |
+
# login(token=hf_token)
|
| 18 |
+
def str2bool(v):
|
| 19 |
+
if isinstance(v, bool):
|
| 20 |
+
return v
|
| 21 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 22 |
+
return True
|
| 23 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 24 |
+
return False
|
| 25 |
+
else:
|
| 26 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 27 |
+
|
| 28 |
+
def get_arguments():
|
| 29 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 30 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 31 |
+
parser.add_argument("--name", type=str, default='test')
|
| 32 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 33 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 34 |
+
parser.add_argument("--neighbours", type=int, default=3)
|
| 35 |
+
parser.add_argument("--eval_opt", type=str, default='all')
|
| 36 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
| 37 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 38 |
+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for short video clips")
|
| 39 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 40 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 41 |
+
parser.add_argument("--video_path", type=str, help="Path to the video file")
|
| 42 |
+
parser.add_argument("--options", nargs="+")
|
| 43 |
+
return parser.parse_args()
|
| 44 |
+
|
| 45 |
+
def download_video(youtube_url, download_finish):
|
| 46 |
+
if is_youtube_url(youtube_url):
|
| 47 |
+
processed_video_path = goldfish_obj.process_video_url(youtube_url)
|
| 48 |
+
download_finish = gr.State(value=True)
|
| 49 |
+
return processed_video_path, download_finish
|
| 50 |
+
else:
|
| 51 |
+
return None, download_finish
|
| 52 |
+
def is_youtube_url(url: str) -> bool:
|
| 53 |
+
youtube_regex = (
|
| 54 |
+
r'(https?://)?(www\.)?'
|
| 55 |
+
'(youtube|youtu|youtube-nocookie)\.(com|be)/'
|
| 56 |
+
'(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})'
|
| 57 |
+
)
|
| 58 |
+
return bool(re.match(youtube_regex, url))
|
| 59 |
+
|
| 60 |
+
@spaces.GPU(duration=60*5)
|
| 61 |
+
def gradio_long_inference_video(videos_list,tmp_save_path,subtitle_paths, use_subtitles=True):
|
| 62 |
+
clips_summary = goldfish_obj.long_inference_video(videos_list,tmp_save_path,subtitle_paths)
|
| 63 |
+
return clips_summary
|
| 64 |
+
|
| 65 |
+
@spaces.GPU(duration=60*3)
|
| 66 |
+
def gradio_short_inference_video(video_path, instruction, use_subtitles=True):
|
| 67 |
+
pred = goldfish_obj.short_video_inference(video_path, instruction, use_subtitles)
|
| 68 |
+
return pred
|
| 69 |
+
|
| 70 |
+
@spaces.GPU(duration=60*3)
|
| 71 |
+
def gradio_inference_RAG (instruction,related_information):
|
| 72 |
+
pred=goldfish_obj.inference_RAG([instruction], [related_information])[0]
|
| 73 |
+
return pred
|
| 74 |
+
def inference(video_path, use_subtitles=True, instruction="", number_of_neighbours=3):
|
| 75 |
+
start_time = time.time()
|
| 76 |
+
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
| 77 |
+
goldfish_obj.args.neighbours = number_of_neighbours
|
| 78 |
+
print(f"Video name: {video_name}")
|
| 79 |
+
video_duration = mp.VideoFileClip(video_path).duration
|
| 80 |
+
print(f"Video duration: {video_duration:.2f} seconds")
|
| 81 |
+
# if the video duration is more than 2 minutes we need to run the long inference
|
| 82 |
+
if video_duration > 180 :
|
| 83 |
+
print("Long video")
|
| 84 |
+
# if the video data is already stored in the external memory, we can use it directly else we need to run the long inference
|
| 85 |
+
file_path=f'new_workspace/clips_summary/demo/{video_name}.json'
|
| 86 |
+
if not os.path.exists(file_path):
|
| 87 |
+
print("Clips summary is not ready")
|
| 88 |
+
videos_list,tmp_save_path=goldfish_obj.split_long_video_into_clips(video_path)
|
| 89 |
+
subtitle_paths = []
|
| 90 |
+
for video_p in videos_list:
|
| 91 |
+
clip_path = os.path.join(tmp_save_path, video_p)
|
| 92 |
+
subtitle_path = goldfish_obj.get_subtitles(clip_path) if use_subtitles else None
|
| 93 |
+
subtitle_paths.append(subtitle_path)
|
| 94 |
+
gradio_long_inference_video(videos_list,tmp_save_path,subtitle_paths, use_subtitles=use_subtitles)
|
| 95 |
+
else:
|
| 96 |
+
print("External memory is ready")
|
| 97 |
+
os.makedirs("new_workspace/embedding/demo", exist_ok=True)
|
| 98 |
+
os.makedirs("new_workspace/open_ai_embedding/demo", exist_ok=True)
|
| 99 |
+
if goldfish_obj.args.use_openai_embedding:
|
| 100 |
+
embedding_path=f"new_workspace/open_ai_embedding/demo/{video_name}.pkl"
|
| 101 |
+
else:
|
| 102 |
+
embedding_path=f"new_workspace/embedding/demo/{video_name}.pkl"
|
| 103 |
+
external_memory=MemoryIndex(goldfish_obj.args.neighbours,use_openai=goldfish_obj.args.use_openai_embedding)
|
| 104 |
+
if os.path.exists(embedding_path):
|
| 105 |
+
print("Loading embeddings from pkl file")
|
| 106 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 107 |
+
else:
|
| 108 |
+
# will embed the information and save it in the pkl file
|
| 109 |
+
external_memory.load_documents_from_json(file_path,embedding_path)
|
| 110 |
+
# get the most similar context from the external memory to this instruction
|
| 111 |
+
|
| 112 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(instruction)
|
| 113 |
+
related_information=goldfish_obj.get_related_context(external_memory,related_context_keys)
|
| 114 |
+
pred=gradio_inference_RAG(instruction,related_information)
|
| 115 |
+
# remove stored data
|
| 116 |
+
# os.remove(file_path)
|
| 117 |
+
# os.system(f"rm -r workspace/tmp/{self.video_name}")
|
| 118 |
+
# os.system(f"rm -r workspace/subtitles/{self.video_name}")
|
| 119 |
+
# os.system(f"rm workspace/tmp/{self.video_id}.mp4")
|
| 120 |
+
else:
|
| 121 |
+
print("Short video")
|
| 122 |
+
goldfish_obj.video_name=video_path.split('/')[-1].split('.')[0]
|
| 123 |
+
pred=gradio_short_inference_video(video_path,instruction,use_subtitles)
|
| 124 |
+
processing_time = time.time() - start_time
|
| 125 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 126 |
+
return pred
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def process_video(path_url, has_subtitles, instruction, number_of_neighbours):
|
| 130 |
+
if is_youtube_url(path_url):
|
| 131 |
+
video_path = return_video_path(path_url)
|
| 132 |
+
else:
|
| 133 |
+
video_path = path_url
|
| 134 |
+
pred = inference(video_path, has_subtitles, instruction, number_of_neighbours)
|
| 135 |
+
return pred
|
| 136 |
+
|
| 137 |
+
def return_video_path(youtube_url):
|
| 138 |
+
video_id = youtube_url.split("https://www.youtube.com/watch?v=")[-1].split('&')[0]
|
| 139 |
+
if video_id:
|
| 140 |
+
return os.path.join("workspace", "tmp", f"{video_id}.mp4")
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError("Invalid YouTube URL provided.")
|
| 143 |
+
|
| 144 |
+
def run_gradio():
|
| 145 |
+
title = """<h1 align="center">Goldfish Demo </h1>"""
|
| 146 |
+
description = """<h5>[ECCV 2024 Accepted]Goldfish: Vision-Language Understanding of Arbitrarily Long Videos</h5>"""
|
| 147 |
+
project_page = """<p><a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
|
| 148 |
+
code_link="""<p><a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='repo_imgs/goldfishai_png.png'></a></p>"""
|
| 149 |
+
paper_link="""<p><a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>"""
|
| 150 |
+
with gr.Blocks(title="Goldfish demo",css=text_css ) as demo :
|
| 151 |
+
gr.Markdown(title)
|
| 152 |
+
gr.Markdown(description)
|
| 153 |
+
with gr.Tab("Youtube videos") as youtube_tab:
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column():
|
| 156 |
+
youtube_link = gr.Textbox(label="YouTube link", placeholder="Paste YouTube URL here")
|
| 157 |
+
video_player = gr.Video(autoplay=False)
|
| 158 |
+
download_finish = gr.State(value=False)
|
| 159 |
+
youtube_link.change(
|
| 160 |
+
fn=download_video,
|
| 161 |
+
inputs=[youtube_link, download_finish],
|
| 162 |
+
outputs=[video_player, download_finish]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
with gr.Row():
|
| 166 |
+
with gr.Column(scale=2) :
|
| 167 |
+
youtube_question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?")
|
| 168 |
+
youtube_has_subtitles = gr.Checkbox(label="Use subtitles", value=True)
|
| 169 |
+
youtube_input_note = """<p>For the global questions set the number of neighbours=-1 otherwise use 3 as the defualt.</p>"""
|
| 170 |
+
gr.Markdown(youtube_input_note)
|
| 171 |
+
# input number
|
| 172 |
+
youtube_number_of_neighbours=gr.Number(label="Number of Neighbours",interactive=True,value=3)
|
| 173 |
+
youtube_process_button = gr.Button("⛓️ Answer the Question (QA)")
|
| 174 |
+
with gr.Column(scale=3):
|
| 175 |
+
youtube_answer = gr.Textbox(label="Answer of the question", lines=8, interactive=True, placeholder="Answer of the question will show up here.")
|
| 176 |
+
youtube_process_button.click(fn=process_video, inputs=[youtube_link, youtube_has_subtitles, youtube_question,youtube_number_of_neighbours], outputs=[youtube_answer])
|
| 177 |
+
with gr.Tab("Local videos") as local_tab:
|
| 178 |
+
with gr.Row():
|
| 179 |
+
with gr.Column():
|
| 180 |
+
local_video_player = gr.Video(sources=["upload"])
|
| 181 |
+
with gr.Row():
|
| 182 |
+
with gr.Column(scale=2):
|
| 183 |
+
local_question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?")
|
| 184 |
+
local_has_subtitles = gr.Checkbox(label="Use subtitles", value=True)
|
| 185 |
+
local_input_note = """<p>For the global questions set the number of neighbours=-1 otherwise use 3 as the defualt.</p>"""
|
| 186 |
+
gr.Markdown(local_input_note)
|
| 187 |
+
local_number_of_neighbours=gr.Number(label="Number of Neighbours",interactive=True,value=3)
|
| 188 |
+
local_process_button = gr.Button("⛓️ Answer the Question (QA)")
|
| 189 |
+
with gr.Column(scale=3):
|
| 190 |
+
local_answer = gr.Textbox(label="Answer of the question", lines=8, interactive=True, placeholder="Answer of the question will show up here.")
|
| 191 |
+
local_process_button.click(fn=process_video, inputs=[local_video_player, local_has_subtitles, local_question,local_number_of_neighbours], outputs=[local_answer])
|
| 192 |
+
|
| 193 |
+
demo.queue(max_size=10).launch(show_error=True,share=True, show_api=False,server_port=5000)
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
args=get_arguments()
|
| 197 |
+
goldfish_obj = GoldFish_LV(args)
|
| 198 |
+
run_gradio()
|
goldfish_inference.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from goldfish_lv import GoldFish_LV
|
| 8 |
+
from theme import minigptlv_style
|
| 9 |
+
import time
|
| 10 |
+
def str2bool(v):
|
| 11 |
+
if isinstance(v, bool):
|
| 12 |
+
return v
|
| 13 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 14 |
+
return True
|
| 15 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 16 |
+
return False
|
| 17 |
+
else:
|
| 18 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 19 |
+
|
| 20 |
+
def get_arguments():
|
| 21 |
+
parser = argparse.ArgumentParser(description="Inference parameters")
|
| 22 |
+
parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml")
|
| 23 |
+
parser.add_argument("--neighbours", type=int, default=3)
|
| 24 |
+
parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth")
|
| 25 |
+
parser.add_argument("--add_subtitles", action='store_true')
|
| 26 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
| 27 |
+
parser.add_argument("--use_openai_embedding",type=str2bool, default=False)
|
| 28 |
+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for short video clips")
|
| 29 |
+
parser.add_argument("--lora_r", type=int, default=64)
|
| 30 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 31 |
+
parser.add_argument("--video_path", type=str,default="path for video.mp4", help="Path to the video file or youtube url")
|
| 32 |
+
parser.add_argument("--question", type=str, default="Why rachel is wearing a wedding dress?")
|
| 33 |
+
parser.add_argument("--options", nargs="+")
|
| 34 |
+
return parser.parse_args()
|
| 35 |
+
|
| 36 |
+
def download_video(youtube_url):
|
| 37 |
+
processed_video_path = goldfish_lv.process_video_url(youtube_url)
|
| 38 |
+
return processed_video_path
|
| 39 |
+
|
| 40 |
+
def process_video(video_path, has_subtitles, instruction="",number_of_neighbours=-1):
|
| 41 |
+
result = goldfish_lv.inference(video_path, has_subtitles, instruction,number_of_neighbours)
|
| 42 |
+
pred = result["pred"]
|
| 43 |
+
return pred
|
| 44 |
+
|
| 45 |
+
def return_video_path(youtube_url):
|
| 46 |
+
video_id = youtube_url.split("https://www.youtube.com/watch?v=")[-1].split('&')[0]
|
| 47 |
+
if video_id:
|
| 48 |
+
return os.path.join("workspace", "tmp", f"{video_id}.mp4")
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("Invalid YouTube URL provided.")
|
| 51 |
+
|
| 52 |
+
args=get_arguments()
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
t1=time.time()
|
| 55 |
+
print("using openai: ", args.use_openai_embedding)
|
| 56 |
+
goldfish_lv = GoldFish_LV(args)
|
| 57 |
+
t2=time.time()
|
| 58 |
+
print("Time taken to load model: ", t2-t1)
|
| 59 |
+
processed_video_path = goldfish_lv.process_video_url(args.video_path)
|
| 60 |
+
pred=process_video(processed_video_path, args.add_subtitles, args.question,args.neighbours)
|
| 61 |
+
print("Question answer: ", pred)
|
| 62 |
+
print(f"Time taken for inference: ", time.time()-t2)
|
goldfish_lv.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import moviepy.editor as mp
|
| 10 |
+
import webvtt
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from typing import Optional, List
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from pytubefix import YouTube
|
| 18 |
+
from minigpt4.common.eval_utils import init_model
|
| 19 |
+
from minigpt4.conversation.conversation import CONV_VISION
|
| 20 |
+
from index import MemoryIndex
|
| 21 |
+
import pysrt
|
| 22 |
+
import chardet
|
| 23 |
+
from openai import OpenAI
|
| 24 |
+
if os.getenv("OPENAI_API_KEY") is not None:
|
| 25 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 26 |
+
else:
|
| 27 |
+
client = OpenAI(api_key="")
|
| 28 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 29 |
+
import re
|
| 30 |
+
from transformers import BitsAndBytesConfig
|
| 31 |
+
# from split_long_video_in_parallel import split_video
|
| 32 |
+
import transformers
|
| 33 |
+
import whisper
|
| 34 |
+
from datetime import timedelta
|
| 35 |
+
# Function to format timestamps for VTT
|
| 36 |
+
def format_timestamp(seconds):
|
| 37 |
+
td = timedelta(seconds=seconds)
|
| 38 |
+
total_seconds = int(td.total_seconds())
|
| 39 |
+
milliseconds = int(td.microseconds / 1000)
|
| 40 |
+
hours, remainder = divmod(total_seconds, 3600)
|
| 41 |
+
minutes, seconds = divmod(remainder, 60)
|
| 42 |
+
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
|
| 43 |
+
|
| 44 |
+
def clean_text(subtitles_text):
|
| 45 |
+
# Remove unwanted characters except for letters, digits, spaces, periods, commas, exclamation marks, and single quotes
|
| 46 |
+
subtitles_text = re.sub(r'[^a-zA-Z0-9\s\']', '', subtitles_text)
|
| 47 |
+
# Replace multiple spaces with a single space
|
| 48 |
+
subtitles_text = re.sub(r'\s+', ' ', subtitles_text)
|
| 49 |
+
return subtitles_text.strip()
|
| 50 |
+
def time_to_seconds(subrip_time):
|
| 51 |
+
return subrip_time.hours * 3600 + subrip_time.minutes * 60 + subrip_time.seconds + subrip_time.milliseconds / 1000
|
| 52 |
+
|
| 53 |
+
def split_subtitles(subtitle_path, n):
|
| 54 |
+
# read the subtitle file and detect the encoding
|
| 55 |
+
with open(subtitle_path, 'rb') as f:
|
| 56 |
+
result = chardet.detect(f.read())
|
| 57 |
+
subs = pysrt.open(subtitle_path, encoding=result['encoding'])
|
| 58 |
+
|
| 59 |
+
total_subs = len(subs)
|
| 60 |
+
|
| 61 |
+
if n <= 0 or n > total_subs:
|
| 62 |
+
print("Invalid value for n. It should be a positive integer less than or equal to the total number of subtitles.")
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
subs_per_paragraph = total_subs // n
|
| 66 |
+
remainder = total_subs % n
|
| 67 |
+
|
| 68 |
+
paragraphs = []
|
| 69 |
+
|
| 70 |
+
current_index = 0
|
| 71 |
+
|
| 72 |
+
for i in range(n):
|
| 73 |
+
num_subs_in_paragraph = subs_per_paragraph + (1 if i < remainder else 0)
|
| 74 |
+
|
| 75 |
+
paragraph_subs = subs[current_index:current_index + num_subs_in_paragraph]
|
| 76 |
+
current_index += num_subs_in_paragraph
|
| 77 |
+
|
| 78 |
+
# Join subtitles using pysrt's built-in method for efficient formatting
|
| 79 |
+
paragraph = pysrt.SubRipFile(items=paragraph_subs).text
|
| 80 |
+
paragraphs.append(paragraph)
|
| 81 |
+
|
| 82 |
+
return paragraphs
|
| 83 |
+
class GoldFish_LV:
|
| 84 |
+
"""
|
| 85 |
+
'GoldFish_LV' class is to handle long video processing and subtitle management with MiniGPT4_video base model.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 89 |
+
self.args = args
|
| 90 |
+
self.model, self.vis_processor,whisper_gpu_id,minigpt4_gpu_id,answer_module_gpu_id = init_model(args)
|
| 91 |
+
self.whisper_gpu_id=whisper_gpu_id
|
| 92 |
+
self.minigpt4_gpu_id=minigpt4_gpu_id
|
| 93 |
+
self.answer_module_gpu_id=answer_module_gpu_id
|
| 94 |
+
# self.original_llama_model,self.original_llama_tokenizer=self.load_original_llama_model()
|
| 95 |
+
# self.original_llama_model=self.load_original_llama_model_vllm()
|
| 96 |
+
self.llama_3_1_model=self.load_llama3_1_model()
|
| 97 |
+
self.whisper_model=whisper.load_model("large",device=f"cuda:{self.whisper_gpu_id}")
|
| 98 |
+
# self.summary_instruction="Generate a description of this video .Pay close attention to the objects, actions, emotions portrayed in the video,providing a vivid description of key moments.Specify any visual cues or elements that stand out."
|
| 99 |
+
self.summary_instruction="I'm a blind person, please provide me with a detailed summary of the video content and try to be as descriptive as possible."
|
| 100 |
+
def load_original_llama_model(self):
|
| 101 |
+
model_name="meta-llama/Meta-Llama-3-8B-Instruct"
|
| 102 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 103 |
+
tokenizer.pad_token = "[PAD]"
|
| 104 |
+
tokenizer.padding_side = "left"
|
| 105 |
+
bnb_config = BitsAndBytesConfig(
|
| 106 |
+
load_in_8bit=True,
|
| 107 |
+
)
|
| 108 |
+
llama_model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
+
model_name,
|
| 110 |
+
torch_dtype=torch.bfloat16,
|
| 111 |
+
device_map={'': f"cuda:{self.answer_module_gpu_id}"},
|
| 112 |
+
quantization_config=bnb_config,
|
| 113 |
+
)
|
| 114 |
+
return llama_model,tokenizer
|
| 115 |
+
|
| 116 |
+
def load_llama3_1_model(self):
|
| 117 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
| 118 |
+
bnb_config = BitsAndBytesConfig(
|
| 119 |
+
load_in_8bit=True,
|
| 120 |
+
)
|
| 121 |
+
self.llama3_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 122 |
+
llama3_model = AutoModelForCausalLM.from_pretrained(
|
| 123 |
+
model_id,
|
| 124 |
+
torch_dtype=torch.bfloat16,
|
| 125 |
+
device_map={'': f"cuda:{self.answer_module_gpu_id}"},
|
| 126 |
+
quantization_config=bnb_config,
|
| 127 |
+
)
|
| 128 |
+
pipeline = transformers.pipeline(
|
| 129 |
+
"text-generation",
|
| 130 |
+
model=llama3_model,
|
| 131 |
+
tokenizer=self.llama3_tokenizer,
|
| 132 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 133 |
+
device_map=f"cuda:{self.answer_module_gpu_id}",
|
| 134 |
+
)
|
| 135 |
+
return pipeline
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _youtube_download(self, url: str) -> str:
|
| 140 |
+
try:
|
| 141 |
+
video_id = url.split('v=')[-1].split('&')[0]
|
| 142 |
+
video_id = video_id.strip()
|
| 143 |
+
print(f"Downloading video with ID: {video_id}")
|
| 144 |
+
youtube = YouTube(f"https://www.youtube.com/watch?v={video_id}")
|
| 145 |
+
video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
| 146 |
+
if not video_stream:
|
| 147 |
+
raise ValueError("No suitable video stream found.")
|
| 148 |
+
output_path = f"workspace/tmp/{video_id}.mp4"
|
| 149 |
+
self.video_id=video_id
|
| 150 |
+
video_stream.download(output_path="workspace/tmp", filename=f"{video_id}.mp4")
|
| 151 |
+
return output_path
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Error downloading video: {e}")
|
| 154 |
+
return url
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def is_youtube_url(url: str) -> bool:
|
| 158 |
+
youtube_regex = (
|
| 159 |
+
r'(https?://)?(www\.)?'
|
| 160 |
+
'(youtube|youtu|youtube-nocookie)\.(com|be)/'
|
| 161 |
+
'(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})'
|
| 162 |
+
)
|
| 163 |
+
return bool(re.match(youtube_regex, url))
|
| 164 |
+
|
| 165 |
+
def process_video_url(self, video_path: str) -> str:
|
| 166 |
+
if self.is_youtube_url(video_path):
|
| 167 |
+
return self._youtube_download(video_path)
|
| 168 |
+
else:
|
| 169 |
+
return video_path
|
| 170 |
+
|
| 171 |
+
def create_video_grid(self, images: list, rows: int, cols: int, save_path: str) -> Image.Image:
|
| 172 |
+
image_width, image_height = images[0].size
|
| 173 |
+
grid_width = cols * image_width
|
| 174 |
+
grid_height = rows * image_height
|
| 175 |
+
new_image = Image.new("RGB", (grid_width, grid_height))
|
| 176 |
+
for i in range(rows):
|
| 177 |
+
for j in range(cols):
|
| 178 |
+
index = i * cols + j
|
| 179 |
+
if index < len(images):
|
| 180 |
+
image = images[index]
|
| 181 |
+
x_offset = j * image_width
|
| 182 |
+
y_offset = i * image_height
|
| 183 |
+
new_image.paste(image, (x_offset, y_offset))
|
| 184 |
+
|
| 185 |
+
new_image.save(save_path)
|
| 186 |
+
return new_image
|
| 187 |
+
def get_subtitles(self, video_path) :
|
| 188 |
+
video_name=video_path.split('/')[-2]
|
| 189 |
+
video_id=video_path.split('/')[-1].split('.')[0]
|
| 190 |
+
audio_dir = f"workspace/audio/{video_name}"
|
| 191 |
+
subtitle_dir = f"workspace/subtitles/{video_name}"
|
| 192 |
+
os.makedirs(audio_dir, exist_ok=True)
|
| 193 |
+
os.makedirs(subtitle_dir, exist_ok=True)
|
| 194 |
+
# if the subtitles are already generated, return the path of the subtitles
|
| 195 |
+
subtitle_path = f"{subtitle_dir}/{video_id}"+'.vtt'
|
| 196 |
+
if os.path.exists(subtitle_path):
|
| 197 |
+
return f"{subtitle_dir}/{video_id}"+'.vtt'
|
| 198 |
+
audio_path = f"{audio_dir}/{video_id}"+'.mp3'
|
| 199 |
+
try:
|
| 200 |
+
self.extract_audio(video_path, audio_path)
|
| 201 |
+
subtitle_path = f"{subtitle_dir}/{video_id}"+'.vtt'
|
| 202 |
+
result = self.whisper_model.transcribe(audio_path,language="en")
|
| 203 |
+
# Create VTT file
|
| 204 |
+
with open(subtitle_path, "w", encoding="utf-8") as vtt_file:
|
| 205 |
+
vtt_file.write("WEBVTT\n\n")
|
| 206 |
+
for segment in result['segments']:
|
| 207 |
+
start = format_timestamp(segment['start'])
|
| 208 |
+
end = format_timestamp(segment['end'])
|
| 209 |
+
text = segment['text']
|
| 210 |
+
vtt_file.write(f"{start} --> {end}\n{text}\n\n")
|
| 211 |
+
return subtitle_path
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"Error during subtitle generation for {video_path}: {e}")
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
def prepare_input(self,
|
| 217 |
+
video_path: str,
|
| 218 |
+
subtitle_path: Optional[str],
|
| 219 |
+
instruction: str,previous_caption=""):
|
| 220 |
+
# If a subtitle path is provided, read the VTT (Web Video Text Tracks) file, else set to an empty list
|
| 221 |
+
conversation=""
|
| 222 |
+
if subtitle_path:
|
| 223 |
+
vtt_file = webvtt.read(subtitle_path)
|
| 224 |
+
print("Subtitle loaded successfully")
|
| 225 |
+
try:
|
| 226 |
+
for subtitle in vtt_file:
|
| 227 |
+
sub = subtitle.text.replace('\n',' ')
|
| 228 |
+
conversation+=sub
|
| 229 |
+
except:
|
| 230 |
+
pass
|
| 231 |
+
if self.model.model_type == "Mistral":
|
| 232 |
+
max_images_length=90
|
| 233 |
+
max_sub_len = 800
|
| 234 |
+
else:
|
| 235 |
+
max_images_length = 45
|
| 236 |
+
max_sub_len = 400
|
| 237 |
+
# Load the video file using moviepy and calculate the total number of frames
|
| 238 |
+
clip = mp.VideoFileClip(video_path)
|
| 239 |
+
total_num_frames = int(clip.duration * clip.fps)
|
| 240 |
+
clip.close()
|
| 241 |
+
# Calculate how often to sample a frame based on the total number of frames and the maximum images length
|
| 242 |
+
cap = cv2.VideoCapture(video_path)
|
| 243 |
+
images = []
|
| 244 |
+
frame_count = 0
|
| 245 |
+
sampling_interval = int(total_num_frames / max_images_length)
|
| 246 |
+
if sampling_interval == 0:
|
| 247 |
+
sampling_interval = 1
|
| 248 |
+
# Initialize variables to hold image placeholders, current subtitle text, and subtitle history
|
| 249 |
+
if previous_caption != "":
|
| 250 |
+
img_placeholder = previous_caption+" "
|
| 251 |
+
else:
|
| 252 |
+
img_placeholder = ""
|
| 253 |
+
subtitle_text_in_interval = ""
|
| 254 |
+
history_subtitles = {}
|
| 255 |
+
raw_frames=[]
|
| 256 |
+
number_of_words=0
|
| 257 |
+
transform=transforms.Compose([
|
| 258 |
+
transforms.ToPILImage(),
|
| 259 |
+
])
|
| 260 |
+
# Loop through each frame in the video
|
| 261 |
+
while cap.isOpened():
|
| 262 |
+
ret, frame = cap.read()
|
| 263 |
+
if not ret:
|
| 264 |
+
break
|
| 265 |
+
# TODO: we need to add subtitles in external memory either
|
| 266 |
+
if subtitle_path is not None:
|
| 267 |
+
for i, subtitle in enumerate(vtt_file):
|
| 268 |
+
sub = subtitle.text.replace('\n',' ')
|
| 269 |
+
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval:
|
| 270 |
+
|
| 271 |
+
if not history_subtitles.get(sub, False):
|
| 272 |
+
subtitle_text_in_interval += sub + " "
|
| 273 |
+
|
| 274 |
+
history_subtitles[sub] = True
|
| 275 |
+
break
|
| 276 |
+
# Process and store the frame at specified intervals
|
| 277 |
+
if frame_count % sampling_interval == 0:
|
| 278 |
+
raw_frames.append(Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)))
|
| 279 |
+
frame = transform(frame[:,:,::-1]) # convert to RGB
|
| 280 |
+
frame = self.vis_processor(frame)
|
| 281 |
+
images.append(frame)
|
| 282 |
+
img_placeholder += '<Img><ImageHere>'
|
| 283 |
+
if subtitle_path is not None and subtitle_text_in_interval != "" and number_of_words< max_sub_len:
|
| 284 |
+
img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
|
| 285 |
+
number_of_words+=len(subtitle_text_in_interval.split(' '))
|
| 286 |
+
subtitle_text_in_interval = ""
|
| 287 |
+
frame_count += 1
|
| 288 |
+
|
| 289 |
+
# Break the loop if the maximum number of images is reached
|
| 290 |
+
if len(images) >= max_images_length:
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
cap.release()
|
| 294 |
+
cv2.destroyAllWindows()
|
| 295 |
+
|
| 296 |
+
# Return None if no images are extracted
|
| 297 |
+
if len(images) == 0:
|
| 298 |
+
return None, None
|
| 299 |
+
while len(images) < max_images_length:
|
| 300 |
+
images.append(images[-1])
|
| 301 |
+
img_placeholder += '<Img><ImageHere>'
|
| 302 |
+
images = torch.stack(images)
|
| 303 |
+
print("Input instruction length",len(instruction.split(' ')))
|
| 304 |
+
instruction = img_placeholder + '\n' + instruction
|
| 305 |
+
print("number of words",number_of_words)
|
| 306 |
+
print("number of images",len(images))
|
| 307 |
+
|
| 308 |
+
return images, instruction,conversation
|
| 309 |
+
|
| 310 |
+
def extract_audio(self, video_path: str, audio_path: str) -> None:
|
| 311 |
+
video_clip = mp.VideoFileClip(video_path)
|
| 312 |
+
audio_clip = video_clip.audio
|
| 313 |
+
audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k")
|
| 314 |
+
|
| 315 |
+
def short_video_inference (self,video_path,instruction,gen_subtitles=True):
|
| 316 |
+
if gen_subtitles:
|
| 317 |
+
subtitle_path=self.get_subtitles(video_path)
|
| 318 |
+
else :
|
| 319 |
+
subtitle_path=None
|
| 320 |
+
prepared_images,prepared_instruction,video_conversation=self.prepare_input(video_path,subtitle_path,instruction)
|
| 321 |
+
if prepared_images is None:
|
| 322 |
+
return "Video cann't be open ,check the video path again"
|
| 323 |
+
length=len(prepared_images)
|
| 324 |
+
prepared_images=prepared_images.unsqueeze(0)
|
| 325 |
+
conv = CONV_VISION.copy()
|
| 326 |
+
conv.system = ""
|
| 327 |
+
# if you want to make conversation comment the 2 lines above and make the conv is global variable
|
| 328 |
+
conv.append_message(conv.roles[0], prepared_instruction)
|
| 329 |
+
conv.append_message(conv.roles[1], None)
|
| 330 |
+
prompt = [conv.get_prompt()]
|
| 331 |
+
answers = self.model.generate(prepared_images, prompt, max_new_tokens=512, do_sample=False, lengths=[length],num_beams=1)
|
| 332 |
+
return answers[0]
|
| 333 |
+
|
| 334 |
+
def split_long_video_into_clips(self,video_path):
|
| 335 |
+
# Split the video into 90 seconds clips and make a queue of the videos and run the inference on each video
|
| 336 |
+
self.video_name=video_path.split('/')[-1].split('.')[0]
|
| 337 |
+
tmp_save_path=f"workspace/tmp/{self.video_name}"
|
| 338 |
+
os.makedirs(tmp_save_path, exist_ok=True)
|
| 339 |
+
print("tmp_save_path",tmp_save_path)
|
| 340 |
+
|
| 341 |
+
if len(os.listdir(tmp_save_path)) == 0:
|
| 342 |
+
print("Splitting Long video")
|
| 343 |
+
os.system(f"python split_long_video_in_parallel.py --video_path {video_path} --output_folder {tmp_save_path}")
|
| 344 |
+
# split_video(video_path, tmp_save_path, clip_duration=90)
|
| 345 |
+
videos_list = sorted(os.listdir(tmp_save_path))
|
| 346 |
+
return videos_list,tmp_save_path
|
| 347 |
+
def long_inference_video(self, videos_list,tmp_save_path,subtitle_paths) -> Optional[str]:
|
| 348 |
+
save_long_videos_path = "new_workspace/clips_summary/demo"
|
| 349 |
+
os.makedirs(save_long_videos_path, exist_ok=True)
|
| 350 |
+
file_path = f'{save_long_videos_path}/{self.video_name}.json'
|
| 351 |
+
|
| 352 |
+
if os.path.exists(file_path):
|
| 353 |
+
print("Clips inference already done")
|
| 354 |
+
with open(file_path, 'r') as file:
|
| 355 |
+
video_information = json.load(file)
|
| 356 |
+
else:
|
| 357 |
+
video_number = 0
|
| 358 |
+
batch_size = self.args.batch_size
|
| 359 |
+
batch_video_paths, batch_instructions ,batch_subtitles= [], [],[]
|
| 360 |
+
video_information = {}
|
| 361 |
+
video_captions = []
|
| 362 |
+
for i, video in tqdm(enumerate(videos_list), desc="Inference video clips", total=len(videos_list)):
|
| 363 |
+
clip_path = os.path.join(tmp_save_path, video)
|
| 364 |
+
batch_video_paths.append(clip_path)
|
| 365 |
+
# previous_caption = "You are analysing a one long video of mutiple clips and this is the summary from all previous clips :"+video_captions[-1]+"\n\n" if video_captions else ""
|
| 366 |
+
previous_caption=""
|
| 367 |
+
batch_instructions.append(self.summary_instruction)
|
| 368 |
+
batch_subtitles.append(subtitle_paths[i])
|
| 369 |
+
# Process each batch
|
| 370 |
+
if len(batch_video_paths) % batch_size == 0 and i != 0:
|
| 371 |
+
batch_preds,videos_conversation=self.run_batch(batch_video_paths,batch_instructions, batch_subtitles,previous_caption)
|
| 372 |
+
for pred,subtitle in zip(batch_preds,videos_conversation):
|
| 373 |
+
video_number += 1
|
| 374 |
+
save_name=f"{video_number}".zfill(5)
|
| 375 |
+
if pred != "":
|
| 376 |
+
video_information[f'caption__{save_name}'] = pred
|
| 377 |
+
if subtitle != "":
|
| 378 |
+
video_information[f'subtitle__{save_name}'] = subtitle
|
| 379 |
+
video_captions.append(pred)
|
| 380 |
+
batch_video_paths, batch_instructions,batch_subtitles = [], [],[]
|
| 381 |
+
|
| 382 |
+
# Process any remaining videos in the last batch
|
| 383 |
+
if batch_video_paths:
|
| 384 |
+
batch_preds,videos_conversation=self.run_batch(batch_video_paths,batch_instructions, batch_subtitles,previous_caption)
|
| 385 |
+
for pred,subtitle in zip(batch_preds,videos_conversation):
|
| 386 |
+
video_number += 1
|
| 387 |
+
save_name=f"{video_number}".zfill(5)
|
| 388 |
+
if pred != "":
|
| 389 |
+
video_information[f'caption__{save_name}'] = pred
|
| 390 |
+
if subtitle != "":
|
| 391 |
+
video_information[f'subtitle__{save_name}'] = subtitle
|
| 392 |
+
video_captions.append(pred)
|
| 393 |
+
with open(file_path, 'w') as file:
|
| 394 |
+
json.dump(video_information, file, indent=4)
|
| 395 |
+
print("Clips inference done")
|
| 396 |
+
return video_information
|
| 397 |
+
# def inference_RAG(self, instructions, context_list):
|
| 398 |
+
# context_promots=[]
|
| 399 |
+
# questions_prompts=[]
|
| 400 |
+
# try:
|
| 401 |
+
# for instruction,context in zip(instructions,context_list):
|
| 402 |
+
# context=clean_text(context)
|
| 403 |
+
# context_prompt=f"<s>[INST] Your task is to answer questions for one long video which is split into multiple clips.\nGiven these related information from the most related clips: \n{context}\n"
|
| 404 |
+
# question_prompt=f"\nAnswer this question :{instruction} \n your answer is: [/INST]"
|
| 405 |
+
# context_promots.append(context_prompt)
|
| 406 |
+
# questions_prompts.append(question_prompt)
|
| 407 |
+
# context_inputs = self.original_llama_tokenizer(context_promots, return_tensors="pt", padding=True, truncation=True,max_length=3500)
|
| 408 |
+
# # print(context_inputs.keys())
|
| 409 |
+
# print("context_inputs shape",context_inputs['input_ids'].shape)
|
| 410 |
+
# question_inputs = self.original_llama_tokenizer(questions_prompts, return_tensors="pt", padding=True, truncation=True,max_length=300)
|
| 411 |
+
# print("question_inputs shape",question_inputs['input_ids'].shape)
|
| 412 |
+
# # concate the context and the question together
|
| 413 |
+
# inputs_ids=torch.cat((context_inputs['input_ids'],question_inputs['input_ids']),dim=1).to('cuda')
|
| 414 |
+
# print("inputs shape",inputs_ids.shape)
|
| 415 |
+
# except Exception as e:
|
| 416 |
+
# print("error while tokenization",e)
|
| 417 |
+
# return self.inference_RAG_batch_size_1(instructions, context_list)
|
| 418 |
+
# with torch.no_grad():
|
| 419 |
+
# summary_ids = self.original_llama_model.generate(inputs_ids,max_new_tokens=512)
|
| 420 |
+
# answers=[]
|
| 421 |
+
# for i in range(len(summary_ids)):
|
| 422 |
+
# output_text=self.original_llama_tokenizer.decode(summary_ids[i], skip_special_tokens=True)
|
| 423 |
+
# output_text = output_text.split('</s>')[0] # remove the stop sign </s>
|
| 424 |
+
# output_text = output_text.replace("<s>", "")
|
| 425 |
+
# output_text = output_text.split(r'[/INST]')[-1].strip()
|
| 426 |
+
# answers.append(output_text)
|
| 427 |
+
# return answers
|
| 428 |
+
def inference_RAG(self, instructions, context_list):
|
| 429 |
+
messages=[]
|
| 430 |
+
for instruction,context in zip(instructions,context_list):
|
| 431 |
+
context=clean_text(context)
|
| 432 |
+
context_prompt=f"Your task is to answer a specific question based on one long video. While you cannot view the video yourself, I will supply you with the most relevant text information from the most pertinent clips. \n{context}\n"
|
| 433 |
+
question_prompt=f"\nPlease provide a detailed and accurate answer to the following question:{instruction} \n Your answer should be:"
|
| 434 |
+
# limit the context words to 10000 word duo to hardware limitation
|
| 435 |
+
context_words=context_prompt.split(' ')
|
| 436 |
+
truncated_context=' '.join(context_words[:10000])
|
| 437 |
+
print("Number of words",len((truncated_context+question_prompt).split(' ')))
|
| 438 |
+
messages.append([{"role": "user", "content": truncated_context+question_prompt}])
|
| 439 |
+
outputs=self.llama_3_1_model(messages, max_new_tokens=512)
|
| 440 |
+
answers=[]
|
| 441 |
+
for out in outputs:
|
| 442 |
+
answers.append(out[0]["generated_text"][-1]['content'])
|
| 443 |
+
return answers
|
| 444 |
+
# def inference_RAG(self, instructions, context_list):
|
| 445 |
+
# prompts=[]
|
| 446 |
+
# for instruction,context in zip(instructions,context_list):
|
| 447 |
+
# context=clean_text(context)
|
| 448 |
+
# context_prompt=f"Your task is to answer questions for one long video which is split into multiple clips.\nGiven these related information from the most related clips: \n{context}\n"
|
| 449 |
+
# question_prompt=f"\nAnswer this question :{instruction} \n your answer is:"
|
| 450 |
+
# prompts.append(context_prompt+question_prompt)
|
| 451 |
+
|
| 452 |
+
# with open('prompts.txt','w') as f:
|
| 453 |
+
# for prompt in prompts:
|
| 454 |
+
# f.write(prompt+'\n')
|
| 455 |
+
|
| 456 |
+
# outputs=self.original_llama_model.generate(prompts)
|
| 457 |
+
# answers=[]
|
| 458 |
+
# for out in outputs:
|
| 459 |
+
# answers.append(out.outputs[0].text)
|
| 460 |
+
# return answers
|
| 461 |
+
def inference_RAG_batch_size_1(self, instructions, context_list):
|
| 462 |
+
answers=[]
|
| 463 |
+
for instruction,context in zip(instructions,context_list):
|
| 464 |
+
context=clean_text(context)
|
| 465 |
+
context_prompt=f"<s>[INST] Your task is to answer questions for one long video which is split into multiple clips.\nGiven these related information from the most related clips: \n{context}\n"
|
| 466 |
+
question_prompt=f"\nAnswer this question :{instruction} \n your answer is: [/INST]"
|
| 467 |
+
context_inputs=self.original_llama_tokenizer([context_prompt], return_tensors="pt", padding=True, truncation=True,max_length=3500)['input_ids']
|
| 468 |
+
question_inputs=self.original_llama_tokenizer([question_prompt], return_tensors="pt", padding=True, truncation=True,max_length=300)['input_ids']
|
| 469 |
+
|
| 470 |
+
inputs_ids=torch.cat((context_inputs,question_inputs),dim=1).to('cuda')
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
summary_ids = self.original_llama_model.generate(inputs_ids,max_new_tokens=512,)
|
| 473 |
+
|
| 474 |
+
output_text=self.original_llama_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 475 |
+
output_text = output_text.split('</s>')[0] # remove the stop sign </s>
|
| 476 |
+
output_text = output_text.replace("<s>", "")
|
| 477 |
+
output_text = output_text.split(r'[/INST]')[-1].strip()
|
| 478 |
+
answers.append(output_text)
|
| 479 |
+
|
| 480 |
+
return answers
|
| 481 |
+
|
| 482 |
+
# def inference_RAG_text_only(self, instructions, context_list):
|
| 483 |
+
# # Use VideoLLM as the answer module
|
| 484 |
+
# seg_tokens=[]
|
| 485 |
+
# for instruction,context in zip(instructions,context_list):
|
| 486 |
+
# context=clean_text(context)
|
| 487 |
+
# context_prompt=f"<s>[INST] Your task is to answer questions for one long video which is split into multiple clips.\nGiven these related information from the most related clips: \n{context}\n"
|
| 488 |
+
# question_prompt=f"\nAnswer this question :{instruction} \n your answer is: [/INST]"
|
| 489 |
+
# context_inputs = self.model.llama_tokenizer(context_prompt,add_special_tokens=True, return_tensors="pt", padding=True, truncation=True,max_length=3500)
|
| 490 |
+
# question_inputs = self.model.llama_tokenizer(question_prompt, return_tensors="pt", padding=True, truncation=True,max_length=300)
|
| 491 |
+
# # concate the context and the question together
|
| 492 |
+
# inputs_ids=torch.cat((context_inputs['input_ids'],question_inputs['input_ids']),dim=1).to('cuda')
|
| 493 |
+
# seg_tokens.append(inputs_ids)
|
| 494 |
+
# with torch.no_grad():
|
| 495 |
+
# answers = self.model.generate_text_only(images=None,seg_tokens=seg_tokens,max_new_tokens=512)
|
| 496 |
+
# return answers
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def inference_RAG_chatGPT(self, instructions: str, context_list) -> str:
|
| 500 |
+
batch_preds=[]
|
| 501 |
+
for context,instruction in zip(context_list,instructions):
|
| 502 |
+
prompt="Your task is to answer questions for long video \n\n Given these related information from the most related clips: \n "+context +"\n\n" +"Answer this question: "+instruction
|
| 503 |
+
while True:
|
| 504 |
+
try:
|
| 505 |
+
response = client.ChatCompletion.create(
|
| 506 |
+
model="gpt-4o",
|
| 507 |
+
messages=[
|
| 508 |
+
{
|
| 509 |
+
"role": "user",
|
| 510 |
+
"content": prompt
|
| 511 |
+
}],
|
| 512 |
+
)
|
| 513 |
+
answer=response.choices[0].message['content']
|
| 514 |
+
batch_preds.append(answer)
|
| 515 |
+
break
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print("chat gpt error",e)
|
| 518 |
+
time.sleep(50)
|
| 519 |
+
|
| 520 |
+
return batch_preds
|
| 521 |
+
|
| 522 |
+
def get_most_related_clips(self,related_context_keys):
|
| 523 |
+
most_related_clips=set()
|
| 524 |
+
for context_key in related_context_keys:
|
| 525 |
+
if len(context_key.split('__'))>1:
|
| 526 |
+
most_related_clips.add(context_key.split('__')[1])
|
| 527 |
+
if len(most_related_clips)==self.args.neighbours:
|
| 528 |
+
break
|
| 529 |
+
assert len(most_related_clips)!=0, f"No related clips found {related_context_keys}"
|
| 530 |
+
return list(most_related_clips)
|
| 531 |
+
def get_related_context(self, external_memory,related_context_keys):
|
| 532 |
+
related_information=""
|
| 533 |
+
most_related_clips=self.get_most_related_clips(related_context_keys)
|
| 534 |
+
for clip_name in most_related_clips:
|
| 535 |
+
clip_conversation=""
|
| 536 |
+
general_sum=""
|
| 537 |
+
for key in external_memory.documents.keys():
|
| 538 |
+
if clip_name in key and 'caption' in key:
|
| 539 |
+
general_sum="Clip Summary: "+external_memory.documents[key]
|
| 540 |
+
if clip_name in key and 'subtitle' in key:
|
| 541 |
+
clip_conversation="Clip Subtitles: "+external_memory.documents[key]
|
| 542 |
+
related_information+=f"{general_sum},{clip_conversation}\n"
|
| 543 |
+
return related_information
|
| 544 |
+
def inference(self,video_path, use_subtitles=True, instruction="", number_of_neighbours=3):
|
| 545 |
+
start_time = time.time()
|
| 546 |
+
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
| 547 |
+
self.args.neighbours = number_of_neighbours
|
| 548 |
+
print(f"Video name: {video_name}")
|
| 549 |
+
video_duration = mp.VideoFileClip(video_path).duration
|
| 550 |
+
print(f"Video duration: {video_duration:.2f} seconds")
|
| 551 |
+
# if the video duration is more than 2 minutes we need to run the long inference
|
| 552 |
+
if video_duration > 180 :
|
| 553 |
+
print("Long video")
|
| 554 |
+
# if the video data is already stored in the external memory, we can use it directly else we need to run the long inference
|
| 555 |
+
file_path=f'new_workspace/clips_summary/demo/{video_name}.json'
|
| 556 |
+
if not os.path.exists(file_path):
|
| 557 |
+
print("Clips summary is not ready")
|
| 558 |
+
videos_list,tmp_save_path=self.split_long_video_into_clips(video_path)
|
| 559 |
+
subtitle_paths = []
|
| 560 |
+
for video_p in videos_list:
|
| 561 |
+
clip_path = os.path.join(tmp_save_path, video_p)
|
| 562 |
+
subtitle_path = self.get_subtitles(clip_path) if use_subtitles else None
|
| 563 |
+
subtitle_paths.append(subtitle_path)
|
| 564 |
+
clips_summary = self.long_inference_video(videos_list,tmp_save_path,subtitle_paths)
|
| 565 |
+
else:
|
| 566 |
+
print("External memory is ready")
|
| 567 |
+
os.makedirs("new_workspace/embedding/demo", exist_ok=True)
|
| 568 |
+
os.makedirs("new_workspace/open_ai_embedding/demo", exist_ok=True)
|
| 569 |
+
if self.args.use_openai_embedding:
|
| 570 |
+
embedding_path=f"new_workspace/open_ai_embedding/demo/{video_name}.pkl"
|
| 571 |
+
else:
|
| 572 |
+
embedding_path=f"new_workspace/embedding/demo/{video_name}.pkl"
|
| 573 |
+
external_memory=MemoryIndex(self.args.neighbours,use_openai=self.args.use_openai_embedding)
|
| 574 |
+
if os.path.exists(embedding_path):
|
| 575 |
+
print("Loading embeddings from pkl file")
|
| 576 |
+
external_memory.load_embeddings_from_pkl(embedding_path)
|
| 577 |
+
else:
|
| 578 |
+
# will embed the information and save it in the pkl file
|
| 579 |
+
external_memory.load_documents_from_json(file_path,embedding_path)
|
| 580 |
+
# get the most similar context from the external memory to this instruction
|
| 581 |
+
|
| 582 |
+
related_context_documents,related_context_keys = external_memory.search_by_similarity(instruction)
|
| 583 |
+
related_information=self.get_related_context(external_memory,related_context_keys)
|
| 584 |
+
pred=self.inference_RAG([instruction],[related_information])
|
| 585 |
+
else:
|
| 586 |
+
print("Short video")
|
| 587 |
+
self.video_name=video_path.split('/')[-1].split('.')[0]
|
| 588 |
+
pred=self.short_video_inference(video_path,instruction,use_subtitles)
|
| 589 |
+
processing_time = time.time() - start_time
|
| 590 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
| 591 |
+
return {
|
| 592 |
+
'video_name': os.path.splitext(os.path.basename(video_path))[0],
|
| 593 |
+
'pred': pred,
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def run_batch(self, video_paths, instructions,subtitle_paths,previous_caption="") -> List[str]:
|
| 598 |
+
|
| 599 |
+
prepared_images_batch = []
|
| 600 |
+
prepared_instructions_batch = []
|
| 601 |
+
lengths_batch = []
|
| 602 |
+
videos_conversations=[]
|
| 603 |
+
|
| 604 |
+
for i,video_path, instruction in zip(range(len(video_paths)),video_paths, instructions):
|
| 605 |
+
subtitle_path = subtitle_paths[i]
|
| 606 |
+
prepared_images, prepared_instruction,video_conversation = self.prepare_input( video_path, subtitle_path, instruction,previous_caption)
|
| 607 |
+
|
| 608 |
+
if prepared_images is None:
|
| 609 |
+
print(f"Error: Unable to open video at {video_path}. Check the path and try again.")
|
| 610 |
+
continue
|
| 611 |
+
videos_conversations.append(video_conversation)
|
| 612 |
+
conversation = CONV_VISION.copy()
|
| 613 |
+
conversation.system = ""
|
| 614 |
+
conversation.append_message(conversation.roles[0], prepared_instruction)
|
| 615 |
+
conversation.append_message(conversation.roles[1], None)
|
| 616 |
+
prepared_instructions_batch.append(conversation.get_prompt())
|
| 617 |
+
prepared_images_batch.append(prepared_images)
|
| 618 |
+
lengths_batch.append(len(prepared_images))
|
| 619 |
+
|
| 620 |
+
if not prepared_images_batch:
|
| 621 |
+
return []
|
| 622 |
+
|
| 623 |
+
prepared_images_batch = torch.stack(prepared_images_batch)
|
| 624 |
+
answers=self.model.generate(prepared_images_batch, prepared_instructions_batch, max_new_tokens=self.args.max_new_tokens, do_sample=False, lengths=lengths_batch, num_beams=1)
|
| 625 |
+
return answers , videos_conversations
|
| 626 |
+
|
| 627 |
+
def run_images_features (self,img_embeds,prepared_instruction):
|
| 628 |
+
lengths=[]
|
| 629 |
+
prompts=[]
|
| 630 |
+
for i in range(img_embeds.shape[0]):
|
| 631 |
+
conv = CONV_VISION.copy()
|
| 632 |
+
conv.system = ""
|
| 633 |
+
conv.append_message(conv.roles[0], prepared_instruction[i])
|
| 634 |
+
conv.append_message(conv.roles[1], None)
|
| 635 |
+
prompts.append(conv.get_prompt())
|
| 636 |
+
lengths.append(len(img_embeds[i]))
|
| 637 |
+
|
| 638 |
+
answers = self.model.generate(images=None,img_embeds=img_embeds,texts=prompts, max_new_tokens=300, do_sample=False, lengths=lengths,num_beams=1)
|
| 639 |
+
return answers
|
| 640 |
+
|
| 641 |
+
def run_images (self,prepared_images,prepared_instruction):
|
| 642 |
+
lengths=[]
|
| 643 |
+
prompts=[]
|
| 644 |
+
for i in range(prepared_images.shape[0]):
|
| 645 |
+
conv = CONV_VISION.copy()
|
| 646 |
+
conv.system = ""
|
| 647 |
+
conv.append_message(conv.roles[0], prepared_instruction[i])
|
| 648 |
+
conv.append_message(conv.roles[1], None)
|
| 649 |
+
prompts.append(conv.get_prompt())
|
| 650 |
+
lengths.append(len(prepared_images[i]))
|
| 651 |
+
answers = self.model.generate(prepared_images, prompts, max_new_tokens=300, do_sample=False, lengths=lengths,num_beams=1)
|
| 652 |
+
return answers
|
| 653 |
+
|
| 654 |
+
|
index.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import List, Dict, Tuple, Union
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import pickle
|
| 12 |
+
from openai import OpenAI
|
| 13 |
+
import os
|
| 14 |
+
import torch
|
| 15 |
+
import time
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
class MemoryIndex:
|
| 19 |
+
def __init__(self,number_of_neighbours,use_openai=False):
|
| 20 |
+
self.documents = {}
|
| 21 |
+
self.document_vectors = {}
|
| 22 |
+
self.use_openai=use_openai
|
| 23 |
+
if use_openai:
|
| 24 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 25 |
+
self.client = OpenAI(api_key=api_key)
|
| 26 |
+
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
| 27 |
+
# self.model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
|
| 28 |
+
with open('test_configs/llama2_test_config.yaml') as file:
|
| 29 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
| 30 |
+
embedding_gpu_id=config['model']['minigpt4_gpu_id']
|
| 31 |
+
self.device = f"cuda:{embedding_gpu_id}" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
self.number_of_neighbours=int(number_of_neighbours)
|
| 33 |
+
|
| 34 |
+
def load_documents_from_json(self, file_path,emdedding_path=""):
|
| 35 |
+
|
| 36 |
+
with open(file_path, 'r') as file:
|
| 37 |
+
data = json.load(file)
|
| 38 |
+
for doc_id, doc_data in data.items():
|
| 39 |
+
self.documents[doc_id] = doc_data
|
| 40 |
+
self.document_vectors[doc_id] = self._compute_sentence_embedding(doc_data)
|
| 41 |
+
|
| 42 |
+
# save self.documents and self.document_vectors to pkl file
|
| 43 |
+
m=[self.documents,self.document_vectors]
|
| 44 |
+
with open(emdedding_path, 'wb') as file:
|
| 45 |
+
pickle.dump(m, file)
|
| 46 |
+
return emdedding_path
|
| 47 |
+
def load_embeddings_from_pkl(self, pkl_file_path):
|
| 48 |
+
#read the pkl file
|
| 49 |
+
with open(pkl_file_path, 'rb') as file:
|
| 50 |
+
data = pickle.load(file)
|
| 51 |
+
self.documents=data[0]
|
| 52 |
+
self.document_vectors=data[1]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_data_from_pkl(self, pkl_file_path):
|
| 56 |
+
with open(pkl_file_path, 'rb') as file:
|
| 57 |
+
data = pickle.load(file)
|
| 58 |
+
for doc_id, doc_data in data.items():
|
| 59 |
+
self.documents[doc_id] = doc_data
|
| 60 |
+
self.document_vectors[doc_id] = doc_data
|
| 61 |
+
def _compute_sentence_embedding(self, text: str) -> torch.Tensor:
|
| 62 |
+
if self.use_openai:
|
| 63 |
+
done=False
|
| 64 |
+
while not done:
|
| 65 |
+
try:
|
| 66 |
+
embedding=self.client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding
|
| 67 |
+
# Convert the list to a PyTorch tensor
|
| 68 |
+
embedding = torch.tensor(embedding)
|
| 69 |
+
done=True
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print("error",e)
|
| 72 |
+
print("text",text)
|
| 73 |
+
# sleep for 5 seconds and try again
|
| 74 |
+
time.sleep(5)
|
| 75 |
+
continue
|
| 76 |
+
else:
|
| 77 |
+
return self.model.encode(text, convert_to_tensor=True).to(self.device)
|
| 78 |
+
|
| 79 |
+
return embedding
|
| 80 |
+
|
| 81 |
+
def search_by_similarity(self, query: str) -> List[str]:
|
| 82 |
+
|
| 83 |
+
query_vector = self._compute_sentence_embedding(query)
|
| 84 |
+
scores = {doc_id: torch.nn.functional.cosine_similarity(query_vector, doc_vector, dim=0).item()
|
| 85 |
+
for doc_id, doc_vector in self.document_vectors.items()}
|
| 86 |
+
sorted_doc_ids = sorted(scores, key=scores.get, reverse=True)
|
| 87 |
+
sorted_documents=[self.documents[doc_id] for doc_id in sorted_doc_ids]
|
| 88 |
+
if self.number_of_neighbours == -1:
|
| 89 |
+
return list(self.documents.values()), list(self.documents.keys())
|
| 90 |
+
if self.number_of_neighbours > len(sorted_documents):
|
| 91 |
+
return sorted_documents, sorted_doc_ids
|
| 92 |
+
# if the retrieved document is the summary, return the summary and the next document to grauntee that always retieve clip name.
|
| 93 |
+
if self.number_of_neighbours==1 and sorted_doc_ids[0]=='summary':
|
| 94 |
+
return sorted_documents[0:2], sorted_doc_ids[:2]
|
| 95 |
+
print("Number of neighbours",self.number_of_neighbours)
|
| 96 |
+
return sorted_documents[:self.number_of_neighbours], sorted_doc_ids[:self.number_of_neighbours]
|
| 97 |
+
|
| 98 |
+
# # main function
|
| 99 |
+
# if __name__ == "__main__":
|
| 100 |
+
# memory_index = MemoryIndex(-1,use_openai=True)
|
| 101 |
+
# memory_index.load_documents_from_json('workspace/results/llama_vid/tt0035423.json')
|
| 102 |
+
# print(memory_index.documents.keys())
|
| 103 |
+
# docs,keys=memory_index.search_by_similarity('kerolos')
|
minigpt4/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
All rights reserved.
|
| 4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
|
| 13 |
+
from minigpt4.common.registry import registry
|
| 14 |
+
|
| 15 |
+
from minigpt4.datasets.builders import *
|
| 16 |
+
from minigpt4.models import *
|
| 17 |
+
from minigpt4.processors import *
|
| 18 |
+
from minigpt4.tasks import *
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
| 23 |
+
|
| 24 |
+
registry.register_path("library_root", root_dir)
|
| 25 |
+
repo_root = os.path.join(root_dir, "..")
|
| 26 |
+
registry.register_path("repo_root", repo_root)
|
| 27 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
| 28 |
+
registry.register_path("cache_root", cache_root)
|
| 29 |
+
|
| 30 |
+
registry.register("MAX_INT", sys.maxsize)
|
| 31 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
minigpt4/common/__init__.py
ADDED
|
File without changes
|
minigpt4/common/config.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
All rights reserved.
|
| 4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import json
|
| 10 |
+
from typing import Dict
|
| 11 |
+
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
from minigpt4.common.registry import registry
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Config:
|
| 17 |
+
def __init__(self, args):
|
| 18 |
+
self.config = {}
|
| 19 |
+
|
| 20 |
+
self.args = args
|
| 21 |
+
|
| 22 |
+
# Register the config and configuration for setup
|
| 23 |
+
registry.register("configuration", self)
|
| 24 |
+
|
| 25 |
+
user_config = self._build_opt_list(self.args.options)
|
| 26 |
+
|
| 27 |
+
config = OmegaConf.load(self.args.cfg_path)
|
| 28 |
+
|
| 29 |
+
runner_config = self.build_runner_config(config)
|
| 30 |
+
model_config = self.build_model_config(config, **user_config)
|
| 31 |
+
dataset_config = self.build_dataset_config(config)
|
| 32 |
+
|
| 33 |
+
# Validate the user-provided runner configuration
|
| 34 |
+
# model and dataset configuration are supposed to be validated by the respective classes
|
| 35 |
+
# [TODO] validate the model/dataset configuration
|
| 36 |
+
# self._validate_runner_config(runner_config)
|
| 37 |
+
|
| 38 |
+
# Override the default configuration with user options.
|
| 39 |
+
self.config = OmegaConf.merge(
|
| 40 |
+
runner_config, model_config, dataset_config, user_config
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _validate_runner_config(self, runner_config):
|
| 44 |
+
"""
|
| 45 |
+
This method validates the configuration, such that
|
| 46 |
+
1) all the user specified options are valid;
|
| 47 |
+
2) no type mismatches between the user specified options and the config.
|
| 48 |
+
"""
|
| 49 |
+
runner_config_validator = create_runner_config_validator()
|
| 50 |
+
runner_config_validator.validate(runner_config)
|
| 51 |
+
|
| 52 |
+
def _build_opt_list(self, opts):
|
| 53 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
| 54 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def build_model_config(config, **kwargs):
|
| 58 |
+
model = config.get("model", None)
|
| 59 |
+
assert model is not None, "Missing model configuration file."
|
| 60 |
+
|
| 61 |
+
model_cls = registry.get_model_class(model.arch)
|
| 62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
| 63 |
+
|
| 64 |
+
model_type = kwargs.get("model.model_type", None)
|
| 65 |
+
if not model_type:
|
| 66 |
+
model_type = model.get("model_type", None)
|
| 67 |
+
# else use the model type selected by user.
|
| 68 |
+
|
| 69 |
+
assert model_type is not None, "Missing model_type."
|
| 70 |
+
|
| 71 |
+
print("--------------")
|
| 72 |
+
print("model arch",model.arch)
|
| 73 |
+
print("model cls",model_cls)
|
| 74 |
+
|
| 75 |
+
model_config_path = model_cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]
|
| 76 |
+
|
| 77 |
+
model_config = OmegaConf.create()
|
| 78 |
+
# hierarchy override, customized config > default config
|
| 79 |
+
model_config = OmegaConf.merge(
|
| 80 |
+
model_config,
|
| 81 |
+
OmegaConf.load(model_config_path),
|
| 82 |
+
{"model": config["model"]},
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return model_config
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def build_runner_config(config):
|
| 89 |
+
return {"run": config.run}
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def build_dataset_config(config):
|
| 93 |
+
datasets = config.get("datasets", None)
|
| 94 |
+
if datasets is None:
|
| 95 |
+
raise KeyError(
|
| 96 |
+
"Expecting 'datasets' as the root key for dataset configuration."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
dataset_config = OmegaConf.create()
|
| 100 |
+
|
| 101 |
+
for dataset_name in datasets:
|
| 102 |
+
|
| 103 |
+
print("dataset name", dataset_name)
|
| 104 |
+
builder_cls = registry.get_builder_class(dataset_name)
|
| 105 |
+
|
| 106 |
+
dataset_config_type = datasets[dataset_name].get("type", "default")
|
| 107 |
+
dataset_config_path = builder_cls.default_config_path(
|
| 108 |
+
type=dataset_config_type
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# hierarchy override, customized config > default config
|
| 112 |
+
dataset_config = OmegaConf.merge(
|
| 113 |
+
dataset_config,
|
| 114 |
+
OmegaConf.load(dataset_config_path),
|
| 115 |
+
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return dataset_config
|
| 119 |
+
|
| 120 |
+
def _convert_to_dot_list(self, opts):
|
| 121 |
+
if opts is None:
|
| 122 |
+
opts = []
|
| 123 |
+
|
| 124 |
+
if len(opts) == 0:
|
| 125 |
+
return opts
|
| 126 |
+
|
| 127 |
+
has_equal = opts[0].find("=") != -1
|
| 128 |
+
|
| 129 |
+
if has_equal:
|
| 130 |
+
return opts
|
| 131 |
+
|
| 132 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
| 133 |
+
|
| 134 |
+
def get_config(self):
|
| 135 |
+
return self.config
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def run_cfg(self):
|
| 139 |
+
return self.config.run
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def datasets_cfg(self):
|
| 143 |
+
return self.config.datasets
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def model_cfg(self):
|
| 147 |
+
return self.config.model
|
| 148 |
+
|
| 149 |
+
def pretty_print(self):
|
| 150 |
+
logging.info("\n===== Running Parameters =====")
|
| 151 |
+
logging.info(self._convert_node_to_json(self.config.run))
|
| 152 |
+
|
| 153 |
+
logging.info("\n====== Dataset Attributes ======")
|
| 154 |
+
datasets = self.config.datasets
|
| 155 |
+
|
| 156 |
+
for dataset in datasets:
|
| 157 |
+
if dataset in self.config.datasets:
|
| 158 |
+
logging.info(f"\n======== {dataset} =======")
|
| 159 |
+
dataset_config = self.config.datasets[dataset]
|
| 160 |
+
logging.info(self._convert_node_to_json(dataset_config))
|
| 161 |
+
else:
|
| 162 |
+
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
| 163 |
+
|
| 164 |
+
logging.info(f"\n====== Model Attributes ======")
|
| 165 |
+
logging.info(self._convert_node_to_json(self.config.model))
|
| 166 |
+
|
| 167 |
+
def _convert_node_to_json(self, node):
|
| 168 |
+
container = OmegaConf.to_container(node, resolve=True)
|
| 169 |
+
return json.dumps(container, indent=4, sort_keys=True)
|
| 170 |
+
|
| 171 |
+
def to_dict(self):
|
| 172 |
+
return OmegaConf.to_container(self.config)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def node_to_dict(node):
|
| 176 |
+
return OmegaConf.to_container(node)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class ConfigValidator:
|
| 180 |
+
"""
|
| 181 |
+
This is a preliminary implementation to centralize and validate the configuration.
|
| 182 |
+
May be altered in the future.
|
| 183 |
+
|
| 184 |
+
A helper class to validate configurations from yaml file.
|
| 185 |
+
|
| 186 |
+
This serves the following purposes:
|
| 187 |
+
1. Ensure all the options in the yaml are defined, raise error if not.
|
| 188 |
+
2. when type mismatches are found, the validator will raise an error.
|
| 189 |
+
3. a central place to store and display helpful messages for supported configurations.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
class _Argument:
|
| 194 |
+
def __init__(self, name, choices=None, type=None, help=None):
|
| 195 |
+
self.name = name
|
| 196 |
+
self.val = None
|
| 197 |
+
self.choices = choices
|
| 198 |
+
self.type = type
|
| 199 |
+
self.help = help
|
| 200 |
+
|
| 201 |
+
def __str__(self):
|
| 202 |
+
s = f"{self.name}={self.val}"
|
| 203 |
+
if self.type is not None:
|
| 204 |
+
s += f", ({self.type})"
|
| 205 |
+
if self.choices is not None:
|
| 206 |
+
s += f", choices: {self.choices}"
|
| 207 |
+
if self.help is not None:
|
| 208 |
+
s += f", ({self.help})"
|
| 209 |
+
return s
|
| 210 |
+
|
| 211 |
+
def __init__(self, description):
|
| 212 |
+
self.description = description
|
| 213 |
+
|
| 214 |
+
self.arguments = dict()
|
| 215 |
+
|
| 216 |
+
self.parsed_args = None
|
| 217 |
+
|
| 218 |
+
def __getitem__(self, key):
|
| 219 |
+
assert self.parsed_args is not None, "No arguments parsed yet."
|
| 220 |
+
|
| 221 |
+
return self.parsed_args[key]
|
| 222 |
+
|
| 223 |
+
def __str__(self) -> str:
|
| 224 |
+
return self.format_help()
|
| 225 |
+
|
| 226 |
+
def add_argument(self, *args, **kwargs):
|
| 227 |
+
"""
|
| 228 |
+
Assume the first argument is the name of the argument.
|
| 229 |
+
"""
|
| 230 |
+
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
| 231 |
+
|
| 232 |
+
def validate(self, config=None):
|
| 233 |
+
"""
|
| 234 |
+
Convert yaml config (dict-like) to list, required by argparse.
|
| 235 |
+
"""
|
| 236 |
+
for k, v in config.items():
|
| 237 |
+
assert (
|
| 238 |
+
k in self.arguments
|
| 239 |
+
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
| 240 |
+
|
| 241 |
+
if self.arguments[k].type is not None:
|
| 242 |
+
try:
|
| 243 |
+
self.arguments[k].val = self.arguments[k].type(v)
|
| 244 |
+
except ValueError:
|
| 245 |
+
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
| 246 |
+
|
| 247 |
+
if self.arguments[k].choices is not None:
|
| 248 |
+
assert (
|
| 249 |
+
v in self.arguments[k].choices
|
| 250 |
+
), f"""{k} must be one of {self.arguments[k].choices}."""
|
| 251 |
+
|
| 252 |
+
return config
|
| 253 |
+
|
| 254 |
+
def format_arguments(self):
|
| 255 |
+
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
| 256 |
+
|
| 257 |
+
def format_help(self):
|
| 258 |
+
# description + key-value pair string for each argument
|
| 259 |
+
help_msg = str(self.description)
|
| 260 |
+
return help_msg + ", available arguments: " + self.format_arguments()
|
| 261 |
+
|
| 262 |
+
def print_help(self):
|
| 263 |
+
# display help message
|
| 264 |
+
print(self.format_help())
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def create_runner_config_validator():
|
| 268 |
+
validator = ConfigValidator(description="Runner configurations")
|
| 269 |
+
|
| 270 |
+
validator.add_argument(
|
| 271 |
+
"runner",
|
| 272 |
+
type=str,
|
| 273 |
+
choices=["runner_base", "runner_iter"],
|
| 274 |
+
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
| 275 |
+
runner runs based on iters. Default: runner_base""",
|
| 276 |
+
)
|
| 277 |
+
# add argumetns for training dataset ratios
|
| 278 |
+
validator.add_argument(
|
| 279 |
+
"train_dataset_ratios",
|
| 280 |
+
type=Dict[str, float],
|
| 281 |
+
help="""Ratios of training dataset. This is used in iteration-based runner.
|
| 282 |
+
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
| 283 |
+
Default: None""",
|
| 284 |
+
)
|
| 285 |
+
validator.add_argument(
|
| 286 |
+
"max_iters",
|
| 287 |
+
type=float,
|
| 288 |
+
help="Maximum number of iterations to run.",
|
| 289 |
+
)
|
| 290 |
+
validator.add_argument(
|
| 291 |
+
"max_epoch",
|
| 292 |
+
type=int,
|
| 293 |
+
help="Maximum number of epochs to run.",
|
| 294 |
+
)
|
| 295 |
+
# add arguments for iters_per_inner_epoch
|
| 296 |
+
validator.add_argument(
|
| 297 |
+
"iters_per_inner_epoch",
|
| 298 |
+
type=float,
|
| 299 |
+
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
| 300 |
+
)
|
| 301 |
+
lr_scheds_choices = registry.list_lr_schedulers()
|
| 302 |
+
validator.add_argument(
|
| 303 |
+
"lr_sched",
|
| 304 |
+
type=str,
|
| 305 |
+
choices=lr_scheds_choices,
|
| 306 |
+
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
| 307 |
+
)
|
| 308 |
+
task_choices = registry.list_tasks()
|
| 309 |
+
validator.add_argument(
|
| 310 |
+
"task",
|
| 311 |
+
type=str,
|
| 312 |
+
choices=task_choices,
|
| 313 |
+
help="Task to use, from {}".format(task_choices),
|
| 314 |
+
)
|
| 315 |
+
# add arguments for init_lr
|
| 316 |
+
validator.add_argument(
|
| 317 |
+
"init_lr",
|
| 318 |
+
type=float,
|
| 319 |
+
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
| 320 |
+
)
|
| 321 |
+
# add arguments for min_lr
|
| 322 |
+
validator.add_argument(
|
| 323 |
+
"min_lr",
|
| 324 |
+
type=float,
|
| 325 |
+
help="Minimum learning rate (after decay).",
|
| 326 |
+
)
|
| 327 |
+
# add arguments for warmup_lr
|
| 328 |
+
validator.add_argument(
|
| 329 |
+
"warmup_lr",
|
| 330 |
+
type=float,
|
| 331 |
+
help="Starting learning rate for warmup.",
|
| 332 |
+
)
|
| 333 |
+
# add arguments for learning rate decay rate
|
| 334 |
+
validator.add_argument(
|
| 335 |
+
"lr_decay_rate",
|
| 336 |
+
type=float,
|
| 337 |
+
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
| 338 |
+
)
|
| 339 |
+
# add arguments for weight decay
|
| 340 |
+
validator.add_argument(
|
| 341 |
+
"weight_decay",
|
| 342 |
+
type=float,
|
| 343 |
+
help="Weight decay rate.",
|
| 344 |
+
)
|
| 345 |
+
# add arguments for training batch size
|
| 346 |
+
validator.add_argument(
|
| 347 |
+
"batch_size_train",
|
| 348 |
+
type=int,
|
| 349 |
+
help="Training batch size.",
|
| 350 |
+
)
|
| 351 |
+
# add arguments for evaluation batch size
|
| 352 |
+
validator.add_argument(
|
| 353 |
+
"batch_size_eval",
|
| 354 |
+
type=int,
|
| 355 |
+
help="Evaluation batch size, including validation and testing.",
|
| 356 |
+
)
|
| 357 |
+
# add arguments for number of workers for data loading
|
| 358 |
+
validator.add_argument(
|
| 359 |
+
"num_workers",
|
| 360 |
+
help="Number of workers for data loading.",
|
| 361 |
+
)
|
| 362 |
+
# add arguments for warm up steps
|
| 363 |
+
validator.add_argument(
|
| 364 |
+
"warmup_steps",
|
| 365 |
+
type=int,
|
| 366 |
+
help="Number of warmup steps. Required if a warmup schedule is used.",
|
| 367 |
+
)
|
| 368 |
+
# add arguments for random seed
|
| 369 |
+
validator.add_argument(
|
| 370 |
+
"seed",
|
| 371 |
+
type=int,
|
| 372 |
+
help="Random seed.",
|
| 373 |
+
)
|
| 374 |
+
# add arguments for output directory
|
| 375 |
+
validator.add_argument(
|
| 376 |
+
"output_dir",
|
| 377 |
+
type=str,
|
| 378 |
+
help="Output directory to save checkpoints and logs.",
|
| 379 |
+
)
|
| 380 |
+
# add arguments for whether only use evaluation
|
| 381 |
+
validator.add_argument(
|
| 382 |
+
"evaluate",
|
| 383 |
+
help="Whether to only evaluate the model. If true, training will not be performed.",
|
| 384 |
+
)
|
| 385 |
+
# add arguments for splits used for training, e.g. ["train", "val"]
|
| 386 |
+
validator.add_argument(
|
| 387 |
+
"train_splits",
|
| 388 |
+
type=list,
|
| 389 |
+
help="Splits to use for training.",
|
| 390 |
+
)
|
| 391 |
+
# add arguments for splits used for validation, e.g. ["val"]
|
| 392 |
+
validator.add_argument(
|
| 393 |
+
"valid_splits",
|
| 394 |
+
type=list,
|
| 395 |
+
help="Splits to use for validation. If not provided, will skip the validation.",
|
| 396 |
+
)
|
| 397 |
+
# add arguments for splits used for testing, e.g. ["test"]
|
| 398 |
+
validator.add_argument(
|
| 399 |
+
"test_splits",
|
| 400 |
+
type=list,
|
| 401 |
+
help="Splits to use for testing. If not provided, will skip the testing.",
|
| 402 |
+
)
|
| 403 |
+
# add arguments for accumulating gradient for iterations
|
| 404 |
+
validator.add_argument(
|
| 405 |
+
"accum_grad_iters",
|
| 406 |
+
type=int,
|
| 407 |
+
help="Number of iterations to accumulate gradient for.",
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# ====== distributed training ======
|
| 411 |
+
validator.add_argument(
|
| 412 |
+
"device",
|
| 413 |
+
type=str,
|
| 414 |
+
choices=["cpu", "cuda"],
|
| 415 |
+
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
| 416 |
+
)
|
| 417 |
+
validator.add_argument(
|
| 418 |
+
"world_size",
|
| 419 |
+
type=int,
|
| 420 |
+
help="Number of processes participating in the job.",
|
| 421 |
+
)
|
| 422 |
+
validator.add_argument("dist_url", type=str)
|
| 423 |
+
validator.add_argument("distributed", type=bool)
|
| 424 |
+
# add arguments to opt using distributed sampler during evaluation or not
|
| 425 |
+
validator.add_argument(
|
| 426 |
+
"use_dist_eval_sampler",
|
| 427 |
+
type=bool,
|
| 428 |
+
help="Whether to use distributed sampler during evaluation or not.",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# ====== task specific ======
|
| 432 |
+
# generation task specific arguments
|
| 433 |
+
# add arguments for maximal length of text output
|
| 434 |
+
validator.add_argument(
|
| 435 |
+
"max_len",
|
| 436 |
+
type=int,
|
| 437 |
+
help="Maximal length of text output.",
|
| 438 |
+
)
|
| 439 |
+
# add arguments for minimal length of text output
|
| 440 |
+
validator.add_argument(
|
| 441 |
+
"min_len",
|
| 442 |
+
type=int,
|
| 443 |
+
help="Minimal length of text output.",
|
| 444 |
+
)
|
| 445 |
+
# add arguments number of beams
|
| 446 |
+
validator.add_argument(
|
| 447 |
+
"num_beams",
|
| 448 |
+
type=int,
|
| 449 |
+
help="Number of beams used for beam search.",
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# vqa task specific arguments
|
| 453 |
+
# add arguments for number of answer candidates
|
| 454 |
+
validator.add_argument(
|
| 455 |
+
"num_ans_candidates",
|
| 456 |
+
type=int,
|
| 457 |
+
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
| 458 |
+
)
|
| 459 |
+
# add arguments for inference method
|
| 460 |
+
validator.add_argument(
|
| 461 |
+
"inference_method",
|
| 462 |
+
type=str,
|
| 463 |
+
choices=["genearte", "rank"],
|
| 464 |
+
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# ====== model specific ======
|
| 468 |
+
validator.add_argument(
|
| 469 |
+
"k_test",
|
| 470 |
+
type=int,
|
| 471 |
+
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
return validator
|
minigpt4/common/dist_utils.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
All rights reserved.
|
| 4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import datetime
|
| 9 |
+
import functools
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import timm.models.hub as timm_hub
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def setup_for_distributed(is_master):
|
| 18 |
+
"""
|
| 19 |
+
This function disables printing when not in master process
|
| 20 |
+
"""
|
| 21 |
+
import builtins as __builtin__
|
| 22 |
+
|
| 23 |
+
builtin_print = __builtin__.print
|
| 24 |
+
|
| 25 |
+
def print(*args, **kwargs):
|
| 26 |
+
force = kwargs.pop("force", False)
|
| 27 |
+
if is_master or force:
|
| 28 |
+
builtin_print(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
__builtin__.print = print
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_dist_avail_and_initialized():
|
| 34 |
+
if not dist.is_available():
|
| 35 |
+
return False
|
| 36 |
+
if not dist.is_initialized():
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_world_size():
|
| 42 |
+
if not is_dist_avail_and_initialized():
|
| 43 |
+
return 1
|
| 44 |
+
return dist.get_world_size()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_rank():
|
| 48 |
+
if not is_dist_avail_and_initialized():
|
| 49 |
+
return 0
|
| 50 |
+
return dist.get_rank()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def is_main_process():
|
| 54 |
+
return get_rank() == 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def init_distributed_mode(args):
|
| 58 |
+
if args.distributed is False:
|
| 59 |
+
print("Not using distributed mode")
|
| 60 |
+
args.rank = 0
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 64 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 65 |
+
|
| 66 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 67 |
+
args.rank = int(os.environ["RANK"])
|
| 68 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 69 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 70 |
+
elif "SLURM_PROCID" in os.environ:
|
| 71 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
| 72 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 73 |
+
else:
|
| 74 |
+
print("Not using distributed mode")
|
| 75 |
+
args.distributed = False
|
| 76 |
+
args.rank = 0
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
args.distributed = True
|
| 80 |
+
|
| 81 |
+
torch.cuda.set_device(args.gpu)
|
| 82 |
+
args.dist_backend = "nccl"
|
| 83 |
+
print(
|
| 84 |
+
"| distributed init (rank {}, world {}): {}".format(
|
| 85 |
+
args.rank, args.world_size, args.dist_url
|
| 86 |
+
),
|
| 87 |
+
flush=True,
|
| 88 |
+
)
|
| 89 |
+
torch.distributed.init_process_group(
|
| 90 |
+
backend=args.dist_backend,
|
| 91 |
+
init_method=args.dist_url,
|
| 92 |
+
world_size=args.world_size,
|
| 93 |
+
rank=args.rank,
|
| 94 |
+
timeout=datetime.timedelta(
|
| 95 |
+
days=365
|
| 96 |
+
), # allow auto-downloading and de-compressing
|
| 97 |
+
)
|
| 98 |
+
torch.distributed.barrier()
|
| 99 |
+
setup_for_distributed(args.rank == 0)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_dist_info():
|
| 103 |
+
if torch.__version__ < "1.0":
|
| 104 |
+
initialized = dist._initialized
|
| 105 |
+
else:
|
| 106 |
+
initialized = dist.is_initialized()
|
| 107 |
+
if initialized:
|
| 108 |
+
rank = dist.get_rank()
|
| 109 |
+
world_size = dist.get_world_size()
|
| 110 |
+
else: # non-distributed training
|
| 111 |
+
rank = 0
|
| 112 |
+
world_size = 1
|
| 113 |
+
return rank, world_size
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def main_process(func):
|
| 117 |
+
@functools.wraps(func)
|
| 118 |
+
def wrapper(*args, **kwargs):
|
| 119 |
+
rank, _ = get_dist_info()
|
| 120 |
+
if rank == 0:
|
| 121 |
+
return func(*args, **kwargs)
|
| 122 |
+
|
| 123 |
+
return wrapper
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
| 127 |
+
"""
|
| 128 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
| 129 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def get_cached_file_path():
|
| 133 |
+
# a hack to sync the file path across processes
|
| 134 |
+
parts = torch.hub.urlparse(url)
|
| 135 |
+
filename = os.path.basename(parts.path)
|
| 136 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
| 137 |
+
|
| 138 |
+
return cached_file
|
| 139 |
+
|
| 140 |
+
if is_main_process():
|
| 141 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
| 142 |
+
|
| 143 |
+
if is_dist_avail_and_initialized():
|
| 144 |
+
dist.barrier()
|
| 145 |
+
|
| 146 |
+
return get_cached_file_path()
|