tuandunghcmut commited on
Commit
7d9e5ac
·
verified ·
1 Parent(s): b02e3a6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Groma/pyproject.toml +32 -0
  2. InternVL/.gitignore +171 -0
  3. InternVL/.isort.cfg +26 -0
  4. InternVL/INSTALLATION.md +69 -0
  5. InternVL/requirements.txt +4 -0
  6. LLM2CLIP/FAQ.md +133 -0
  7. LLM2CLIP/LICENSE +21 -0
  8. LLM2CLIP/SUPPORT.md +25 -0
  9. LLaVA/.dockerignore +21 -0
  10. LLaVA/.editorconfig +18 -0
  11. LLaVA/LICENSE +201 -0
  12. LLaVA/cog.yaml +37 -0
  13. LLaVA/pyproject.toml +37 -0
  14. OpenSeeD/README.md +77 -0
  15. OpenSeeD/__init__.py +0 -0
  16. OpenSeeD/requirements.txt +30 -0
  17. Ovis/README.md +110 -0
  18. PaddleMIX/.copyright.hook +134 -0
  19. PaddleMIX/.style.yapf +3 -0
  20. PaddleMIX/LICENSE +201 -0
  21. PaddleMIX/README_EN.md +390 -0
  22. PaddleMIX/VERSION +1 -0
  23. PaddleMIX/check_env.sh +101 -0
  24. PaddleMIX/pyproject.toml +23 -0
  25. PaddleMIX/requirements.txt +15 -0
  26. VILA/LongVILA.md +79 -0
  27. VILA/convert_ckpt.py +91 -0
  28. VILA/environment_setup.sh +33 -0
  29. VILA/predict.py +189 -0
  30. VLMEvalKit/.pre-commit-config.yaml +30 -0
  31. VLMEvalKit/requirements.txt +30 -0
  32. a_distributed_notebook/FSDP_tutorial.md +519 -0
  33. a_distributed_notebook/temp/all_gather.py +116 -0
  34. a_distributed_notebook/temp/run_4.py +63 -0
  35. a_main_folder/convert_hf_dataset.ipynb +0 -0
  36. a_temp/deepseek_vl2.ipynb +0 -0
  37. a_temp/docs.html +32 -0
  38. a_temp/example_image.jpg +0 -0
  39. a_temp/openapi.json +1 -0
  40. a_temp/temp1.ipynb +330 -0
  41. a_temp/vllm_example.sh +412 -0
  42. groundingLMM/train.py +671 -0
  43. lightning-hydra-template/.github/codecov.yml +15 -0
  44. lightning-hydra-template/.github/workflows/test.yml +139 -0
  45. lightning-hydra-template/configs/__init__.py +1 -0
  46. lightning-hydra-template/configs/local/.gitkeep +0 -0
  47. lightning-hydra-template/configs/train.yaml +49 -0
  48. lightning-hydra-template/logs/.gitkeep +0 -0
  49. lightning-hydra-template/tests/test_datamodules.py +38 -0
  50. lightning-hydra-template/tests/test_eval.py +39 -0
Groma/pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools<67.0.0,>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "groma"
7
+ version = "1.0.0"
8
+ description = "Grounded Multimodal Large Language Models."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "einops", "fastapi", "gradio==3.23", "markdown2[all]", "numpy",
17
+ "requests", "sentencepiece", "tokenizers==0.12.1",
18
+ "uvicorn", "shortuuid", "scipy", "pycocotools", "pycocoevalcap",
19
+ "deepspeed==0.9.2", "peft==0.3.0", "terminaltables", "transformers==4.32.0",
20
+ "bitsandbytes==0.43.1",
21
+ "lvis @ git+https://github.com/lvis-dataset/lvis-api.git",
22
+ "accelerate @ git+https://github.com/huggingface/accelerate@a2d8f540c3ab37c8f84d616be1300a0572b69cf8"
23
+ ]
24
+
25
+ [project.urls]
26
+ "Homepage" = "https://groma-mllm.github.io/"
27
+
28
+ [tool.setuptools.packages.find]
29
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
30
+
31
+ [tool.wheel]
32
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
InternVL/.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .idea/
163
+
164
+ .DS_Store
165
+ data_process/
166
+ internvl_chat/work_dirs/
167
+ internvl_chat/unittest/
168
+ internvl_chat/data/
169
+ Husky2/*
170
+ data_process/
171
+ *distillation*
InternVL/.isort.cfg ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ line-length = 180
3
+ multi_line_output = 0
4
+ extra_standard_library = setuptools
5
+ known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
6
+ no_lines_before = STDLIB,LOCALFOLDER
7
+ default_section = THIRDPARTY
8
+
9
+ [yapf]
10
+ BASED_ON_STYLE = pep8
11
+ BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
12
+ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
13
+
14
+ [codespell]
15
+ skip = *.ipynb
16
+ quiet-level = 3
17
+ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood
18
+ © 2022 GitHub, Inc.
19
+ Terms
20
+ Privacy
21
+ Security
22
+ Status
23
+ Docs
24
+ Contact GitHub
25
+ Pricing
26
+ API
InternVL/INSTALLATION.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 🛠️ Installation
2
+
3
+ - Clone this repository:
4
+
5
+ ```bash
6
+ git clone https://github.com/OpenGVLab/InternVL.git
7
+ ```
8
+
9
+ - Create a conda virtual environment and activate it:
10
+
11
+ ```bash
12
+ conda create -n internvl python=3.9 -y
13
+ conda activate internvl
14
+ ```
15
+
16
+ - Install dependencies using `requirements.txt`:
17
+
18
+ ```bash
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ By default, our `requirements.txt` file includes the following dependencies:
23
+
24
+ - `-r requirements/internvl_chat.txt`
25
+ - `-r requirements/streamlit_demo.txt`
26
+ - `-r requirements/classification.txt`
27
+ - `-r requirements/segmentation.txt`
28
+
29
+ The `clip_benchmark.txt` is **not** included in the default installation. If you require the `clip_benchmark` functionality, please install it manually by running the following command:
30
+
31
+ ```bash
32
+ pip install -r requirements/clip_benchmark.txt
33
+ ```
34
+
35
+ ### Additional Instructions
36
+
37
+ - Install `flash-attn==2.3.6`:
38
+
39
+ ```bash
40
+ pip install flash-attn==2.3.6 --no-build-isolation
41
+ ```
42
+
43
+ Alternatively you can compile from source:
44
+
45
+ ```bash
46
+ git clone https://github.com/Dao-AILab/flash-attention.git
47
+ cd flash-attention
48
+ git checkout v2.3.6
49
+ python setup.py install
50
+ ```
51
+
52
+ - Install `mmcv-full==1.6.2` (optional, for `segmentation`):
53
+
54
+ ```bash
55
+ pip install -U openmim
56
+ mim install mmcv-full==1.6.2
57
+ ```
58
+
59
+ - Install `apex` (optional, for `segmentation`):
60
+
61
+ ```bash
62
+ git clone https://github.com/NVIDIA/apex.git
63
+ git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 # https://github.com/NVIDIA/apex/issues/1735
64
+ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
65
+ ```
66
+
67
+ If you encounter `ModuleNotFoundError: No module named 'fused_layer_norm_cuda'`, it is because apex's CUDA extensions are not being installed successfully. You can try uninstalling apex and the code will default to the PyTorch version of RMSNorm. Alternatively, if you prefer using apex, try adding a few lines to `setup.py` and then recompiling.
68
+
69
+ <img src=https://github.com/OpenGVLab/InternVL/assets/23737120/c04a989c-8024-49fa-b62c-2da623e63729 width=50%>
InternVL/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ -r requirements/internvl_chat.txt
2
+ -r requirements/streamlit_demo.txt
3
+ -r requirements/classification.txt
4
+ -r requirements/segmentation.txt
LLM2CLIP/FAQ.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Selected Representative Q&A
2
+
3
+ ## Q1:
4
+
5
+ > **Q: It is foreseeable that the technology of LLM2CLIP will be of great significance in expanding CLIP's support for more modal data. As far as the article is concerned, LLM2CLIP has surprisingly improved CLIP's adaptability to cross-language and long text tasks. At the same time, it also proposes application possibilities for higher-dimensional data modalities such as audio and video. Of course, this puts forward further requirements for LLM2CLIP's adaptation strategy and fine-tuning methods. Based on your team's current understanding of LLM2CLIP, what additional challenges will arise, for example, the feature space alignment problem of high-dimensional modalities?**
6
+
7
+ ![A1](https://via.placeholder.com/15/blue/000000?text=+) **A:** To be honest, we’re already exploring a video-based version of LLM2CLIP, including scaling up both the dataset size and model parameters by several orders of magnitude. Please stay tuned for our future updates, and if you’re interested, we’d be happy to discuss this further!
8
+
9
+ Here are some additional challenges I see in this area:
10
+
11
+ 1. **Enhancing the Supervisory Signal in Contrastive Learning:** While LLMs have a strong capability to understand text, providing valuable and rich textual information is equally critical. For instance, for video tasks, we could enrich the input with denser captions, prompts, or instructions. These could provide more complex and detailed information for the LLM to interpret, thereby enabling it to better guide the construction of the cross-modal space.
12
+
13
+ 2. **Expanding Contrastive Learning Loss Across Dimensions:** Contrastive learning losses can be applied across various dimensions, such as the temporal dimension in video data. Different prompts provided to the LLM could be designed to guide and control the training process in these additional dimensions, further strengthening the multimodal representations.
14
+
15
+ 3. **Tackling Complex Temporal Logic in Videos:** The challenges in video understanding often involve designing solutions for complex temporal relationships over extended time spans. Here, we could incorporate self-play techniques using the LLM to introduce tasks and increase the complexity of the training objectives. This might involve designing scenarios where the LLM can simulate and reason about sequences, further enhancing its learning.
16
+
17
+ ## Q2:
18
+
19
+ > **Q: What a groundbreaking paper on LLM2CLIP! The innovative integration of large language models with CLIP to enhance cross-modal representation learning is truly inspiring. The performance improvements demonstrated, particularly in long-text and short-text retrieval tasks, are impressive and have significant implications for the field of multimodal AI.**
20
+ >
21
+ > **My admiration for your work encourages me to inquire about the potential applications of LLM2CLIP in more specialized domains, such as medicine or law, where the precision and expertise of textual understanding are paramount. Therefore, I am curious to know if LLM2CLIP has been tested or if there are plans to test it with domain-specific texts that require a high degree of accuracy and proficiency.**
22
+ >
23
+ > Looking forward to your insights on this matter and how LLM2CLIP might be adapted or extended to meet the challenges of these specialized fields!
24
+ >
25
+ ![A2](https://via.placeholder.com/15/green/000000?text=+) **A:** Your idea is fantastic, and in fact, we have had similar thoughts. I believe there is significant potential in working on specialized fields, and here are my reasons:
26
+
27
+ 1. **Limited Data, High Impact:** Our work focuses on fine-tuning pre-trained CLIP models with very limited data for LLM2CLIP, ranging from 3M to 60M. Compared to the 1-2B data commonly used in CLIP pre-training, this is a small amount, yet it has already demonstrated substantial performance improvements. If we focus on specialized fields, we could leverage limited domain-specific data to train the model exceptionally well in a specific knowledge area. This approach could potentially resolve issues like perception or cognition hallucinations in related multimodal domains entirely.
28
+
29
+ 2. **Leveraging LLM Knowledge as Data Augmentation:** Certain specialized fields, such as medical reports, often suffer from a lack of data. Here, the knowledge encoded in LLMs can serve as an excellent data augmenter due to their access to open-world knowledge over time.
30
+
31
+ We look forward to collaborating with you to push the boundaries of multimodal domains!
32
+
33
+ BTW, we plan to release scaled-up LLM2CLIP models (10-100x larger) next quarter. These models will inherit our general-purpose parameters, potentially making them even more powerful. Please stay tuned to our GitHub!
34
+
35
+ ## Q3:
36
+
37
+ > **Q: Thank you so much for such an outstanding work. I have a couple of questions regarding the fine-tuning process described in Section 3.2, particularly around the integration of loss functions and datasets:**
38
+ >
39
+ > **In the paper, two loss functions are mentioned: SimCSE loss and Masked Next Token Prediction (MNTP). However, it is unclear whether these two loss functions are used simultaneously during training, or if the training process is split into different phases where each loss is applied separately. Could you please clarify how the losses are used? If they are used together, what are the relative weights assigned to each?**
40
+ >
41
+ > **Regarding the datasets, CC-3M and Wikitext-103 are mentioned as part of the training process. It seems a bit unclear how these two datasets are combined in the training phase. Given that Wikitext-103 is a pure language corpus while CC-3M is image-caption based, how are they jointly used during the fine-tuning process? Are they used for different stages or tasks?**
42
+ >
43
+ > Looking forward to your insights on this!
44
+ >
45
+ ![A3](https://via.placeholder.com/15/red/000000?text=+) **A:** Thank you for your question. I’m glad to clarify.
46
+
47
+ **Loss Functions Integration:** We use the supervised SimCSE loss to make different captions of the same image positive samples for each other, while captions of different images serve as negative samples. This loss function is key to our method, allowing the LLM to provide meaningful supervisory signals to the image. However, the Masked Next Token Prediction (MNTP) was an initial stage we employed before using the supervised SimCSE loss; it can be understood as an earlier step in training. We first conduct MNTP, followed by supervised SimCSE loss, in a two-stage process. In practice, MNTP has little impact on the results, so removing it does not affect the conclusions. However, for optimal performance, we still chose to use MNTP before applying supervised SimCSE loss.
48
+
49
+ **Dataset Combination:** We indeed mix both pure text and caption datasets. This is because the LLM is initially pre-trained on pure text data, so we aim to retain its original distribution with minimal shift by using the pure text dataset Wikitext-103, which also helps mitigate any bias introduced by captions. Our approach is to mix and shuffle the two datasets and then sample batches normally for training. This is a common and effective practice.
50
+
51
+ If you have more questions, please feel free to ask.
52
+
53
+ ## Q4:
54
+
55
+ > **Q: LLM2CLIP does not bring out significant improvements on ImageNet-1k only or all these zero-shot benchmarks?**
56
+ >
57
+ > **Have you ever measured the average caption length between your method and vanilla EVA-02-CLIP? In my opinion, longer text captions do not always bring out improvements.**
58
+ >
59
+ > **It's reasonable to improve the performances of VLMs on the SQA and Wizwiz benchmarks while it's strange to drop the performances on the fundamental benchmarks such as MME.**
60
+
61
+ ![A4](https://via.placeholder.com/15/purple/000000?text=+) **A:** We haven’t specifically tested it, and the improvement on ImageNet is indeed not very noticeable. With OpenAI’s CLIP, we can achieve about a one-point improvement, which is relatively modest compared to other retrieval tasks. My guess is that we used a large amount of dense captions, which may cause the model to favor more complex text. However, we have found in experiments that ImageNet performance is strongly correlated with data volume, possibly related to the word distribution used during alignment. We only used 15 million data points for the alignment in LLM fine-tuning. In the next version, we’ll increase the training data for LLM2CLIP by tens of times, so we plan to re-evaluate it then.
62
+
63
+ The improvement of long captions or dense captions for CLIP is quite limited. Works like LongCLIP (https://arxiv.org/abs/2403.15378) and DCI (https://arxiv.org/abs/2312.08578) specifically address this issue. The problem here is that the original CLIP text encoder lacks the ability to understand such information or handle captions of this length. However, LLM2CLIP, even when trained on a fully short-text dataset, still demonstrates outstanding and leading performance, as shown in Table 5 of the paper.
64
+
65
+ ## Q5:
66
+
67
+ > **Q: Hello!**
68
+ >
69
+ > **I am very interested in your work, and I encountered some issues during the reproduction process.**
70
+ >
71
+ > **How can I replace the original text encoder with the tuned Llama 3 model? I checked the config file LLM2CLIP-EVA02-L-14-336/configuration_evaclip.py, and I noticed that the model parameters for the text encoder remain the same as those in the original CLIP model. This is a bit confusing to me.**
72
+ >
73
+ > **If I’m correct, is the run.sh script provided for training CLIP with a frozen Llama 3 encoder?**
74
+ >
75
+ > Looking forward for your reply!
76
+ >
77
+ ![A5](https://via.placeholder.com/15/orange/000000?text=+) **A:** We have updated the caption contrastive fine-tuned version of Llama3-8B-CC (https://huggingface.co/microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned) to assist with your retrieval experiments and training of your own CLIP models. Additionally, the parameters for our adapter and projector have been made available in our OpenAI ViT-L repository (https://huggingface.co/microsoft/LLM2CLIP-Openai-L-14-336). The retrieval testing methods are documented in the model card for reference.
78
+
79
+ Our tests show retrieval performance exceeding the results reported in the paper, and we encourage you to try it out.
80
+
81
+ Regarding the EVA series of models, there have been precision mismatches during the conversion to Hugging Face, which are currently being fixed. Updates will be released progressively.
82
+
83
+ Furthermore, we will provide detailed instructions on how to use LLM2CLIP to fine-tune your own CLIP models in about a week—please stay tuned!
84
+
85
+ ## Q6:
86
+
87
+ > **Q: Hello!**
88
+ >
89
+ > **I am very interested in your work, and I encountered some issues during the reproduction process.**
90
+ >
91
+ > **How can I replace the original text encoder with the tuned Llama 3 model? I checked the config file LLM2CLIP-EVA02-L-14-336/configuration_evaclip.py, and I noticed that the model parameters for the text encoder remain the same as those in the original CLIP model. This is a bit confusing to me.**
92
+ >
93
+ > **If I’m correct, is the run.sh script provided for training CLIP with a frozen Llama 3 encoder?**
94
+ >
95
+ > Looking forward for your reply!
96
+ >
97
+ ![A6](https://via.placeholder.com/15/orange/000000?text=+) **A:** We have updated the caption contrastive fine-tuned version of Llama3-8B-CC (https://huggingface.co/microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned) to assist with your retrieval experiments and training of your own CLIP models. Additionally, the parameters for our adapter and projector have been made available in our OpenAI ViT-L repository (https://huggingface.co/microsoft/LLM2CLIP-Openai-L-14-336). The retrieval testing methods are documented in the model card for reference.
98
+
99
+ Our tests show retrieval performance exceeding the results reported in the paper, and we encourage you to try it out.
100
+
101
+ Regarding the EVA series of models, there have been precision mismatches during the conversion to Hugging Face, which are currently being fixed. Updates will be released progressively.
102
+
103
+ Furthermore, we will provide detailed instructions on how to use LLM2CLIP to fine-tune your own CLIP models in about a week—please stay tuned!
104
+ >
105
+ ## Q6:
106
+
107
+ > **Q: I find the LLM2CLIP approach inspiring as it leverages large language models (LLMs) to enhance cross-modal representation learning. The integration of fine-tuned LLMs as a textual encoder offers substantial improvements over traditional CLIP models. However, I have a few questions and suggestions regarding the methodology and evaluation:**
108
+ >
109
+ > **While the paper highlights the efficiency of training using LoRA and freezing LLM gradients, scaling to datasets larger than the 60M configuration or involving multilingual captions could introduce challenges. Could you elaborate on the computational implications if fine-tuning were performed without freezing the LLM gradients?**
110
+ >
111
+ > **The contrastive fine-tuning strategy for improving feature discriminability is innovative. However, as mentioned, dense captions from ShareCaptioner may introduce noise or distribution mismatches. Have you explored the impact of using alternative caption-generation methods or real-world noisy datasets?**
112
+ >
113
+ > **The use of various datasets like DOCCI and ShareGPT4V provides comprehensive evaluations. However, benchmarks focusing on event understanding, video context, or temporal dependencies could further validate the model's capabilities in real-world multimodal tasks.**
114
+ >
115
+ > **Overall, LLM2CLIP presents a significant advancement in multimodal learning, setting a foundation for future enhancements in cross-modal representation tasks.**
116
+
117
+ ![A6](https://via.placeholder.com/15/orange/000000?text=+) **A:** We opened the latter layers of the network based on the GPU memory we could accommodate but did not observe significant performance improvements, so we decided not to continue this way. CLIP training relies heavily on batch size, and opening the LLM would compromise the batch size, which could have a negative impact. Additionally, keeping the LLM fixed is actually quite reasonable since our goal is to align the visual model with the correct textual modality. Now that we have access to more abundant computational resources, we plan to conduct more experiments in this area to provide answers for the community.
118
+
119
+ We have tried the Recaption-1B dataset (https://github.com/UCSC-VLAA/Recap-DataComp-1B) labeled using Llava 1.5, but its performance was not as good as ShareCaptioner 4V. Real-world noisy datasets essentially align with the conclusion in Table 5 of our paper, specifically the 0% short caption results, which show that they underperform compared to using VLLMs for recaptioning. In our next version, we plan to incorporate a large volume of GPT-4o recaptioned results—please stay tuned!
120
+
121
+ Thank you for your excellent suggestions. Do you have any specific benchmarks you would recommend? We’d be happy to test them.
122
+
123
+ We truly appreciate your recognition and look forward to contributing more valuable models and knowledge to the community in the future.
124
+
125
+ ## Q7:
126
+
127
+ > **Q: This is a really interesting paper that presents a compelling approach to improving visual representation learning by effectively integrating the power of LLMs with CLIP. The entire paper feels well motivated, thoroughly researched, and clearly presented - a truly excellent contribution to the field!**
128
+ >
129
+ > **I am a bit curious that given the importance of CLIP in guiding the image generation process of diffusion models, and the enhancement of CLIP's image-text understanding capabilities by LLM2CLIP demonstrated in the paper, can integrating LLM2CLIP into the training and inference of a diffusion model bring a boost in the text-to-image domain? For example, FLUX and Stable Diffusion 3 series show significant improvement in following natural language prompts than previous diffusion models, and I think LLM2CLIP will bring further improvements.**
130
+ >
131
+ > **Thank you for your innovative work and significant contribution to the field of multimodal learning!**
132
+
133
+ ![A7](https://via.placeholder.com/15/teal/000000?text=+) **A:** Yes, we have also considered that incorporating LLM2CLIP into image-text generative models could enable more complex and precise control, and we believe there is great potential in this direction. In fact, we’ve already conducted some initial experiments, which indicate that LLM2CLIP’s llama3 performs significantly better than a standard llama3 when simply integrated with Stable Diffusion 3. However, we haven’t had the chance to explore this further in depth yet. We might delve into this more thoroughly in the future. Thank you for recognizing our work!
LLM2CLIP/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
LLM2CLIP/SUPPORT.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: The maintainer of this repo has not yet edited this file
2
+
3
+ **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
+
5
+ - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
+ - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7
+ - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8
+
9
+ *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
+
11
+ # Support
12
+
13
+ ## How to file issues and get help
14
+
15
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
+ issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
+ feature request as a new Issue.
18
+
19
+ For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
+ FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
+ CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
+
23
+ ## Microsoft Support Policy
24
+
25
+ Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
LLaVA/.dockerignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The .dockerignore file excludes files from the container build process.
2
+ #
3
+ # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4
+
5
+ # Exclude Git files
6
+ .git
7
+ .github
8
+ .gitignore
9
+
10
+ # Exclude Python cache files
11
+ __pycache__
12
+ .mypy_cache
13
+ .pytest_cache
14
+ .ruff_cache
15
+
16
+ # Exclude Python virtual environment
17
+ /venv
18
+
19
+ # Exclude some weights
20
+ /openai
21
+ /liuhaotian
LLaVA/.editorconfig ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ # Unix-style newlines with a newline ending every file
4
+ [*]
5
+ end_of_line = lf
6
+ insert_final_newline = true
7
+ trim_trailing_whitespace = true
8
+ charset = utf-8
9
+
10
+ # 4 space indentation
11
+ [*.{py,json}]
12
+ indent_style = space
13
+ indent_size = 4
14
+
15
+ # 2 space indentation
16
+ [*.{md,sh,yaml,yml}]
17
+ indent_style = space
18
+ indent_size = 2
LLaVA/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LLaVA/cog.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+
7
+ python_version: "3.11"
8
+
9
+ python_packages:
10
+ - "torch==2.0.1"
11
+ - "accelerate==0.21.0"
12
+ - "bitsandbytes==0.41.0"
13
+ - "deepspeed==0.9.5"
14
+ - "einops-exts==0.0.4"
15
+ - "einops==0.6.1"
16
+ - "gradio==3.35.2"
17
+ - "gradio_client==0.2.9"
18
+ - "httpx==0.24.0"
19
+ - "markdown2==2.4.10"
20
+ - "numpy==1.26.0"
21
+ - "peft==0.4.0"
22
+ - "scikit-learn==1.2.2"
23
+ - "sentencepiece==0.1.99"
24
+ - "shortuuid==1.0.11"
25
+ - "timm==0.6.13"
26
+ - "tokenizers==0.13.3"
27
+ - "torch==2.0.1"
28
+ - "torchvision==0.15.2"
29
+ - "transformers==4.31.0"
30
+ - "wandb==0.15.12"
31
+ - "wavedrom==2.0.3.post3"
32
+ - "Pygments==2.16.1"
33
+ run:
34
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35
+
36
+ # predict.py defines how predictions are run on your model
37
+ predict: "predict.py:Predictor"
LLaVA/pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llava"
7
+ version = "1.2.2.post1"
8
+ description = "Towards GPT-4 like large language and visual assistant."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.1.2", "torchvision==0.16.2",
17
+ "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.21.0", "peft", "bitsandbytes",
19
+ "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20
+ "gradio==4.16.0", "gradio_client==0.8.1",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ train = ["deepspeed==0.12.6", "ninja", "wandb"]
27
+ build = ["build", "twine"]
28
+
29
+ [project.urls]
30
+ "Homepage" = "https://llava-vl.github.io"
31
+ "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
32
+
33
+ [tool.setuptools.packages.find]
34
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
35
+
36
+ [tool.wheel]
37
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
OpenSeeD/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenSeeD
2
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-simple-framework-for-open-vocabulary/panoptic-segmentation-on-coco-minival)](https://paperswithcode.com/sota/panoptic-segmentation-on-coco-minival?p=a-simple-framework-for-open-vocabulary)
3
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-simple-framework-for-open-vocabulary/panoptic-segmentation-on-ade20k-val)](https://paperswithcode.com/sota/panoptic-segmentation-on-ade20k-val?p=a-simple-framework-for-open-vocabulary)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-simple-framework-for-open-vocabulary/instance-segmentation-on-ade20k-val)](https://paperswithcode.com/sota/instance-segmentation-on-ade20k-val?p=a-simple-framework-for-open-vocabulary)
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-simple-framework-for-open-vocabulary/instance-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/instance-segmentation-on-cityscapes-val?p=a-simple-framework-for-open-vocabulary)
6
+
7
+ This is the official implementation of the paper "[A Simple Framework for Open-Vocabulary Segmentation and Detection](https://arxiv.org/pdf/2303.08131.pdf)".
8
+
9
+ https://user-images.githubusercontent.com/34880758/225408795-d1e714e0-cfc8-4466-b052-045d54409a1d.mp4
10
+
11
+ You can also find the more detailed demo at [video link on Youtube](https://www.youtube.com/watch?v=z4gsQw2n7iM).
12
+
13
+ :point_right: **[New] demo code is available**
14
+ :point_right: **[New] OpenSeeD has been accepted to ICCV 2023! training code is available!**
15
+
16
+ ### :rocket: Key Features
17
+ - A Simple Framework for Open-Vocabulary Segmentation and Detection.
18
+ - Support interactive segmentation with box input to generate mask.
19
+
20
+ ### :bulb: Installation
21
+ ```sh
22
+ pip3 install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113
23
+ python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
24
+ pip install git+https://github.com/cocodataset/panopticapi.git
25
+ python -m pip install -r requirements.txt
26
+ export DATASET=/pth/to/dataset
27
+ ```
28
+ Download the pretrained checkpoint from [here](https://github.com/IDEA-Research/OpenSeeD/releases/download/openseed/model_state_dict_swint_51.2ap.pt).
29
+ ### :bulb: Demo script
30
+ ```sh
31
+ python demo/demo_panoseg.py evaluate --conf_files configs/openseed/openseed_swint_lang.yaml --image_path images/animals.png --overrides WEIGHT /path/to/ckpt/model_state_dict_swint_51.2ap.pt
32
+ ```
33
+ :fire: Remember to **modify the vocabulary** `thing_classes` and `stuff_classes` in `demo_panoseg.py` if your want to segment open-vocabulary objects.
34
+
35
+ **Evaluation on coco**
36
+ ```sh
37
+ python train_net.py --original_load --eval_only --num-gpus 8 --config-file configs/openseed/openseed_swint_lang.yaml MODEL.WEIGHTS=[/path/to/lang/weight](https://github.com/IDEA-Research/OpenSeeD/releases/download/openseed/model_state_dict_swint_51.2ap.pt)
38
+ ```
39
+ You are expected to get `55.4` PQ.
40
+ ### :bulb: Some coco-format data
41
+ Here is the coco-format json file for evaluating [BDD](https://github.com/IDEA-Research/OpenSeeD/releases/download/bdd_val_data/coco_val.json) and [SUN](https://github.com/IDEA-Research/OpenSeeD/releases/tag/sun_data).
42
+ ### Training OpenSeeD baseline
43
+ **Training on coco**
44
+ ```sh
45
+ python train_net.py --num-gpus 8 --config-file configs/openseed/openseed_swint_lang.yaml --lang_weight [/path/to/lang/weight](https://github.com/IDEA-Research/OpenSeeD/releases/download/training/model_state_dict_only_language.pt)
46
+ ```
47
+ **Training on coco+o365**
48
+ ```sh
49
+ python train_net.py --num-gpus 8 --config-file configs/openseed/openseed_swint_lang_o365.yaml --lang_weight [/path/to/lang/weight](https://github.com/IDEA-Research/OpenSeeD/releases/download/training/model_state_dict_only_language.pt)
50
+ ```
51
+ ### Checkpoints
52
+ - Swin-T model trained on COCO panoptic segmentation and Objects365 [weights](https://github.com/IDEA-Research/OpenSeeD/releases/tag/ckpt_swint_coco_o365).
53
+ - Swin-L model fine-tuned on COCO panoptic segmentation [weights](https://github.com/IDEA-Research/OpenSeeD/releases/tag/coco_pano_sota_swinl).
54
+ - Swin-L model fine-tuned on ADE20K semantic segmentation [weights](https://github.com/IDEA-Research/OpenSeeD/releases/tag/ade20k_swinl).
55
+ ![hero_figure](figs/intro.jpg)
56
+ ### :unicorn: Model Framework
57
+ ![hero_figure](figs/framework.jpg)
58
+ ### :volcano: Results
59
+ Results on open segmentation
60
+ ![hero_figure](figs/results1.jpg)
61
+ Results on task transfer and segmentation in the wild
62
+ ![hero_figure](figs/results2.jpg)
63
+
64
+
65
+ ### <a name="CitingOpenSeeD"></a>Citing OpenSeeD
66
+
67
+ If you find our work helpful for your research, please consider citing the following BibTeX entry.
68
+
69
+ ```BibTeX
70
+ @article{zhang2023simple,
71
+ title={A Simple Framework for Open-Vocabulary Segmentation and Detection},
72
+ author={Zhang, Hao and Li, Feng and Zou, Xueyan and Liu, Shilong and Li, Chunyuan and Gao, Jianfeng and Yang, Jianwei and Zhang, Lei},
73
+ journal={arXiv preprint arXiv:2303.08131},
74
+ year={2023}
75
+ }
76
+ ```
77
+
OpenSeeD/__init__.py ADDED
File without changes
OpenSeeD/requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ pyyaml
5
+ json_tricks
6
+ yacs
7
+ scikit-learn
8
+ pandas
9
+ timm==0.4.12
10
+ numpy==1.23.5
11
+ einops
12
+ fvcore
13
+ transformers==4.19.2
14
+ sentencepiece
15
+ ftfy
16
+ regex
17
+ nltk
18
+ vision-datasets==0.2.2
19
+ pycocotools==2.0.4
20
+ diffdist
21
+ pyarrow
22
+ cityscapesscripts
23
+ shapely
24
+ scikit-image
25
+ mup
26
+ gradio==3.13.0
27
+ scann
28
+ kornia==0.6.4
29
+ torchmetrics==0.6.0
30
+ mpi4py
Ovis/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ovis: Structural Embedding Alignment for Multimodal Large Language Model
2
+
3
+ Ovis (Open VISion) is a novel Multimodal Large Language Model (MLLM) architecture, designed to structurally align visual and textual embeddings. For a comprehensive introduction, please refer to the [Ovis paper](https://arxiv.org/abs/2405.20797).
4
+
5
+ <div style="text-align: center;">
6
+ <img style="max-width: 100%;" src="docs/ovis-illustration.png" alt="Ovis Illustration"/>
7
+ </div>
8
+
9
+ ## Release
10
+ - [11/26] 🔥 Announcing [Ovis1.6-Gemma2-27B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B)!
11
+ - [11/04] 🔥 Announcing quantized versions of Ovis1.6: [Ovis1.6-Gemma2-9B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4) and [Ovis1.6-Llama3.2-3B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B-GPTQ-Int4)!
12
+ - [10/22] 🔥 Announcing Ovis1.6-Llama3.2-3B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Llama3.2-3B))!
13
+ - [09/19] 🔥 Announcing Ovis1.6-Gemma2-9B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Gemma2-9B))! This latest release further enhances high-resolution image processing, is trained on a larger, more diverse, and higher-quality dataset, and refines the training process with DPO training following instruction-tuning.
14
+ - [07/24] 🔥 Introducing Ovis1.5, featuring improved high-resolution image processing and optimized training data for enhanced performance.
15
+ - [06/14] 🔥 Launch of Ovis1.0, the inaugural version of the Ovis model.
16
+
17
+ ## Contents
18
+ - [Install](#install)
19
+ - [Model](#model)
20
+ - [Performance](#performance)
21
+ - [Finetune](#finetune)
22
+ - [Inference](#inference)
23
+ - [Quantization](#quantization)
24
+ - [Citation](#citation)
25
+ - [Team](#team)
26
+ - [License](#license)
27
+
28
+ ## Install
29
+ Ovis has been tested with Python 3.10, Torch 2.4.0, Transformers 4.46.2, and DeepSpeed 0.15.4. For a comprehensive list of package dependencies, please consult the `requirements.txt` file. Before finetuning or inference, please install Ovis as follows.
30
+ ```bash
31
+ git clone git@github.com:AIDC-AI/Ovis.git
32
+ conda create -n ovis python=3.10 -y
33
+ conda activate ovis
34
+ cd Ovis
35
+ pip install -r requirements.txt
36
+ pip install -e .
37
+ ```
38
+
39
+ ## Model
40
+ Ovis can be instantiated with popular LLMs. We provide the following Ovis MLLMs:
41
+
42
+ | Ovis MLLMs | ViT | LLM | Model Weights | Demo |
43
+ |:------------------|:-----------:|:------------------:|:---------------------------------------------------------------:|:----------------------------------------------------------------:|
44
+ | Ovis1.6-Gemma2-27B | Siglip-400M | Gemma2-27B-It | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B) | - |
45
+ | Ovis1.6-Gemma2-9B | Siglip-400M | Gemma2-9B-It | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Gemma2-9B) |
46
+ | Ovis1.6-Llama3.2-3B | Siglip-400M | Llama-3.2-3B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Llama3.2-3B) |
47
+
48
+ ## Performance
49
+ With **29B** parameters, **Ovis1.6-Gemma2-27B** achieves exceptional performance in the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark, ranking among the top-tier open-source MLLMs.
50
+
51
+ ![performance-Ovis1_6-Gemma2-27B](docs/performance/Ovis1_6-Gemma2-27B.png)
52
+
53
+ With just **10B** parameters, **Ovis1.6-Gemma2-9B** leads the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark among open-source MLLMs within **30B** parameters.
54
+
55
+ ![performance-Ovis1_6-Gemma2-9B](docs/performance/Ovis1_6-Gemma2-9B.png)
56
+
57
+ **Ovis1.6-Llama3.2-3B** leads the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark among open-source MLLMs under **4B** parameters, even surpassing Llama-3.2-11B-Vision-Instruct.
58
+
59
+ ![performance-Ovis1_6-Llama3_2-3B](docs/performance/Ovis1_6-Llama3_2-3B.png)
60
+
61
+ ## Finetune
62
+ Finetuning Ovis1.6-Gemma2-9B is supported in [ms-swift](https://github.com/modelscope/ms-swift).
63
+
64
+ ## Inference
65
+ We provide an inference wrapper in `ovis/serve/runner.py`, which can be used as:
66
+ ```python
67
+ from PIL import Image
68
+ from ovis.serve.runner import RunnerArguments, OvisRunner
69
+ image = Image.open('temp.png')
70
+ text = 'PROMPT'
71
+ runner_args = RunnerArguments(model_path='AIDC-AI/Ovis1.6-Gemma2-27B')
72
+ runner = OvisRunner(runner_args)
73
+ generation = runner.run([image, text])
74
+ ```
75
+ Based on [Gradio](https://github.com/gradio-app/gradio), Ovis can also be accessed via a web user interface:
76
+ ```bash
77
+ python ovis/serve/server.py --model_path MODEL_PATH --port PORT
78
+ ```
79
+
80
+ ## Quantization
81
+ We quantized Ovis1.6 using AutoGPTQ. For detailed information on running and creating your own quantized version, please refer to the respective Huggingface model cards: [Ovis1.6-Gemma2-9B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4) and [Ovis1.6-Llama3.2-3B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B-GPTQ-Int4). Quantized Ovis1.6 maintains performance comparable to its non-quantized counterpart while requiring less GPU memory:
82
+
83
+ - Benchmark performance:
84
+ ![performance-Ovis1_6-Gemma2-9B-GPTQ-Int4](docs/performance/Ovis1_6-Gemma2-9B-GPTQ-Int4.png)
85
+ ![performance-Ovis1_6-Llama3_2-3B-GPTQ-Int4](docs/performance/Ovis1_6-Llama3_2-3B-GPTQ-Int4.png)
86
+
87
+ - GPU memory usage (max_partition=9):
88
+ ![performance-Ovis1_6-VRAM-Comparison](docs/performance/Ovis1_6-VRAM-Comparison.png)
89
+
90
+ ## Citation
91
+ If you find Ovis useful, please cite the paper
92
+ ```
93
+ @article{lu2024ovis,
94
+ title={Ovis: Structural Embedding Alignment for Multimodal Large Language Model},
95
+ author={Shiyin Lu and Yang Li and Qing-Guo Chen and Zhao Xu and Weihua Luo and Kaifu Zhang and Han-Jia Ye},
96
+ year={2024},
97
+ journal={arXiv:2405.20797}
98
+ }
99
+ ```
100
+
101
+ ## Team
102
+ This work is a collaborative effort by the MarcoVL team. We would also like to provide links to the following MLLM papers from our team:
103
+ - [Parrot: Multilingual Visual Instruction Tuning](https://arxiv.org/abs/2406.02539)
104
+ - [Wings: Learning Multimodal LLMs without Text-only Forgetting](https://arxiv.org/abs/2406.03496)
105
+
106
+ ## License
107
+ This project is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.txt) (SPDX-License-Identifier: Apache-2.0).
108
+
109
+ ## Disclaimer
110
+ We used compliance-checking algorithms during the training process, to ensure the compliance of the trained model to the best of our ability. Due to the complexity of the data and the diversity of language model usage scenarios, we cannot guarantee that the model is completely free of copyright issues or improper content. If you believe anything infringes on your rights or generates improper content, please contact us, and we will promptly address the matter.
PaddleMIX/.copyright.hook ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import print_function
17
+ from __future__ import unicode_literals
18
+
19
+ import argparse
20
+ import io
21
+ import re
22
+ import sys
23
+ import os
24
+ import datetime
25
+
26
+ COPYRIGHT = '''Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
27
+
28
+ Licensed under the Apache License, Version 2.0 (the "License");
29
+ you may not use this file except in compliance with the License.
30
+ You may obtain a copy of the License at
31
+
32
+ http://www.apache.org/licenses/LICENSE-2.0
33
+
34
+ Unless required by applicable law or agreed to in writing, software
35
+ distributed under the License is distributed on an "AS IS" BASIS,
36
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
37
+ See the License for the specific language governing permissions and
38
+ limitations under the License.'''
39
+
40
+ def _generate_copyright(comment_mark):
41
+ copyright=COPYRIGHT.split(os.linesep)
42
+ header = copyright[0].rstrip()
43
+
44
+ p = re.search('(\d{4})', header).group(0)
45
+ now = datetime.datetime.now()
46
+
47
+ header = header.replace(p,str(now.year))
48
+
49
+ ans=[comment_mark + " " + header + os.linesep]
50
+ for idx, line in enumerate(copyright[1:]):
51
+ ans.append(comment_mark + " " + line.rstrip() + os.linesep)
52
+
53
+ return ans
54
+
55
+ def _get_comment_mark(path):
56
+ lang_type=re.compile(r"\.(py|sh)$")
57
+ if lang_type.search(path) is not None:
58
+ return "#"
59
+
60
+ lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$")
61
+ if lang_type.search(path) is not None:
62
+ return "//"
63
+
64
+ return None
65
+
66
+
67
+ RE_ENCODE = re.compile(r"^[ \t\v]*#.*?coding[:=]", re.IGNORECASE)
68
+ RE_COPYRIGHT = re.compile(r".*Copyright( \(c\))* \d{4}", re.IGNORECASE)
69
+ RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!")
70
+
71
+ def _check_copyright(path):
72
+ head=[]
73
+ try:
74
+ with open(path) as f:
75
+ head = [next(f) for x in range(4)]
76
+ except StopIteration:
77
+ pass
78
+
79
+ for idx, line in enumerate(head):
80
+ if RE_COPYRIGHT.search(line) is not None:
81
+ return True
82
+
83
+ return False
84
+
85
+ def generate_copyright(path, comment_mark):
86
+ original_contents = io.open(path, encoding="utf-8").readlines()
87
+ head = original_contents[0:4]
88
+
89
+ insert_line_no=0
90
+ for i, line in enumerate(head):
91
+ if RE_ENCODE.search(line) or RE_SHEBANG.search(line):
92
+ insert_line_no=i+1
93
+
94
+ copyright = _generate_copyright(comment_mark)
95
+ if insert_line_no == 0:
96
+ new_contents = copyright
97
+ if len(original_contents) > 0 and len(original_contents[0].strip()) != 0:
98
+ new_contents.append(os.linesep)
99
+ new_contents.extend(original_contents)
100
+ else:
101
+ new_contents=original_contents[0:insert_line_no]
102
+ new_contents.append(os.linesep)
103
+ new_contents.extend(copyright)
104
+ if len(original_contents) > insert_line_no and len(original_contents[insert_line_no].strip()) != 0:
105
+ new_contents.append(os.linesep)
106
+ new_contents.extend(original_contents[insert_line_no:])
107
+ new_contents="".join(new_contents)
108
+
109
+ with io.open(path, 'w') as output_file:
110
+ output_file.write(new_contents)
111
+
112
+
113
+
114
+ def main(argv=None):
115
+ parser = argparse.ArgumentParser(
116
+ description='Checker for copyright declaration.')
117
+ parser.add_argument('filenames', nargs='*', help='Filenames to check')
118
+ args = parser.parse_args(argv)
119
+
120
+ retv = 0
121
+ for path in args.filenames:
122
+ comment_mark = _get_comment_mark(path)
123
+ if comment_mark is None:
124
+ print("warning:Unsupported file", path, file=sys.stderr)
125
+ continue
126
+
127
+ if _check_copyright(path):
128
+ continue
129
+
130
+ generate_copyright(path, comment_mark)
131
+
132
+
133
+ if __name__ == '__main__':
134
+ exit(main())
PaddleMIX/.style.yapf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ column_limit = 80
PaddleMIX/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
PaddleMIX/README_EN.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 简体中文 | [English](README_EN.md)
2
+
3
+ <p align="center">
4
+ <img src="https://github.com/PaddlePaddle/PaddleMIX/assets/22989727/2cd19298-1c52-4d73-a0f7-dcdab6a8ec90" align="middle" width = "600" />
5
+ </p>
6
+
7
+ <p align="center">
8
+ <a href="https://github.com/PaddlePaddle/PaddleMix/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleMix?color=ffa"></a>
9
+ <a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
10
+ <a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
11
+ <a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
12
+ <a href="#📌社区交流"><img src="https://img.shields.io/badge/微信-小助手加群-green?logo=wechat&amp"></a>
13
+ <a href="https://github.com/PaddlePaddle/PaddleMIX/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleMIX?color=ccf"></a>
14
+
15
+ </p>
16
+ </div>
17
+
18
+ ## 💌 Table of Contents
19
+ - [💌 Table of Contents](#table-of-contents)
20
+ - [📰 News](#news)
21
+ - [📣 Latest Developments](#latest-developments)
22
+ - [🌈 Introduction](#introduction)
23
+ - [✨ Key Features](#key-features)
24
+ - [📱 Rich Multimodal Capabilities](#rich-multimodal-capabilities)
25
+ - [🧩 Simple Development Experience](#simple-development-experience)
26
+ - [💡 High-Performance Distributed Training and Inference Capabilities](#high-performance-distributed-training-and-inference-capabilities)
27
+ - [🔧 Unique Features and Tools](#unique-features-and-tools)
28
+ - [🔍 Installation](#installation)
29
+ - [🔥 Tutorials](#tutorials)
30
+ - [🤔 FAQ](#faq)
31
+ - [📱 Model Library](#model-library)
32
+ - [📝 License](#license)
33
+ - [📌 Community](#community)
34
+
35
+
36
+ ## 📰 News
37
+
38
+ **🔥PaddleMIX Development Project Challenge (November 21 - December 22, 2024)**
39
+
40
+ **🔥2024.11.21 - 2024.12.22 PaddleMIX Development Project Challenge (Ended)**
41
+
42
+ - ✨「Experience Officer Recruitment」PaddleMIX Development Project Challenge
43
+ Click the link to register 🔗: [https://aistudio.baidu.com/activitydetail/1503019366](https://aistudio.baidu.com/activitydetail/1503019366)
44
+ 🏆 Submit to the PaddlePaddle Galaxy Community Project Hall to be featured and receive a PaddleMIX Experience Officer certification certificate and JD.com card incentives.
45
+ Everyone is welcome to submit~
46
+
47
+ <details>
48
+ <summary>Click to view the event poster</summary>
49
+ <p align="center">
50
+ <img src='https://github.com/user-attachments/assets/27e0bbe3-0ff8-49ef-bd39-81a31a2b288b' width="25%">
51
+ </p>
52
+ </details>
53
+
54
+ ## 📣 Latest Developments
55
+
56
+ **🎉 2024.12.17 Support for [InternVL2_5 (1B, 2B, 4B, 8B)](./paddlemix/examples/internvl2) inference**
57
+
58
+ **🎉 2024.11.27 Added support for [Janus/JanusFlow](./paddlemix/examples/janus) inference**
59
+
60
+ **🎉 2024.11.21 Added support for [MiniCPM-V-2_6](./paddlemix/examples/minicpm-v-2_6) inference**
61
+
62
+ **🎉 2024.11.8 Support for [DenseConnector](./paddlemix/examples/llava_denseconnector) and [Aquila-VL-2B-llava-qwen](./paddlemix/examples/llava_onevision/) inference**
63
+
64
+ **🎉 2024.11.1 Support for [LLaVA-OneVision](./paddlemix/examples/llava_onevision/) and [LLaVA-Critic](./paddlemix/examples/llava_critic/) inference**
65
+
66
+ **🎉 2024.10.31 Welcome to the Update of External Developer's Creative [Tutorial Page](paddlemix_applications.md)**
67
+ * 🌟 Since the launch of our Large Model Suite Premium Project Collection activity on September 6th, we have received 30 high-quality developer projects. Among them, 25 premium projects have successfully passed the platform evaluation and been featured.
68
+
69
+ * 🙏 We sincerely thank all developers for their wonderful creations based on our suite! 🚀 We cordially invite you to share your creativity as well - welcome to publish your tutorials on public web pages or in the [PaddlePaddle AI Studio](https://aistudio.baidu.com/aistudio/community/multimodal?from=singlemessage) community!
70
+
71
+ <details>
72
+ <summary>Click to expand more</summary>
73
+
74
+ **🔥 PaddleMIX v2.1 Released on 2024.10.11**
75
+ * Supports the [PaddleNLP 3.0 beta](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta0) version, allowing early access to its latest features.
76
+ * Added cutting-edge models like [Qwen2-VL](./paddlemix/examples/qwen2_vl/), [InternVL2](./paddlemix/examples/internvl2/), and [Stable Diffusion 3 (SD3)](https://github.com/PaddlePaddle/PaddleMIX/blob/develop/ppdiffusers/examples/dreambooth/README_sd3.md).
77
+ * Released our self-developed multimodal data capability tagging model [PP-InsCapTagger](./paddlemix/datacopilot/example/pp_inscaptagger/), which can be used for data analysis and filtering. Experimental cases show that it can reduce data volume by 50% while maintaining model performance, significantly improving training efficiency.
78
+
79
+ * The multimodal large models InternVL2, LLaVA, SD3, and SDXL are now adapted to the Ascend 910B, offering training and inference capabilities on domestic computing chips.
80
+
81
+
82
+ **PaddleMIX v2.0 Released on 2024.07.25**
83
+ * Multimodal Understanding: Added LLaVA series, Qwen-VL, etc.; introduced Auto module to unify the SFT training process; introduced Mixtoken training strategy, increasing SFT throughput by 5.6 times.
84
+ * Multimodal Generation: Released [PPDiffusers 0.24.1](./ppdiffusers/README.md), supporting video generation capabilities, and added LCM to the text-to-image model. Also added a PaddlePaddle version of PEFT and the Accelerate backend. Provided a ComfyUI plugin developed with PaddlePaddle.
85
+ * Multimodal Data Processing Toolbox [DataCopilot](./paddlemix/datacopilot/): Supports custom data structures, data transformation, and offline format checks. Includes basic statistical information and data visualization functionality.
86
+
87
+ **PaddleMIX v1.0 Released on 2023.10.7**
88
+ * Added distributed training capabilities for vision-language pre-training models, and BLIP-2 now supports trillion-scale training.
89
+ * Introduced the cross-modal application pipeline [AppFlow](./applications/README.md), which supports 11 cross-modal applications such as automatic annotation, image editing, and audio-to-image with one click.
90
+ * [PPDiffusers](./ppdiffusers/README.md) released version 0.19.3, adding SDXL and related tasks.
91
+ </details>
92
+
93
+
94
+ ---
95
+
96
+ ## 🌈 Introduction
97
+
98
+ PaddleMIX is a multimodal large model development suite based on PaddlePaddle, integrating various modalities such as images, text, and video. It covers a wide range of multimodal tasks, including vision-language pre-training, fine-tuning, text-to-image, text-to-video, and multimodal understanding. It offers an out-of-the-box development experience while supporting flexible customization to meet diverse needs, empowering the exploration of general artificial intelligence.
99
+
100
+ <p align="center">
101
+ <img src="https://github.com/user-attachments/assets/764b32a4-3933-4ef8-a0b2-dd425af49ef8" align="middle" width = 100% />
102
+ </p>
103
+
104
+ The PaddleMIX toolchain includes data processing, model development, pre-training, fine-tuning, and inference deployment, supporting mainstream multimodal models such as EVA-CLIP, BLIP-2, and Stable Diffusion. With cross-modal task pipelines like AppFlow and text-to-image application pipelines, developers can quickly build multimodal applications.
105
+
106
+ ### An example of multimodal understanding is shown below:
107
+
108
+ <img src="https://github.com/user-attachments/assets/4c9a0427-57c7-4e1b-80f0-428c03119cc3"></img>
109
+
110
+
111
+ Multimodal understanding 🤝 integrates visual 👀 and linguistic 💬 processing capabilities. It includes functions such as basic perception, fine-grained image understanding, and complex visual reasoning 🧠. Our [Model Library](#model-library) offers practical applications for single-image, multi-image, and video inference. Features include natural image summarization 📝, question answering 🤔, OCR 🔍, sentiment recognition ❤️😢, specialized image analysis 🔬, and code interpretation 💻. These technologies can be applied in various fields such as education 📚, healthcare 🏥, industry 🏭, and more, enabling comprehensive intelligent analysis from static images 🖼️ to dynamic videos 🎥. We invite you to experience and explore these capabilities!
112
+
113
+ ### An example of multimodal generation is shown below:
114
+
115
+ <div style="display: flex; justify-content: center; gap: 5px;">
116
+ <img src="https://github.com/user-attachments/assets/f4768f08-f7a3-45e0-802c-c91554dc5dfc" style="height: 250px; object-fit: fill;">
117
+ <img src="https://github.com/user-attachments/assets/9bf4a333-af57-4ddd-a514-617dea8da435" style="height: 250px; object-fit: fill;">
118
+ </div>
119
+
120
+ Multimodal generation ✍️ combines the creative power of text 💬 and visuals 👀. It includes various technologies ranging from text-to-image 🖼️ to text-to-video 🎥, featuring advanced models like Stable Diffusion 3 and Open-Sora. We provide practical applications for single-image generation, multi-image synthesis, and video generation in [ppdiffusers](ppdiffusers/README.md). These features cover areas such as artistic creation 🎨, animation production 📽️, and content generation 📝. With these technologies, creative generation from static images to dynamic videos can be applied in fields like education 📚, entertainment 🎮, advertising 📺, and more. We invite you to experience and explore these innovations!
121
+
122
+ ### Example of featured applications (click the titles for a quick jump to the online experience):
123
+ | [**ComfyUI Creative Workflow**](https://aistudio.baidu.com/community/app/106043) | [**Art Style QR Code Model**](https://aistudio.baidu.com/community/app/1339) | [**Mix Image Overlay**](https://aistudio.baidu.com/community/app/1340) |
124
+ | :--------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------: |
125
+ | <img src='https://github.com/PaddlePaddle/PaddleMIX/assets/35400185/36ba7261-1744-41a4-b1cb-c9e99f6931f2' width="300px"> | <img src='https://github.com/PaddlePaddle/Paddle/assets/22989727/ba091291-a1ee-49dc-a1af-fc501c62bfc8' width="300px"> | <img src='https://github.com/PaddlePaddle/Paddle/assets/22989727/a71be5a0-b0f3-4aa8-bc20-740ea8ae6785' width="300px"> |
126
+ | [**Anime Text-to-Image**](https://aistudio.baidu.com/community/app/2/webUI?source=appCenter) | [**AI Art|50+ Lora Style Overlays**](https://aistudio.baidu.com/community/app/2848/webUI?source=appCenter) | [**ControlNet|Partial Image Repainting**](https://aistudio.baidu.com/community/app/1981/webUI?source=appCenter) |
127
+ | <img src='https://github.com/user-attachments/assets/a4af8f8a-08c7-4da7-8575-9dbfedaba56c' width="200px"> | <img src='https://github.com/user-attachments/assets/fa92c229-a885-46a1-b23f-a076855c93ec' width="200px"> | <img src='https://github.com/user-attachments/assets/78625876-d8ec-4c15-ae96-655c50f562ab' width="200px"> |
128
+
129
+
130
+
131
+
132
+
133
+ -----
134
+
135
+
136
+ ## ✨ Key Features
137
+
138
+ ### 📱 Rich Multimodal Capabilities
139
+ PaddleMIX supports a wide range of the latest mainstream algorithm benchmarks and pre-trained models, covering vision-language pre-training, text-to-image, cross-modal visual tasks, and enabling diverse functionalities such as image editing, image description, and data annotation. `Gateway`: [📱 Model Library](#model-library)
140
+
141
+ ### 🧩 Simple Development Experience
142
+ PaddleMIX provides a unified model development interface, allowing developers to quickly integrate and customize models. With the Auto module, users can efficiently load pre-trained models, perform tokenization, and easily complete model training, fine-tuning (SFT), inference, and deployment through a simplified API. Additionally, the Auto module supports developers in customizing automated model integration, ensuring flexibility and scalability while enhancing development efficiency.
143
+
144
+ ### 💡 High-Performance Distributed Training and Inference Capabilities
145
+ PaddleMIX offers high-performance distributed training and inference capabilities, integrating acceleration operators like ✨Fused Linear✨ and ✨Flash Attention✨. It supports 🌀BF16 mixed-precision training and 4D mixed-parallel strategies. By optimizing inference performance through convolution layout, GroupNorm fusion, and rotating positional encoding optimization, it significantly enhances large-scale pre-training and efficient inference performance.
146
+
147
+ <img src="https://github.com/user-attachments/assets/9ab9540a-fa89-41cb-838d-95df86e33382" width = 100% />
148
+
149
+ ### 🔧 Unique Features and Tools
150
+ The multimodal data processing toolbox, DataCopilot, accelerates model iteration and upgrades. It allows developers to perform basic data operations with low code based on specific tasks. `Gateway`: [🏆 Featured Models | Tools](#featured-models-tools)
151
+
152
+
153
+ ## 🔍 Installation
154
+ ### 1. Clone PaddleMIX Repository
155
+ ```
156
+ git clone https://github.com/PaddlePaddle/PaddleMIX
157
+ cd PaddleMIX
158
+ ```
159
+
160
+ ### 2. Create Virtual Environment
161
+ ```
162
+ conda create -n paddlemix python=3.10 -y
163
+ conda activate paddlemix
164
+ ```
165
+
166
+ ### 3. ‼️ Install PaddlePaddle
167
+
168
+ #### Method 1: One-click Installation (Recommended for GPU/CPU)
169
+
170
+ - CUDA 11.x or 12.3
171
+ - PaddlePaddle 3.0.0b1
172
+ ```
173
+ sh build_paddle_env.sh
174
+ ```
175
+
176
+ #### Method 2: Manual Installation
177
+ For detailed instructions on installing PaddlePaddle, please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html).
178
+
179
+ ### 4. ‼️ Install Dependencies
180
+
181
+ #### Method 1: One-Click Installation (Recommended)
182
+ ```
183
+ sh build_env.sh
184
+ ```
185
+ #### Method 2: Manual Installation
186
+ ```bash
187
+ # Install PaddleMIX
188
+ pip install -e .
189
+ # Install ppdiffusers
190
+ cd ppdiffusers
191
+ pip install -e .
192
+ cd ..
193
+
194
+ ### 5. ‼️ Verify Installation
195
+
196
+ Run the following command to verify your installation:
197
+ ```bash
198
+ sh check_env.sh
199
+ ```
200
+
201
+ Recommended versions for environment and dependencies:
202
+ - paddlepaddle: 3.0.0b2 or develop version
203
+ - paddlenlp: 3.0.0b2
204
+ - ppdiffusers: 0.29.0
205
+ - huggingface_hub: 0.23.0
206
+
207
+ ### 6. Install Custom Operators (Optional)
208
+ * Some models require custom operators (FastLayerNorm, FusedLayerNorm), such as EVA-CLIP, DIT_LLAMA, etc.
209
+ * Skip this step for non-CUDA environments (e.g., Ascend NPU)
210
+ * ```bash
211
+ cd paddlemix/external_ops
212
+ python setup.py install
213
+ ```
214
+
215
+
216
+
217
+
218
+ #### Method 2: Manual Installation (Please refer to build_env.sh)
219
+ ## 🔥 Tutorials
220
+
221
+ **Quick Start**
222
+ - [Multimodal Understanding: Beginner's Guide [Example: InternVL2 Model]](paddlemix/examples/internvl2/README.md)
223
+ - [Multimodal Generation: Zero to Hero Guide [Example: Stable Diffusion Model]](ppdiffusers/examples/stable_diffusion/README.md)
224
+ - [Cross-modal Task Pipeline: Getting Started](applications/README.md/#getting-started)
225
+
226
+ **Hands-On Practice & Examples**
227
+ - [LLaVA Model: Full Process Practice from Training to Inference](https://aistudio.baidu.com/projectdetail/7917712)
228
+ - [SDXL Application: Create Your Own Olympic Poster Generator](https://aistudio.baidu.com/projectdetail/8251202)
229
+ - [PaddleMIX Multimodal AI Applications: Project Classification Overview](./paddlemix_applications.md)
230
+
231
+ **Multi-Hardware Usage**
232
+ - For the model list and usage supported by Ascend 910B, please refer to [Ascend Hardware Usage](./docs/hardware_support/ascend_usage.md)
233
+
234
+ **Data Preparation & Fine-Tuning**
235
+ - [Model Training and Fine-Tuning Techniques](paddlemix/tools/README.md)
236
+
237
+ **Inference Deployment**
238
+ - [Deployment Guide: From Development to Production Environment](deploy/README.md)
239
+
240
+
241
+
242
+ ## 📱 Model Library
243
+ <table align="center">
244
+ <tbody>
245
+ <tr align="center" valign="center">
246
+ <td>
247
+ <b>Multimodal Understanding</b>
248
+ </td>
249
+ <td>
250
+ <b>Multimodal Generation</b>
251
+ </td>
252
+ <td>
253
+ <b>Unified Multimodal Foundation Model</b>
254
+ </td>
255
+ </tr>
256
+ <tr valign="top">
257
+ <td>
258
+ <ul>
259
+ </ul>
260
+ <li><b>Image-Text Pre-training</b></li>
261
+ <ul>
262
+ <li><a href="paddlemix/examples/clip">CLIP</a></li>
263
+ <li><a href="paddlemix/examples/evaclip">EVA-CLIP</a></li>
264
+ <li><a href="paddlemix/examples/llava">LLaVA-1.5</a></li>
265
+ <li><a href="paddlemix/examples/llava">LLaVA-1.6</a></li>
266
+ <li><a href="paddlemix/examples/llava">LLaVA-NeXT</a></li>
267
+ <li><a href="paddlemix/examples/llava_onevision">LLaVA-onevision</a></li>
268
+ <li><a href="paddlemix/examples/llava_onevision">Aquila-VL-2B-llava-qwen</a></li>
269
+ <li><a href="paddlemix/examples/llava_critic">LLaVA-Critic</a></li>
270
+ <li><a href="paddlemix/examples/llava_denseconnector">LLaVA-DenseConnector</a></li>
271
+ <li><a href="paddlemix/examples/qwen_vl">Qwen-VL</a></li>
272
+ <li><a href="paddlemix/examples/qwen2_vl">Qwen2-VL</a></li>
273
+ <li><a href="paddlemix/examples/internvl2">InternVL2</a></li>
274
+ <li><a href="paddlemix/examples/minimonkey">Mini-Monkey</a></li>
275
+ <li><a href="paddlemix/examples/coca">CoCa</a></li>
276
+ <li><a href="paddlemix/examples/blip2">BLIP-2</a></li>
277
+ <li><a href="paddlemix/examples/minigpt4">miniGPT-4</a></li>
278
+ <li><a href="paddlemix/examples/visualglm">VIsualGLM</a></li>
279
+ <li><a href="paddlemix/examples/cogvlm">CogVLM && CogAgent</a></li>
280
+ <li><a href="paddlemix/examples/internlm_xcomposer2">InternLM-XComposer2</a></li>
281
+ </ul>
282
+ </ul>
283
+ <li><b>Open-World Visual Model</b></li>
284
+ <ul>
285
+ <li><a href="paddlemix/examples/groundingdino">Grounding DINO</a></li>
286
+ <li><a href="paddlemix/examples/sam">SAM</a></li>
287
+ <li><a href="paddlemix/examples/YOLO-World">YOLO-World</a></li>
288
+ </ul>
289
+ </ul>
290
+ <li><b>More Multimodal Pre-trained Models</b></li>
291
+ <ul>
292
+ <li><a href="paddlemix/examples/imagebind">ImageBind</a></li>
293
+ </ul>
294
+ </ul>
295
+ <li><b>Data Analysis</b></li>
296
+ <ul>
297
+ <li><a href="./paddlemix/datacopilot/example/pp_inscaptagger/">PP-InsCapTagger</a></li>
298
+ </ul>
299
+ </td>
300
+ <td>
301
+ <ul>
302
+ </ul>
303
+ <li><b>Text-to-Image</b></li>
304
+ <ul>
305
+ <li><a href="ppdiffusers/examples/stable_diffusion">Stable Diffusion</a></li>
306
+ <li><a href="ppdiffusers/examples/dreambooth/README_sd3.md">Stable Diffusion 3 (SD3)</a></li>
307
+ <li><a href="ppdiffusers/examples/controlnet">ControlNet</a></li>
308
+ <li><a href="ppdiffusers/examples/t2i-adapter">T2I-Adapter</a></li>
309
+ <li><a href="ppdiffusers/examples/text_to_image_laion400m">LDM</a></li>
310
+ <li><a href="ppdiffusers/ppdiffusers/pipelines/unidiffuser">Unidiffuser</a></li>
311
+ <li><a href="ppdiffusers/examples/class_conditional_image_generation/DiT">DiT</a></li>
312
+ <li><a href="ppdiffusers/examples/HunyuanDiT">HunyuanDiT</a></li>
313
+ </ul>
314
+ </ul>
315
+ <li><b>Text-to-Video</b></li>
316
+ <ul>
317
+ <li><a href="ppdiffusers/examples/lvdm">LVDM</a></li>
318
+ <li><a href="ppdiffusers/examples/stable_video_diffusion">SVD</a></li>
319
+ <li><a href="ppdiffusers/examples/AnimateAnyone">AnimateAnyone</a></li>
320
+ <li><a href="ppdiffusers/examples/Open-Sora">OpenSora</a></li>
321
+ </ul>
322
+ </ul>
323
+ <li><b>Audio Generation</b></li>
324
+ <ul>
325
+ <li><a href="ppdiffusers/ppdiffusers/pipelines/audioldm">AudioLDM</a></li>
326
+ <li><a href="ppdiffusers/ppdiffusers/pipelines/audioldm2">AudioLDM2</a></li>
327
+ </ul>
328
+ </td>
329
+ <td>
330
+ <ul>
331
+ </ul>
332
+ <li><b>Unified Multimodal Model</b></li>
333
+ <ul>
334
+ <li><a href="paddlemix/examples/janus">Janus</a></li>
335
+ </ul>
336
+ </td>
337
+ </tr>
338
+ </tbody>
339
+ </table>
340
+
341
+ For more model capabilities, please refer to the [Model Capability Matrix](./paddlemix/examples/README.md)
342
+
343
+ ## 🏆 Featured Models | Tools
344
+
345
+ ### 💎 Cross-Modal Task Pipeline AppFlow
346
+ <details>
347
+ <summary><b> Introduction (Click to Expand)</b></summary>
348
+
349
+ AppFlow, as the cross-modal application task pipeline of PaddleMIX, possesses powerful functionality and ease of use. By integrating cutting-edge algorithms such as LLaVA and Stable Diffusion, AppFlow has comprehensively covered various modalities including images, text, audio, and video. Through a flexible pipeline approach, it has constructed over ten multimodal applications, encompassing text-image generation, text-video generation, text-audio generation, image understanding, and more, providing users with rich demo examples. The highlight of AppFlow is its one-click prediction feature, allowing users to complete model inference with simple commands, eliminating cumbersome training and extensive coding, significantly lowering the barrier to use. Additionally, AppFlow fully leverages the dynamic-static unification advantages of the PaddlePaddle framework; users only need to set simple parameters to automatically complete model dynamic-to-static export and high-performance inference, enhancing work efficiency and optimizing model performance for one-stop application deployment.
350
+
351
+ `Gateway`: [Application Documentation Example](applications/README.md/#quick-start).
352
+
353
+ </details>
354
+
355
+ ### 💎 Multimodal Data Processing Toolbox DataCopilot
356
+ <details>
357
+ <summary><b> Introduction (Click to Expand)</b></summary>
358
+
359
+ In real-world application scenarios, there is a substantial demand for fine-tuning multimodal large models using proprietary data to enhance model performance, making data elements the core of this process. Based on this, PaddleMIX provides the DataCopilot tool for data processing and analysis, allowing developers to achieve an end-to-end development experience within the PaddleMIX suite.
360
+
361
+ PP-InsCapTagger (Instance Capability Tagger) is a dataset capability tagging model implemented by DataCopilot based on PaddleMIX. It is used to label the capabilities of multimodal data instances. By optimizing the dataset through instance capability distribution, it can improve model training efficiency and provide an efficient solution for dataset analysis and evaluation. Combining the model inference labeling results with the LLaVA SFT dataset optimization can **improve LLaVA model training efficiency by 50% during the SFT phase.**
362
+
363
+ `Gateway`: [Application Documentation Example](paddlemix/datacopilot/readme.md).
364
+
365
+ </details>
366
+
367
+ <details>
368
+ <summary><b> PP-InsCapTagger (Click to Expand)</b></summary>
369
+
370
+ | Model | ScienceQA | TextVQA | VQAv2 | GQA | MMMU | MME |
371
+ |----------------------------------|-----------------------------------------|----------------------------------------|----------------------------------------|----------------------------------------|----------------------------------------|-----------------------------------------|
372
+ | llava-1.5-7b (origin) | 66.8 | 58.2 | 78.5 | 62 | - | - |
373
+ | llava-1.5-7b (rerun) | 69.01 | 57.6 | 79 | 62.95 | 36.89 | 1521<br>323 |
374
+ | llava-1.5-7b (random 50%) | 67.31 | 55.6 | 76.89 | 61.01 | 34.67 | 1421<br>286 |
375
+ | **llava-1.5-7b (our 50%)** | **70.24** *(+2.93)* | **57.12** *(+1.52)* | **78.32** *(+1.43)* | **62.14** *(+1.13)* | **37.11** *(+2.44)* | **1476** *(+55)*<br>**338** *(+52)* |
376
+ `Gateway`: [Application Documentation Example](paddlemix/datacopilot/example/pp_inscaptagger/readme.md).
377
+ </details>
378
+
379
+ ## 🤔 FAQ
380
+ For answers to some common questions about our project, please refer to the [FAQ](docs/FAQ.md). If your question is not addressed, feel free to raise it in the [Issues](https://github.com/PaddlePaddle/PaddleMIX/issues).
381
+
382
+ ## 📝 License
383
+ This project is released under the [Apache 2.0 license](LICENSE).
384
+
385
+ ## 📌 Community Communication
386
+
387
+ - Scan the QR code and fill out the questionnaire to join the communication group and engage deeply with numerous community developers and the official team.
388
+ <div align="center">
389
+ <img src="https://github.com/user-attachments/assets/ecf292da-9ac6-41cb-84b6-df726ef4522d" width="300" height="300" />
390
+ </div>
PaddleMIX/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.1.0
PaddleMIX/check_env.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 设置错误时退出
3
+ set -e
4
+
5
+ # 查找可用的Python解释器
6
+ find_python() {
7
+ for cmd in python3 python python3.8 python3.9 python3.10; do
8
+ if command -v "$cmd" > /dev/null 2>&1; then
9
+ if $cmd -c "import sys; exit(0 if sys.version_info >= (3,7) else 1)" 2>/dev/null; then
10
+ echo "$cmd"
11
+ return 0
12
+ fi
13
+ fi
14
+ done
15
+ return 1
16
+ }
17
+
18
+ # 查找Python解释器
19
+ PYTHON_CMD=$(find_python)
20
+
21
+ if [ -z "$PYTHON_CMD" ]; then
22
+ echo "错误: 未找到合适的Python环境 (需要Python >= 3.7)"
23
+ exit 1
24
+ fi
25
+
26
+ echo "使用Python环境: $($PYTHON_CMD --version)"
27
+ echo "=====================Package Versions====================="
28
+
29
+ # 检查paddlepaddle版本
30
+ echo "检查paddlepaddle版本..."
31
+ if $PYTHON_CMD -c "import paddle" 2>/dev/null; then
32
+ paddle_version=$($PYTHON_CMD -c "import paddle; print(paddle.__version__)")
33
+ echo "当前paddlepaddle版本: $paddle_version"
34
+
35
+ # 检查是否为GPU版本
36
+ if $PYTHON_CMD -c "import paddle; print(paddle.device.is_compiled_with_cuda())" 2>/dev/null | grep -q "True"; then
37
+ echo "paddlepaddle类型: GPU版本"
38
+ cuda_version=$($PYTHON_CMD -c "import paddle; print(paddle.device.get_cudnn_version() / 100)")
39
+ echo "CUDA版本: $cuda_version"
40
+ else
41
+ echo "⚠️ paddlepaddle类型: CPU版本,推荐使用GPU版本"
42
+ fi
43
+
44
+ if [[ "$paddle_version" == "3.0.0b2" || "$paddle_version" == *"0.0.0"* ]]; then
45
+ echo "✅ paddlepaddle版本符合要求"
46
+ else
47
+ echo "⚠️ 建议使用paddlepaddle 3.0.0b2或develop版本"
48
+ fi
49
+ else
50
+ echo "❌ 未安装paddlepaddle"
51
+ fi
52
+
53
+ # 检查paddlenlp版本
54
+ echo -e "\n检查paddlenlp版本..."
55
+ if $PYTHON_CMD -c "import paddlenlp" 2>/dev/null; then
56
+ paddlenlp_version=$($PYTHON_CMD -c "import paddlenlp; print(paddlenlp.__version__)")
57
+ echo "当前paddlenlp版本: $paddlenlp_version"
58
+ if [[ "$paddlenlp_version" == "3.0.0b2" ]]; then
59
+ echo "✅ paddlenlp版本符合要求"
60
+ else
61
+ echo "⚠️ 建议使用paddlenlp 3.0.0b2版本"
62
+ fi
63
+ else
64
+ echo "❌ 未安装paddlenlp"
65
+ fi
66
+
67
+ # 检查ppdiffusers版本
68
+ echo -e "\n检查ppdiffusers版本..."
69
+ if $PYTHON_CMD -c "import ppdiffusers" 2>/dev/null; then
70
+ ppdiffusers_version=$($PYTHON_CMD -c "import ppdiffusers; print(ppdiffusers.__version__)")
71
+ echo "当前ppdiffusers版本: $ppdiffusers_version"
72
+ if [[ "$ppdiffusers_version" == "0.29.0" ]]; then
73
+ echo "✅ ppdiffusers版本符合要求"
74
+ else
75
+ echo "⚠️ 建议使用ppdiffusers 0.29.0版本"
76
+ fi
77
+ else
78
+ echo "❌ 未安装ppdiffusers"
79
+ fi
80
+
81
+ # 检查huggingface_hub版本
82
+ echo -e "\n检查huggingface_hub版本..."
83
+ if $PYTHON_CMD -c "import huggingface_hub" 2>/dev/null; then
84
+ hf_version=$($PYTHON_CMD -c "import huggingface_hub; print(huggingface_hub.__version__)")
85
+ echo "当前huggingface_hub版本: $hf_version"
86
+ if [[ "$hf_version" == "0.23.0" ]]; then
87
+ echo "✅ huggingface_hub版本符合要求"
88
+ else
89
+ echo "⚠️ 建议使用huggingface_hub 0.23.0版本"
90
+ fi
91
+ else
92
+ echo "❌ 未安装huggingface_hub"
93
+ fi
94
+
95
+ echo -e "\n===================Version Summary===================="
96
+ echo "推荐版本:"
97
+ echo "- paddlepaddle: 3.0.0b2或develop版本"
98
+ echo "- paddlenlp: 3.0.0b2"
99
+ echo "- ppdiffusers: 0.29.0"
100
+ echo "- huggingface_hub: 0.23.0"
101
+ echo "===================================================="
PaddleMIX/pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.isort]
2
+ profile = 'black'
3
+ known_third_party = ["paddle"]
4
+
5
+ [tool.black]
6
+ line-length = 119
7
+ target_version = ['py35', 'py36', 'py37', 'py38', 'py39', 'py310']
8
+ exclude = ['.flake8']
9
+
10
+ [tool.pytest.ini_options]
11
+ minversion = "6.0"
12
+ pythonpath = ["."]
13
+ testpaths = [
14
+ # "tests/models",
15
+ ]
16
+ python_files = [
17
+ "test.py",
18
+ "test_*.py"
19
+ ]
20
+ filterwarnings = [
21
+ "ignore::UserWarning",
22
+ 'ignore::DeprecationWarning',
23
+ ]
PaddleMIX/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ paddlenlp>=3.0.0b2
3
+ tensorboardX
4
+ opencv-python
5
+ Pillow
6
+ pycocoevalcap
7
+ ftfy
8
+ regex
9
+ einops>=0.6.1
10
+ soundfile
11
+ librosa
12
+ h5py
13
+ jsonschema>=4.19.0
14
+ referencing>=0.32.1
15
+ decord>=0.6.0
VILA/LongVILA.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="demo_images/longvila-logo.png" width="60%"/>
3
+ </p>
4
+
5
+ # LongVILA: Scaling Long-Context Visual Language Models for Long Videos
6
+
7
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](CODE_LICENSE)
8
+ [![Model License](https://img.shields.io/badge/MODEL%20License-CC%20By%20NC%204.0-red.svg)](MODEL_LICENSE)
9
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)
10
+
11
+
12
+ [![Paper](https://img.shields.io/badge/Paper-Arvix%20Link-green)](https://arxiv.org/abs/2408.10188)
13
+ [![Huggingface Models](https://img.shields.io/badge/Models-Huggingface%20Models-bron)](https://huggingface.co/collections/Efficient-Large-Model/longvila-66c3fce79284c8209f119b32)
14
+
15
+ ## 💡 Introduction
16
+
17
+ Long-context capability is critical for multi-modal foundation models. We introduce LongVILA, a full-stack solution for long-context vision-language models, including system, model training, and dataset development. On the system side, we introduce the first long-context Multi-Modal Sequence Parallelism (MM-SP) system that enables long training and inference, enabling 2M context length training on 256 GPUs. MM-SP is also efficient, being 2.1x - 5.7x faster than Ring-Style Sequence Parallelism and 1.1x - 1.4x faster than Megatron-LM in text-only settings. Moreover, it seamlessly integrates with Hugging Face Transformers. For model training, we propose a five-stage pipeline comprising alignment, pre-training, short supervised fine-tuning, context extension, and long supervised fine-tuning. Regarding datasets, we meticulously construct large-scale visual language pre-training datasets and long video instruction-following datasets to support our multi-stage training process. The full-stack solution extends the feasible frame number of VILA by a factor of 128 (from 8 to 1024 frames) and improves long video captioning score from 2.00 to 3.26 (1.6x), achieving 99.5% accuracy in 1400-frames video (274k context length) needle in a haystack. LongVILA-8B also demonstrates consistent accuracy improvements on long videos in the VideoMME benchmark as the video frames increase.
18
+
19
+ <p align="center">
20
+ <img src="demo_images/LongVILA-pipeline.png" width="100%"/>
21
+ </p>
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ ./environment_setup.sh vila
27
+ ```
28
+
29
+ ## Evaluations
30
+ Please refer to `scripts/v1_5/eval/needle.sh`, `scripts/v1_5/eval/video_chatgpt/run_vila_benchmark.sh`, and `llava/eval/video_mme/eval.sh` for needle-in-a-haystack, LongVILA-Caption, and Video MME evaluations.
31
+
32
+
33
+ > [!Note]
34
+ > 💡**Sequence Parallelism Configuration**
35
+ >
36
+ > To enable sequence parallelism, you can set the following parameters in the training script:
37
+ >
38
+ > `seq_parallel_size`:The degree of sequence parallelism (SP). SP is disabled by default (value: -1).
39
+ >
40
+ > `seq_parallel_ring_size`: The communication process group size using optimized Ring Attention approach in SP. Ring Attention approach is disabled by default in SP.
41
+ >
42
+ > `seq_parallel_ring_type`: Ring Attention implementation. Support ['ring_varlen', 'zigzag_ring_varlen'] in 2D attention. Only works when *seq_parallel_ring_size* > 1.
43
+ >
44
+ > Please note that when SP is enabled, we treat each group of seq_parallel_size GPUs as a single device, with the global batch size calculated as the product of the per-device batch size and the data parallelism size.
45
+
46
+
47
+
48
+ ## 🔒 License
49
+
50
+ - The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
51
+ - The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
52
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
53
+ - [Model License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. For LLAMA3-VILA checkpoints terms of use, please refer to the [LLAMA3 License](https://llama.meta.com/llama3/license/) for additional details.
54
+ - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
55
+ - [Dataset Licenses](./data_prepare/LICENSE) for each one used during training.
56
+
57
+ ## Citations
58
+
59
+ ```
60
+ @article{longvila,
61
+ title={LongVILA: Scaling Long-Context Visual Language Models for Long Videos},
62
+ author={Fuzhao Xue and Yukang Chen and Dacheng Li and Qinghao Hu and Ligeng Zhu and Xiuyu Li and Yunhao Fang and Haotian Tang and Shang Yang and Zhijian Liu and Yihui He and Hongxu Yin and Pavlo Molchanov and Jan Kautz and Linxi Fan and Yuke Zhu and Yao Lu and Song Han},
63
+ year={2024},
64
+ eprint={2408.10188},
65
+ archivePrefix={arXiv},
66
+ primaryClass={cs.CV}
67
+ }
68
+ ```
69
+
70
+ # Acknowledgement
71
+
72
+ - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. Thanks for their wonderful work.
73
+ - [LongVA](https://github.com/EvolvingLMMs-Lab/LongVA): we borrowed the long video needle in the haystack evaluation script from this repository.
74
+ - [LongLoRA](https://github.com/dvlab-research/LongLoRA): we modified the low-rank long-context fine-tuning code from this repository.
75
+ - [USP (YunChang)](https://github.com/feifeibear/long-context-attention): we adopted the 2D attention implementation from this repository.
76
+ - [RingFlashAttention](https://github.com/zhuzilin/ring-flash-attention): we adopted the ring flash attention implementation from this repository.
77
+ - [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed): we adopted the all-to-all implementation from this repository.
78
+ - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): we borrowed video evaluation script from this repository.
79
+ - [MMC4](https://github.com/allenai/mmc4), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT), [OpenORCA/FLAN](https://huggingface.co/datasets/Open-Orca/FLAN), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V), [WIT](google-research-datasets/wit), [GSM8K-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel/blob/main/data/train_use.jsonl), [VisualGenome](https://visualgenome.org/api/v0/api_home.html), [VCR](https://visualcommonsense.com/download/), [ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA), [Shot2Story](https://github.com/bytedance/Shot2Story/blob/master/DATA.md), [Youcook2](http://youcook2.eecs.umich.edu/), [Vatex](https://eric-xw.github.io/vatex-website/download.html), [ShareGPT-Video](https://huggingface.co/datasets/ShareGPTVideo/train_video_and_instruction), [ShareGPT4o](https://sharegpt4o.github.io/) for providing datasets used in this research.
VILA/convert_ckpt.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import json
18
+ import os.path as osp
19
+ from collections import OrderedDict
20
+ from glob import glob
21
+
22
+ from safetensors import safe_open
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModel,
26
+ AutoModelForCausalLM,
27
+ AutoTokenizer,
28
+ BitsAndBytesConfig,
29
+ LlamaForCausalLM,
30
+ PretrainedConfig,
31
+ PreTrainedModel,
32
+ )
33
+
34
+ import llava.model.language_model.llava_llama
35
+ from llava.model import *
36
+ from llava.model.configuration_llava import LlavaConfig
37
+ from llava.model.language_model.builder import build_llm_and_tokenizer
38
+ from llava.model.multimodal_encoder.builder import SiglipVisionTower, build_vision_tower
39
+ from llava.model.multimodal_encoder.siglip import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
40
+ from llava.model.multimodal_projector.builder import build_mm_projector
41
+ from llava.model.utils import get_model_config
42
+
43
+
44
+ def main(
45
+ path="~/workspace/VILA/checkpoints/Llama-2-7b-hf-google/siglip-large-patch16-384-align-llava_1_5_mm_align",
46
+ output_dir="checkpoints/converted_models",
47
+ ):
48
+ path = osp.expanduser(path)
49
+ # assuming 7b llama + siglip
50
+ config = AutoConfig.from_pretrained("CI-new-format-llama7b-siglip")
51
+ model = AutoModel.from_config(config)
52
+
53
+ # kep mapping
54
+ state_dict = {}
55
+
56
+ def fn(k):
57
+ if (
58
+ k.startswith("model.layers")
59
+ or k.startswith("model.norm")
60
+ or k.startswith("model.embed_tokens")
61
+ or k.startswith("lm_head")
62
+ ):
63
+ # llm layer
64
+ new_k = "llm." + k
65
+ return new_k
66
+ if k.startswith("model.vision_tower.vision_tower.vision_model."):
67
+ new_k = k.replace(
68
+ "model.vision_tower.vision_tower.vision_model.", "vision_tower.vision_tower.vision_model."
69
+ )
70
+ return new_k
71
+ if k.startswith("model.mm_projector"):
72
+ new_k = k.replace("model.mm_projector.", "mm_projector.layers.")
73
+ return new_k
74
+ return k
75
+
76
+ for sf in glob(osp.join(path, "*.safetensors")):
77
+ with safe_open(sf, framework="pt") as f:
78
+ for key in f.keys():
79
+ state_dict[fn(key)] = f.get_tensor(key)
80
+
81
+ for k in state_dict.keys():
82
+ assert k in model.state_dict().keys()
83
+
84
+ model.load_state_dict(state_dict)
85
+ model.save_pretrained(output_dir)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ import fire
90
+
91
+ fire.Fire(main)
VILA/environment_setup.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # This is required to activate conda environment
4
+ eval "$(conda shell.bash hook)"
5
+
6
+ # CONDA_ENV=${1:-""}
7
+ CONDA_ENV=vila
8
+ if [ -n "$CONDA_ENV" ]; then
9
+ conda create -n $CONDA_ENV python=3.10 -y
10
+ conda activate $CONDA_ENV
11
+ else
12
+ echo "Skipping conda environment creation. Make sure you have the correct environment activated."
13
+ fi
14
+
15
+ # This is required to enable PEP 660 support
16
+ pip install --upgrade pip
17
+
18
+ # This is optional if you prefer to use built-in nvcc
19
+ conda install -c nvidia cuda-toolkit -y
20
+
21
+ # Install FlashAttention2
22
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
23
+
24
+ # Install VILA
25
+ pip install -e .
26
+ pip install -e ".[train]"
27
+ pip install -e ".[eval]"
28
+
29
+ # Install HF's Transformers
30
+ pip install git+https://github.com/huggingface/transformers@v4.37.2
31
+ site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
32
+ cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
33
+ cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/
VILA/predict.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is originated from: https://github.com/haotian-liu/LLaVA/
17
+
18
+ import os
19
+ import subprocess
20
+ import time
21
+ from io import BytesIO
22
+ from threading import Thread
23
+
24
+ import requests
25
+ import torch
26
+ from cog import BasePredictor, ConcatenateIterator, Input, Path
27
+ from PIL import Image
28
+ from transformers.generation.streamers import TextIteratorStreamer
29
+
30
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
31
+ from llava.conversation import SeparatorStyle, conv_templates
32
+ from llava.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token
33
+ from llava.model.builder import load_pretrained_model
34
+ from llava.utils import disable_torch_init
35
+
36
+ os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
37
+
38
+ # url for the weights mirror
39
+ REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
40
+ # files to download from the weights mirrors
41
+ weights = [
42
+ {
43
+ "dest": "liuhaotian/llava-v1.5-13b",
44
+ # git commit hash from huggingface
45
+ "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
46
+ "files": [
47
+ "config.json",
48
+ "generation_config.json",
49
+ "pytorch_model-00001-of-00003.bin",
50
+ "pytorch_model-00002-of-00003.bin",
51
+ "pytorch_model-00003-of-00003.bin",
52
+ "pytorch_model.bin.index.json",
53
+ "special_tokens_map.json",
54
+ "tokenizer.model",
55
+ "tokenizer_config.json",
56
+ ],
57
+ },
58
+ {
59
+ "dest": "openai/clip-vit-large-patch14-336",
60
+ "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
61
+ "files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"],
62
+ },
63
+ ]
64
+
65
+
66
+ def download_json(url: str, dest: Path):
67
+ res = requests.get(url, allow_redirects=True)
68
+ if res.status_code == 200 and res.content:
69
+ with dest.open("wb") as f:
70
+ f.write(res.content)
71
+ else:
72
+ print(f"Failed to download {url}. Status code: {res.status_code}")
73
+
74
+
75
+ def download_weights(baseurl: str, basedest: str, files: list[str]):
76
+ basedest = Path(basedest)
77
+ start = time.time()
78
+ print("downloading to: ", basedest)
79
+ basedest.mkdir(parents=True, exist_ok=True)
80
+ for f in files:
81
+ dest = basedest / f
82
+ url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
83
+ if not dest.exists():
84
+ print("downloading url: ", url)
85
+ if dest.suffix == ".json":
86
+ download_json(url, dest)
87
+ else:
88
+ subprocess.check_call(["pget", url, str(dest)], close_fds=False)
89
+ print("downloading took: ", time.time() - start)
90
+
91
+
92
+ class Predictor(BasePredictor):
93
+ def setup(self) -> None:
94
+ """Load the model into memory to make running multiple predictions efficient"""
95
+ for weight in weights:
96
+ download_weights(weight["src"], weight["dest"], weight["files"])
97
+ disable_torch_init()
98
+
99
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
100
+ "liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
101
+ )
102
+
103
+ def predict(
104
+ self,
105
+ image: Path = Input(description="Input image"),
106
+ prompt: str = Input(description="Prompt to use for text generation"),
107
+ top_p: float = Input(
108
+ description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
109
+ ge=0.0,
110
+ le=1.0,
111
+ default=1.0,
112
+ ),
113
+ temperature: float = Input(
114
+ description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic",
115
+ default=0.2,
116
+ ge=0.0,
117
+ ),
118
+ max_tokens: int = Input(
119
+ description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0
120
+ ),
121
+ ) -> ConcatenateIterator[str]:
122
+ """Run a single prediction on the model"""
123
+
124
+ conv_mode = "llava_v1"
125
+ conv = conv_templates[conv_mode].copy()
126
+
127
+ image_data = load_image(str(image))
128
+ image_tensor = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"].half().cuda()
129
+
130
+ # loop start
131
+
132
+ # just one turn, always prepend image token
133
+ inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt
134
+ conv.append_message(conv.roles[0], inp)
135
+
136
+ conv.append_message(conv.roles[1], None)
137
+ prompt = conv.get_prompt()
138
+
139
+ input_ids = (
140
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
141
+ )
142
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
143
+ keywords = [stop_str]
144
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
145
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0)
146
+
147
+ with torch.inference_mode():
148
+ thread = Thread(
149
+ target=self.model.generate,
150
+ kwargs=dict(
151
+ inputs=input_ids,
152
+ images=image_tensor,
153
+ do_sample=True,
154
+ temperature=temperature,
155
+ top_p=top_p,
156
+ max_new_tokens=max_tokens,
157
+ streamer=streamer,
158
+ use_cache=True,
159
+ stopping_criteria=[stopping_criteria],
160
+ ),
161
+ )
162
+ thread.start()
163
+ # workaround: second-to-last token is always " "
164
+ # but we want to keep it if it's not the second-to-last token
165
+ prepend_space = False
166
+ for new_text in streamer:
167
+ if new_text == " ":
168
+ prepend_space = True
169
+ continue
170
+ if new_text.endswith(stop_str):
171
+ new_text = new_text[: -len(stop_str)].strip()
172
+ prepend_space = False
173
+ elif prepend_space:
174
+ new_text = " " + new_text
175
+ prepend_space = False
176
+ if len(new_text):
177
+ yield new_text
178
+ if prepend_space:
179
+ yield " "
180
+ thread.join()
181
+
182
+
183
+ def load_image(image_file):
184
+ if image_file.startswith("http") or image_file.startswith("https"):
185
+ response = requests.get(image_file)
186
+ image = Image.open(BytesIO(response.content)).convert("RGB")
187
+ else:
188
+ image = Image.open(image_file).convert("RGB")
189
+ return image
VLMEvalKit/.pre-commit-config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: |
2
+ (?x)^(
3
+ scripts/|
4
+ assets/|
5
+ vlmeval/config.py
6
+ )
7
+ repos:
8
+ - repo: https://github.com/PyCQA/flake8
9
+ rev: 6.1.0
10
+ hooks:
11
+ - id: flake8
12
+ args: ["--max-line-length=120", "--ignore=F401,F403,F405,E402,E722,E741,W503,E231,E702"]
13
+ exclude: ^configs/
14
+ - repo: https://github.com/pre-commit/mirrors-yapf
15
+ rev: v0.30.0
16
+ hooks:
17
+ - id: yapf
18
+ args: ["--style={column_limit=120}"]
19
+ - repo: https://github.com/pre-commit/pre-commit-hooks
20
+ rev: v3.1.0
21
+ hooks:
22
+ - id: trailing-whitespace
23
+ - id: check-yaml
24
+ - id: end-of-file-fixer
25
+ - id: requirements-txt-fixer
26
+ - id: check-merge-conflict
27
+ - id: fix-encoding-pragma
28
+ args: ["--remove"]
29
+ - id: mixed-line-ending
30
+ args: ["--fix=lf"]
VLMEvalKit/requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord; platform_machine != 'arm64'
2
+ eva-decord; platform_machine == 'arm64'
3
+ gradio
4
+ huggingface_hub
5
+ imageio
6
+ matplotlib
7
+ numpy
8
+ omegaconf
9
+ openai
10
+ opencv-python>=4.4.0.46
11
+ openpyxl
12
+ pandas
13
+ pillow
14
+ portalocker
15
+ protobuf
16
+ python-dotenv
17
+ requests
18
+ rich
19
+ sentencepiece
20
+ setuptools
21
+ sty
22
+ tabulate
23
+ tiktoken
24
+ timeout-decorator
25
+ torch
26
+ tqdm
27
+ transformers
28
+ typing_extensions
29
+ validators
30
+ xlsxwriter
a_distributed_notebook/FSDP_tutorial.md ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Getting Started with Fully Sharded Data Parallel(FSDP)
2
+ ======================================================
3
+
4
+ **Author**: [Hamid Shojanazeri](https://github.com/HamidShojanazeri),
5
+ [Yanli Zhao](https://github.com/zhaojuanmao), [Shen
6
+ Li](https://mrshenli.github.io/)
7
+
8
+ ::: {.note}
9
+ ::: {.title}
10
+ Note
11
+ :::
12
+
13
+ View and edit this tutorial in
14
+ [github](https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP_tutorial.rst).
15
+ :::
16
+
17
+ Training AI models at a large scale is a challenging task that requires
18
+ a lot of compute power and resources. It also comes with considerable
19
+ engineering complexity to handle the training of these very large
20
+ models. [PyTorch
21
+ FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/),
22
+ released in PyTorch 1.11 makes this easier.
23
+
24
+ In this tutorial, we show how to use [FSDP
25
+ APIs](https://pytorch.org/docs/stable/fsdp.html), for simple MNIST
26
+ models that can be extended to other larger models such as [HuggingFace
27
+ BERT models](https://huggingface.co/blog/zero-deepspeed-fairscale), [GPT
28
+ 3 models up to 1T
29
+ parameters](https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff)
30
+ . The sample DDP MNIST code has been borrowed from
31
+ [here](https://github.com/yqhu/mnist_examples).
32
+
33
+ How FSDP works
34
+ --------------
35
+
36
+ In
37
+ [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html),
38
+ (DDP) training, each process/ worker owns a replica of the model and
39
+ processes a batch of data, finally it uses all-reduce to sum up
40
+ gradients over different workers. In DDP the model weights and optimizer
41
+ states are replicated across all workers. FSDP is a type of data
42
+ parallelism that shards model parameters, optimizer states and gradients
43
+ across DDP ranks.
44
+
45
+ When training with FSDP, the GPU memory footprint is smaller than when
46
+ training with DDP across all workers. This makes the training of some
47
+ very large models feasible by allowing larger models or batch sizes to
48
+ fit on device. This comes with the cost of increased communication
49
+ volume. The communication overhead is reduced by internal optimizations
50
+ like overlapping communication and computation.
51
+
52
+ ![FSDP
53
+ Workflow](/_static/img/distributed/fsdp_workflow.png){.align-center
54
+ width="100.0%"}
55
+
56
+ At a high level FSDP works as follow:
57
+
58
+ *In constructor*
59
+
60
+ - Shard model parameters and each rank only keeps its own shard
61
+
62
+ *In forward path*
63
+
64
+ - Run all\_gather to collect all shards from all ranks to recover the
65
+ full parameter in this FSDP unit
66
+ - Run forward computation
67
+ - Discard parameter shards it has just collected
68
+
69
+ *In backward path*
70
+
71
+ - Run all\_gather to collect all shards from all ranks to recover the
72
+ full parameter in this FSDP unit
73
+ - Run backward computation
74
+ - Run reduce\_scatter to sync gradients
75
+ - Discard parameters.
76
+
77
+ One way to view FSDP\'s sharding is to decompose the DDP gradient
78
+ all-reduce into reduce-scatter and all-gather. Specifically, during the
79
+ backward pass, FSDP reduces and scatters gradients, ensuring that each
80
+ rank possesses a shard of the gradients. Then it updates the
81
+ corresponding shard of the parameters in the optimizer step. Finally, in
82
+ the subsequent forward pass, it performs an all-gather operation to
83
+ collect and combine the updated parameter shards.
84
+
85
+ ![FSDP
86
+ Allreduce](/_static/img/distributed/fsdp_sharding.png){.align-center
87
+ width="100.0%"}
88
+
89
+ How to use FSDP
90
+ ---------------
91
+
92
+ Here we use a toy model to run training on the MNIST dataset for
93
+ demonstration purposes. The APIs and logic can be applied to training
94
+ larger models as well.
95
+
96
+ *Setup*
97
+
98
+ 1.1 Install PyTorch along with Torchvision
99
+
100
+ See the [Get Started guide](https://pytorch.org/get-started/locally/)
101
+ for information on installation.
102
+
103
+ We add the following code snippets to a python script "FSDP\_mnist.py".
104
+
105
+ 1.2 Import necessary packages
106
+
107
+ ::: {.note}
108
+ ::: {.title}
109
+ Note
110
+ :::
111
+
112
+ This tutorial is intended for PyTorch versions 1.12 and later. If you
113
+ are using an earlier version, replace all instances of
114
+ [size\_based\_auto\_wrap\_policy]{.title-ref} with
115
+ [default\_auto\_wrap\_policy]{.title-ref} and
116
+ [fsdp\_auto\_wrap\_policy]{.title-ref} with
117
+ [auto\_wrap\_policy]{.title-ref}.
118
+ :::
119
+
120
+ ``` {.python}
121
+ # Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
122
+ import os
123
+ import argparse
124
+ import functools
125
+ import torch
126
+ import torch.nn as nn
127
+ import torch.nn.functional as F
128
+ import torch.optim as optim
129
+ from torchvision import datasets, transforms
130
+
131
+
132
+ from torch.optim.lr_scheduler import StepLR
133
+
134
+ import torch.distributed as dist
135
+ import torch.multiprocessing as mp
136
+ from torch.nn.parallel import DistributedDataParallel as DDP
137
+ from torch.utils.data.distributed import DistributedSampler
138
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
139
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
140
+ CPUOffload,
141
+ BackwardPrefetch,
142
+ )
143
+ from torch.distributed.fsdp.wrap import (
144
+ size_based_auto_wrap_policy,
145
+ enable_wrap,
146
+ wrap,
147
+ )
148
+ ```
149
+
150
+ 1.3 Distributed training setup. As we mentioned FSDP is a type of data
151
+ parallelism which requires a distributed training environment, so here
152
+ we use two helper functions to initialize the processes for distributed
153
+ training and clean up.
154
+
155
+ ``` {.python}
156
+ def setup(rank, world_size):
157
+ os.environ['MASTER_ADDR'] = 'localhost'
158
+ os.environ['MASTER_PORT'] = '12355'
159
+
160
+ # initialize the process group
161
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
162
+
163
+ def cleanup():
164
+ dist.destroy_process_group()
165
+ ```
166
+
167
+ 2.1 Define our toy model for handwritten digit classification.
168
+
169
+ ``` {.python}
170
+ class Net(nn.Module):
171
+ def __init__(self):
172
+ super(Net, self).__init__()
173
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
174
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
175
+ self.dropout1 = nn.Dropout(0.25)
176
+ self.dropout2 = nn.Dropout(0.5)
177
+ self.fc1 = nn.Linear(9216, 128)
178
+ self.fc2 = nn.Linear(128, 10)
179
+
180
+ def forward(self, x):
181
+
182
+ x = self.conv1(x)
183
+ x = F.relu(x)
184
+ x = self.conv2(x)
185
+ x = F.relu(x)
186
+ x = F.max_pool2d(x, 2)
187
+ x = self.dropout1(x)
188
+ x = torch.flatten(x, 1)
189
+ x = self.fc1(x)
190
+ x = F.relu(x)
191
+ x = self.dropout2(x)
192
+ x = self.fc2(x)
193
+ output = F.log_softmax(x, dim=1)
194
+ return output
195
+ ```
196
+
197
+ 2.2 Define a train function
198
+
199
+ ``` {.python}
200
+ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
201
+ model.train()
202
+ ddp_loss = torch.zeros(2).to(rank)
203
+ if sampler:
204
+ sampler.set_epoch(epoch)
205
+ for batch_idx, (data, target) in enumerate(train_loader):
206
+ data, target = data.to(rank), target.to(rank)
207
+ optimizer.zero_grad()
208
+ output = model(data)
209
+ loss = F.nll_loss(output, target, reduction='sum')
210
+ loss.backward()
211
+ optimizer.step()
212
+ ddp_loss[0] += loss.item()
213
+ ddp_loss[1] += len(data)
214
+
215
+ dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
216
+ if rank == 0:
217
+ print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
218
+ ```
219
+
220
+ 2.3 Define a validation function
221
+
222
+ ``` {.python}
223
+ def test(model, rank, world_size, test_loader):
224
+ model.eval()
225
+ correct = 0
226
+ ddp_loss = torch.zeros(3).to(rank)
227
+ with torch.no_grad():
228
+ for data, target in test_loader:
229
+ data, target = data.to(rank), target.to(rank)
230
+ output = model(data)
231
+ ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
232
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
233
+ ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
234
+ ddp_loss[2] += len(data)
235
+
236
+ dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
237
+
238
+ if rank == 0:
239
+ test_loss = ddp_loss[0] / ddp_loss[2]
240
+ print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
241
+ test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
242
+ 100. * ddp_loss[1] / ddp_loss[2]))
243
+ ```
244
+
245
+ 2.4 Define a distributed train function that wraps the model in FSDP
246
+
247
+ **Note: to save the FSDP model, we need to call the state\_dict on each
248
+ rank then on Rank 0 save the overall states.**
249
+
250
+ ``` {.python}
251
+ def fsdp_main(rank, world_size, args):
252
+ setup(rank, world_size)
253
+
254
+ transform=transforms.Compose([
255
+ transforms.ToTensor(),
256
+ transforms.Normalize((0.1307,), (0.3081,))
257
+ ])
258
+
259
+ dataset1 = datasets.MNIST('../data', train=True, download=True,
260
+ transform=transform)
261
+ dataset2 = datasets.MNIST('../data', train=False,
262
+ transform=transform)
263
+
264
+ sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
265
+ sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
266
+
267
+ train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
268
+ test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
269
+ cuda_kwargs = {'num_workers': 2,
270
+ 'pin_memory': True,
271
+ 'shuffle': False}
272
+ train_kwargs.update(cuda_kwargs)
273
+ test_kwargs.update(cuda_kwargs)
274
+
275
+ train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
276
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
277
+ my_auto_wrap_policy = functools.partial(
278
+ size_based_auto_wrap_policy, min_num_params=100
279
+ )
280
+ torch.cuda.set_device(rank)
281
+
282
+
283
+ init_start_event = torch.cuda.Event(enable_timing=True)
284
+ init_end_event = torch.cuda.Event(enable_timing=True)
285
+
286
+ model = Net().to(rank)
287
+
288
+ model = FSDP(model)
289
+
290
+ optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
291
+
292
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
293
+ init_start_event.record()
294
+ for epoch in range(1, args.epochs + 1):
295
+ train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
296
+ test(model, rank, world_size, test_loader)
297
+ scheduler.step()
298
+
299
+ init_end_event.record()
300
+
301
+ if rank == 0:
302
+ print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
303
+ print(f"{model}")
304
+
305
+ if args.save_model:
306
+ # use a barrier to make sure training is done on all ranks
307
+ dist.barrier()
308
+ states = model.state_dict()
309
+ if rank == 0:
310
+ torch.save(states, "mnist_cnn.pt")
311
+
312
+ cleanup()
313
+ ```
314
+
315
+ 2.5 Finally, parse the arguments and set the main function
316
+
317
+ ``` {.python}
318
+ if __name__ == '__main__':
319
+ # Training settings
320
+ parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
321
+ parser.add_argument('--batch-size', type=int, default=64, metavar='N',
322
+ help='input batch size for training (default: 64)')
323
+ parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
324
+ help='input batch size for testing (default: 1000)')
325
+ parser.add_argument('--epochs', type=int, default=10, metavar='N',
326
+ help='number of epochs to train (default: 14)')
327
+ parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
328
+ help='learning rate (default: 1.0)')
329
+ parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
330
+ help='Learning rate step gamma (default: 0.7)')
331
+ parser.add_argument('--no-cuda', action='store_true', default=False,
332
+ help='disables CUDA training')
333
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
334
+ help='random seed (default: 1)')
335
+ parser.add_argument('--save-model', action='store_true', default=False,
336
+ help='For Saving the current Model')
337
+ args = parser.parse_args()
338
+
339
+ torch.manual_seed(args.seed)
340
+
341
+ WORLD_SIZE = torch.cuda.device_count()
342
+ mp.spawn(fsdp_main,
343
+ args=(WORLD_SIZE, args),
344
+ nprocs=WORLD_SIZE,
345
+ join=True)
346
+ ```
347
+
348
+ We have recorded cuda events to measure the time of FSDP model
349
+ specifics. The CUDA event time was 110.85 seconds.
350
+
351
+ ``` {.bash}
352
+ python FSDP_mnist.py
353
+
354
+ CUDA event elapsed time on training loop 40.67462890625sec
355
+ ```
356
+
357
+ Wrapping the model with FSDP, the model will look as follows, we can see
358
+ the model has been wrapped in one FSDP unit. Alternatively, we will look
359
+ at adding the auto\_wrap\_policy next and will discuss the differences.
360
+
361
+ ``` {.bash}
362
+ FullyShardedDataParallel(
363
+ (_fsdp_wrapped_module): FlattenParamsWrapper(
364
+ (_fpw_module): Net(
365
+ (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
366
+ (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
367
+ (dropout1): Dropout(p=0.25, inplace=False)
368
+ (dropout2): Dropout(p=0.5, inplace=False)
369
+ (fc1): Linear(in_features=9216, out_features=128, bias=True)
370
+ (fc2): Linear(in_features=128, out_features=10, bias=True)
371
+ )
372
+ )
373
+ )
374
+ ```
375
+
376
+ The following is the peak memory usage from FSDP MNIST training on
377
+ g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch
378
+ Profiler.
379
+
380
+ ![FSDP Peak Memory
381
+ Usage](/_static/img/distributed/FSDP_memory.gif){.align-center
382
+ width="100.0%"}
383
+
384
+ Applying *auto\_wrap\_policy* in FSDP otherwise, FSDP will put the
385
+ entire model in one FSDP unit, which will reduce computation efficiency
386
+ and memory efficiency. The way it works is that, suppose your model
387
+ contains 100 Linear layers. If you do FSDP(model), there will only be
388
+ one FSDP unit which wraps the entire model. In that case, the allgather
389
+ would collect the full parameters for all 100 linear layers, and hence
390
+ won\'t save CUDA memory for parameter sharding. Also, there is only one
391
+ blocking allgather call for the all 100 linear layers, there will not be
392
+ communication and computation overlapping between layers.
393
+
394
+ To avoid that, you can pass in an auto\_wrap\_policy, which will seal
395
+ the current FSDP unit and start a new one automatically when the
396
+ specified condition is met (e.g., size limit). In that way you will have
397
+ multiple FSDP units, and only one FSDP unit needs to collect full
398
+ parameters at a time. E.g., suppose you have 5 FSDP units, and each
399
+ wraps 20 linear layers. Then, in the forward, the 1st FSDP unit will
400
+ allgather parameters for the first 20 linear layers, do computation,
401
+ discard the parameters and then move on to the next 20 linear layers.
402
+ So, at any point in time, each rank only materializes parameters/grads
403
+ for 20 linear layers instead of 100.
404
+
405
+ To do so in 2.4 we define the auto\_wrap\_policy and pass it to FSDP
406
+ wrapper, in the following example, my\_auto\_wrap\_policy defines that a
407
+ layer could be wrapped or sharded by FSDP if the number of parameters in
408
+ this layer is larger than 100. If the number of parameters in this layer
409
+ is smaller than 100, it will be wrapped with other small layers together
410
+ by FSDP. Finding an optimal auto wrap policy is challenging, PyTorch
411
+ will add auto tuning for this config in the future. Without an auto
412
+ tuning tool, it is good to profile your workflow using different auto
413
+ wrap policies experimentally and find the optimal one.
414
+
415
+ ``` {.python}
416
+ my_auto_wrap_policy = functools.partial(
417
+ size_based_auto_wrap_policy, min_num_params=20000
418
+ )
419
+ torch.cuda.set_device(rank)
420
+ model = Net().to(rank)
421
+
422
+ model = FSDP(model,
423
+ auto_wrap_policy=my_auto_wrap_policy)
424
+ ```
425
+
426
+ Applying the auto\_wrap\_policy, the model would be as follows:
427
+
428
+ ``` {.bash}
429
+ FullyShardedDataParallel(
430
+ (_fsdp_wrapped_module): FlattenParamsWrapper(
431
+ (_fpw_module): Net(
432
+ (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
433
+ (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
434
+ (dropout1): Dropout(p=0.25, inplace=False)
435
+ (dropout2): Dropout(p=0.5, inplace=False)
436
+ (fc1): FullyShardedDataParallel(
437
+ (_fsdp_wrapped_module): FlattenParamsWrapper(
438
+ (_fpw_module): Linear(in_features=9216, out_features=128, bias=True)
439
+ )
440
+ )
441
+ (fc2): Linear(in_features=128, out_features=10, bias=True)
442
+ )
443
+ )
444
+ ```
445
+
446
+ ``` {.bash}
447
+ python FSDP_mnist.py
448
+
449
+ CUDA event elapsed time on training loop 41.89130859375sec
450
+ ```
451
+
452
+ The following is the peak memory usage from FSDP with auto\_wrap policy
453
+ of MNIST training on a g4dn.12.xlarge AWS EC2 instance with 4 GPUs
454
+ captured from PyTorch Profiler. It can be observed that the peak memory
455
+ usage on each device is smaller compared to FSDP without auto wrap
456
+ policy applied, from \~75 MB to 66 MB.
457
+
458
+ ![FSDP Peak Memory Usage using Auto\_wrap
459
+ policy](/_static/img/distributed/FSDP_autowrap.gif){.align-center
460
+ width="100.0%"}
461
+
462
+ *CPU Off-loading*: In case the model is very large that even with FSDP
463
+ wouldn\'t fit into GPUs, then CPU offload can be helpful here.
464
+
465
+ Currently, only parameter and gradient CPU offload is supported. It can
466
+ be enabled via passing in cpu\_offload=CPUOffload(offload\_params=True).
467
+
468
+ Note that this currently implicitly enables gradient offloading to CPU
469
+ in order for params and grads to be on the same device to work with the
470
+ optimizer. This API is subject to change. The default is None in which
471
+ case there will be no offloading.
472
+
473
+ Using this feature may slow down the training considerably, due to
474
+ frequent copying of tensors from host to device, but it could help
475
+ improve memory efficiency and train larger scale models.
476
+
477
+ In 2.4 we just add it to the FSDP wrapper
478
+
479
+ ``` {.python}
480
+ model = FSDP(model,
481
+ auto_wrap_policy=my_auto_wrap_policy,
482
+ cpu_offload=CPUOffload(offload_params=True))
483
+ ```
484
+
485
+ Compare it with DDP, if in 2.4 we just normally wrap the model in DPP,
486
+ saving the changes in "DDP\_mnist.py".
487
+
488
+ ``` {.python}
489
+ model = Net().to(rank)
490
+ model = DDP(model)
491
+ ```
492
+
493
+ ``` {.bash}
494
+ python DDP_mnist.py
495
+
496
+ CUDA event elapsed time on training loop 39.77766015625sec
497
+ ```
498
+
499
+ The following is the peak memory usage from DDP MNIST training on
500
+ g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch
501
+ profiler.
502
+
503
+ ![DDP Peak Memory Usage using Auto\_wrap
504
+ policy](/_static/img/distributed/DDP_memory.gif){.align-center
505
+ width="100.0%"}
506
+
507
+ Considering the toy example and tiny MNIST model we defined here, we can
508
+ observe the difference between peak memory usage of DDP and FSDP. In DDP
509
+ each process holds a replica of the model, so the memory footprint is
510
+ higher compared to FSDP which shards the model parameters, optimizer
511
+ states and gradients over DDP ranks. The peak memory usage using FSDP
512
+ with auto\_wrap policy is the lowest followed by FSDP and DDP.
513
+
514
+ Also, looking at timings, considering the small model and running the
515
+ training on a single machine, FSDP with and without auto\_wrap policy
516
+ performed almost as fast as DDP. This example does not represent most of
517
+ the real applications, for detailed analysis and comparison between DDP
518
+ and FSDP please refer to this [blog
519
+ post](https://pytorch.medium.com/6c8da2be180d) .
a_distributed_notebook/temp/all_gather.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import os
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.multiprocessing as mp
7
+ from pprint import pprint
8
+ import time
9
+
10
+ def my_print(*args, **kwargs):
11
+ if dist.get_rank() == 0:
12
+ print(*args, **kwargs)
13
+ else:
14
+ time.sleep(0.01)
15
+ print(*args, **kwargs)
16
+
17
+ class AllGatherLB(torch.autograd.Function):
18
+ """
19
+ An autograd function that performs allgather on a tensor.
20
+ This function only performs local-backpropagation on a single GPU.
21
+ It has lower convergence and lower efficiency compared to global-backpropagation.
22
+ """
23
+
24
+ @staticmethod
25
+ def forward(ctx, tensor, rank, world_size):
26
+ output = [torch.empty_like(tensor) for _ in range(world_size)]
27
+ dist.all_gather(output, tensor)
28
+ ctx.rank = rank
29
+ ctx.batch_size = tensor.shape[0]
30
+ return torch.cat(output, 0)
31
+
32
+ @staticmethod
33
+ def backward(ctx, grad_output):
34
+ return (
35
+ grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
36
+ None,
37
+ None,
38
+ )
39
+
40
+
41
+ class AllGatherGB(torch.autograd.Function):
42
+ """
43
+ An autograd function that performs allgather on a tensor.
44
+ Global-backprogation on all GPUs.
45
+ This function has higher convergence and higher efficiency compared to local-backpropagation.
46
+ This function is used as default for gather strategy.
47
+ """
48
+
49
+ @staticmethod
50
+ def forward(ctx, tensor):
51
+ world_size = dist.get_world_size()
52
+ output = [torch.empty_like(tensor) for _ in range(world_size)]
53
+ dist.all_gather(output, tensor)
54
+ ctx.world_size = world_size
55
+ return torch.cat(output, 0)
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_output):
59
+ batch_size = grad_output.shape[0] // ctx.world_size
60
+ rank = dist.get_rank()
61
+ in_grad = grad_output.clone()
62
+
63
+ my_print("Rank ", rank, " has in_grad before all reduce", in_grad)
64
+ dist.all_reduce(in_grad, op=dist.ReduceOp.SUM)
65
+ my_print("Rank ", rank, " has in_grad after all reduce", in_grad)
66
+ return (in_grad[batch_size * rank : batch_size * (rank + 1)],)
67
+
68
+
69
+ all_gather = AllGatherGB.apply
70
+
71
+ """ All-Reduce example."""
72
+ def run(rank, size):
73
+ """ Simple collective communication. """
74
+ # group = dist.new_group([0, 1])
75
+ BATCH_SIZE = 4
76
+ LENGTH = 3
77
+ VECTOR_DIM = 2
78
+ tensor = torch.zeros(BATCH_SIZE, LENGTH, VECTOR_DIM) + (2 * rank - 2) # rank 0: -2, rank 1: 0
79
+ tensor.requires_grad = True
80
+ # tensor_list = [torch.zeros(4, 3) for _ in range(size)]
81
+ gather_tensor = all_gather(tensor)
82
+ # data
83
+ # print('Rank ', rank, ' has gather data ', gather_tensor)
84
+ # shape
85
+ print('Rank ', rank, ' has gather shape ', gather_tensor.shape)
86
+
87
+
88
+ loss = gather_tensor ** 2
89
+ # + random tensor
90
+ loss = loss.sum() + torch.rand(1, requires_grad=True)
91
+ loss.backward()
92
+
93
+ # In mathematically, the gradient of gather_tensor is 2 * gather_tensor
94
+
95
+ print('Rank ', rank, ' has final tensor grad ', tensor.grad)
96
+
97
+
98
+ def init_process(rank, size, fn, backend='gloo'):
99
+ """ Initialize the distributed environment. """
100
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
101
+ os.environ['MASTER_PORT'] = '29500'
102
+ dist.init_process_group(backend, rank=rank, world_size=size)
103
+ fn(rank, size)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ size = 2
108
+ processes = []
109
+ mp.set_start_method("spawn")
110
+ for rank in range(size):
111
+ p = mp.Process(target=init_process, args=(rank, size, run))
112
+ p.start()
113
+ processes.append(p)
114
+
115
+ for p in processes:
116
+ p.join()
a_distributed_notebook/temp/run_4.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ import torch.multiprocessing as mp
9
+
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+
12
+ def setup(rank, world_size):
13
+ os.environ['MASTER_ADDR'] = 'localhost'
14
+ os.environ['MASTER_PORT'] = '12355'
15
+
16
+ # initialize the process group
17
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
18
+
19
+ def cleanup():
20
+ dist.destroy_process_group()
21
+
22
+ class ToyModel(nn.Module):
23
+ def __init__(self):
24
+ super(ToyModel, self).__init__()
25
+ self.net1 = nn.Linear(10, 10)
26
+ self.relu = nn.ReLU()
27
+ self.net2 = nn.Linear(10, 5)
28
+
29
+ def forward(self, x):
30
+ return self.net2(self.relu(self.net1(x)))
31
+
32
+
33
+ def demo_basic(rank, world_size):
34
+ print(f"Running basic DDP example on rank {rank}.")
35
+ setup(rank, world_size)
36
+
37
+ # create model and move it to GPU with id rank
38
+ model = ToyModel().to(rank)
39
+ ddp_model = DDP(model, device_ids=[rank])
40
+
41
+ loss_fn = nn.MSELoss()
42
+ optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
43
+
44
+ optimizer.zero_grad()
45
+ outputs = ddp_model(torch.randn(20, 10))
46
+ labels = torch.randn(20, 5).to(rank)
47
+ loss_fn(outputs, labels).backward()
48
+ optimizer.step()
49
+
50
+ cleanup()
51
+ print(f"Finished running basic DDP example on rank {rank}.")
52
+
53
+
54
+ def run_demo(demo_fn, world_size):
55
+ print(f"Running DDP example with {world_size} processes.")
56
+ mp.set_start_method("spawn")
57
+ mp.spawn(demo_fn,
58
+ args=(world_size,),
59
+ nprocs=world_size,
60
+ join=True)
61
+
62
+ if __name__ == "__main__":
63
+ run_demo(demo_basic, 2)
a_main_folder/convert_hf_dataset.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
a_temp/deepseek_vl2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
a_temp/docs.html ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+
4
+ <head>
5
+ <link type="text/css" rel="stylesheet" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
6
+ <link rel="shortcut icon" href="https://fastapi.tiangolo.com/img/favicon.png">
7
+ <title>FastAPI - Swagger UI</title>
8
+ </head>
9
+
10
+ <body>
11
+ <div id="swagger-ui">
12
+ </div>
13
+ <script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
14
+ <!-- `SwaggerUIBundle` is now available on the page -->
15
+ <script>
16
+ const ui = SwaggerUIBundle({
17
+ url: '/openapi.json',
18
+ "dom_id": "#swagger-ui",
19
+ "layout": "BaseLayout",
20
+ "deepLinking": true,
21
+ "showExtensions": true,
22
+ "showCommonExtensions": true,
23
+ oauth2RedirectUrl: window.location.origin + '/docs/oauth2-redirect',
24
+ presets: [
25
+ SwaggerUIBundle.presets.apis,
26
+ SwaggerUIBundle.SwaggerUIStandalonePreset
27
+ ],
28
+ })
29
+ </script>
30
+ </body>
31
+
32
+ </html>
a_temp/example_image.jpg ADDED
a_temp/openapi.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"openapi":"3.1.0","info":{"title":"FastAPI","version":"0.1.0"},"paths":{"/health":{"get":{"summary":"Health","description":"Health check.","operationId":"health_health_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/tokenize":{"post":{"summary":"Tokenize","operationId":"tokenize_tokenize_post","requestBody":{"content":{"application/json":{"schema":{"anyOf":[{"$ref":"#/components/schemas/TokenizeCompletionRequest"},{"$ref":"#/components/schemas/TokenizeChatRequest"}],"title":"Request"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/detokenize":{"post":{"summary":"Detokenize","operationId":"detokenize_detokenize_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/DetokenizeRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/models":{"get":{"summary":"Show Available Models","operationId":"show_available_models_v1_models_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/version":{"get":{"summary":"Show Version","operationId":"show_version_version_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/v1/chat/completions":{"post":{"summary":"Create Chat Completion","operationId":"create_chat_completion_v1_chat_completions_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ChatCompletionRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/completions":{"post":{"summary":"Create Completion","operationId":"create_completion_v1_completions_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/CompletionRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/embeddings":{"post":{"summary":"Create Embedding","operationId":"create_embedding_v1_embeddings_post","requestBody":{"content":{"application/json":{"schema":{"anyOf":[{"$ref":"#/components/schemas/EmbeddingCompletionRequest"},{"$ref":"#/components/schemas/EmbeddingChatRequest"}],"title":"Request"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/pooling":{"post":{"summary":"Create Pooling","operationId":"create_pooling_pooling_post","requestBody":{"content":{"application/json":{"schema":{"anyOf":[{"$ref":"#/components/schemas/EmbeddingCompletionRequest"},{"$ref":"#/components/schemas/EmbeddingChatRequest"}],"title":"Request"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/score":{"post":{"summary":"Create Score","operationId":"create_score_score_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ScoreRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/v1/score":{"post":{"summary":"Create Score V1","operationId":"create_score_v1_v1_score_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/ScoreRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}}},"components":{"schemas":{"Audio":{"properties":{"id":{"type":"string","title":"Id"}},"type":"object","required":["id"],"title":"Audio"},"AudioURL":{"properties":{"url":{"type":"string","title":"Url"}},"type":"object","required":["url"],"title":"AudioURL"},"BaseModel":{"properties":{},"type":"object","title":"BaseModel"},"ChatCompletionAssistantMessageParam":{"properties":{"role":{"type":"string","enum":["assistant"],"const":"assistant","title":"Role"},"audio":{"anyOf":[{"$ref":"#/components/schemas/Audio"},{"type":"null"}]},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartRefusalParam"}]},"type":"array"},{"type":"null"}],"title":"Content"},"function_call":{"anyOf":[{"$ref":"#/components/schemas/FunctionCall"},{"type":"null"}]},"name":{"type":"string","title":"Name"},"refusal":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Refusal"},"tool_calls":{"items":{"$ref":"#/components/schemas/ChatCompletionMessageToolCallParam"},"type":"array","title":"Tool Calls"}},"type":"object","required":["role"],"title":"ChatCompletionAssistantMessageParam"},"ChatCompletionContentPartAudioParam":{"properties":{"audio_url":{"$ref":"#/components/schemas/AudioURL"},"type":{"type":"string","enum":["audio_url"],"const":"audio_url","title":"Type"}},"type":"object","required":["audio_url","type"],"title":"ChatCompletionContentPartAudioParam"},"ChatCompletionContentPartImageParam":{"properties":{"image_url":{"$ref":"#/components/schemas/ImageURL"},"type":{"type":"string","enum":["image_url"],"const":"image_url","title":"Type"}},"type":"object","required":["image_url","type"],"title":"ChatCompletionContentPartImageParam"},"ChatCompletionContentPartInputAudioParam":{"properties":{"input_audio":{"$ref":"#/components/schemas/InputAudio"},"type":{"type":"string","enum":["input_audio"],"const":"input_audio","title":"Type"}},"type":"object","required":["input_audio","type"],"title":"ChatCompletionContentPartInputAudioParam"},"ChatCompletionContentPartRefusalParam":{"properties":{"refusal":{"type":"string","title":"Refusal"},"type":{"type":"string","enum":["refusal"],"const":"refusal","title":"Type"}},"type":"object","required":["refusal","type"],"title":"ChatCompletionContentPartRefusalParam"},"ChatCompletionContentPartTextParam":{"properties":{"text":{"type":"string","title":"Text"},"type":{"type":"string","enum":["text"],"const":"text","title":"Type"}},"type":"object","required":["text","type"],"title":"ChatCompletionContentPartTextParam"},"ChatCompletionContentPartVideoParam":{"properties":{"video_url":{"$ref":"#/components/schemas/VideoURL"},"type":{"type":"string","enum":["video_url"],"const":"video_url","title":"Type"}},"type":"object","required":["video_url","type"],"title":"ChatCompletionContentPartVideoParam"},"ChatCompletionDeveloperMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","enum":["developer"],"const":"developer","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionDeveloperMessageParam"},"ChatCompletionFunctionMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Content"},"name":{"type":"string","title":"Name"},"role":{"type":"string","enum":["function"],"const":"function","title":"Role"}},"type":"object","required":["content","name","role"],"title":"ChatCompletionFunctionMessageParam"},"ChatCompletionMessageToolCallParam":{"properties":{"id":{"type":"string","title":"Id"},"function":{"$ref":"#/components/schemas/Function"},"type":{"type":"string","enum":["function"],"const":"function","title":"Type"}},"type":"object","required":["id","function","type"],"title":"ChatCompletionMessageToolCallParam"},"ChatCompletionNamedFunction":{"properties":{"name":{"type":"string","title":"Name"}},"additionalProperties":true,"type":"object","required":["name"],"title":"ChatCompletionNamedFunction"},"ChatCompletionNamedToolChoiceParam":{"properties":{"function":{"$ref":"#/components/schemas/ChatCompletionNamedFunction"},"type":{"type":"string","enum":["function"],"const":"function","title":"Type","default":"function"}},"additionalProperties":true,"type":"object","required":["function"],"title":"ChatCompletionNamedToolChoiceParam"},"ChatCompletionRequest":{"properties":{"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionDeveloperMessageParam"},{"$ref":"#/components/schemas/ChatCompletionSystemMessageParam"},{"$ref":"#/components/schemas/ChatCompletionUserMessageParam"},{"$ref":"#/components/schemas/ChatCompletionAssistantMessageParam"},{"$ref":"#/components/schemas/ChatCompletionToolMessageParam"},{"$ref":"#/components/schemas/ChatCompletionFunctionMessageParam"},{"$ref":"#/components/schemas/CustomChatCompletionMessageParam"}]},"type":"array","title":"Messages"},"model":{"type":"string","title":"Model"},"frequency_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Frequency Penalty","default":0.0},"logit_bias":{"anyOf":[{"additionalProperties":{"type":"number"},"type":"object"},{"type":"null"}],"title":"Logit Bias"},"logprobs":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Logprobs","default":false},"top_logprobs":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Top Logprobs","default":0},"max_tokens":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Max Tokens","deprecated":true},"max_completion_tokens":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Max Completion Tokens"},"n":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"N","default":1},"presence_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Presence Penalty","default":0.0},"response_format":{"anyOf":[{"$ref":"#/components/schemas/ResponseFormat"},{"type":"null"}]},"seed":{"anyOf":[{"type":"integer","maximum":9.223372036854776e+18,"minimum":-9.223372036854776e+18},{"type":"null"}],"title":"Seed"},"stop":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Stop"},"stream":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Stream","default":false},"stream_options":{"anyOf":[{"$ref":"#/components/schemas/StreamOptions"},{"type":"null"}]},"temperature":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Temperature"},"top_p":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Top P"},"tools":{"anyOf":[{"items":{"$ref":"#/components/schemas/ChatCompletionToolsParam"},"type":"array"},{"type":"null"}],"title":"Tools"},"tool_choice":{"anyOf":[{"type":"string","enum":["none"],"const":"none"},{"type":"string","enum":["auto"],"const":"auto"},{"$ref":"#/components/schemas/ChatCompletionNamedToolChoiceParam"},{"type":"null"}],"title":"Tool Choice","default":"none"},"parallel_tool_calls":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Parallel Tool Calls","default":false},"user":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User"},"best_of":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Best Of"},"use_beam_search":{"type":"boolean","title":"Use Beam Search","default":false},"top_k":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Top K"},"min_p":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Min P"},"repetition_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Repetition Penalty"},"length_penalty":{"type":"number","title":"Length Penalty","default":1.0},"stop_token_ids":{"anyOf":[{"items":{"type":"integer"},"type":"array"},{"type":"null"}],"title":"Stop Token Ids"},"include_stop_str_in_output":{"type":"boolean","title":"Include Stop Str In Output","default":false},"ignore_eos":{"type":"boolean","title":"Ignore Eos","default":false},"min_tokens":{"type":"integer","title":"Min Tokens","default":0},"skip_special_tokens":{"type":"boolean","title":"Skip Special Tokens","default":true},"spaces_between_special_tokens":{"type":"boolean","title":"Spaces Between Special Tokens","default":true},"truncate_prompt_tokens":{"anyOf":[{"type":"integer","minimum":1.0},{"type":"null"}],"title":"Truncate Prompt Tokens"},"prompt_logprobs":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Prompt Logprobs"},"echo":{"type":"boolean","title":"Echo","description":"If true, the new message will be prepended with the last message if they belong to the same role.","default":false},"add_generation_prompt":{"type":"boolean","title":"Add Generation Prompt","description":"If true, the generation prompt will be added to the chat template. This is a parameter used by chat template in tokenizer config of the model.","default":true},"continue_final_message":{"type":"boolean","title":"Continue Final Message","description":"If this is set, the chat will be formatted so that the final message in the chat is open-ended, without any EOS tokens. The model will continue this message rather than starting a new one. This allows you to \"prefill\" part of the model's response for it. Cannot be used at the same time as `add_generation_prompt`.","default":false},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true, special tokens (e.g. BOS) will be added to the prompt on top of what is added by the chat template. For most models, the chat template takes care of adding the special tokens so this should be set to false (as is the default).","default":false},"documents":{"anyOf":[{"items":{"additionalProperties":{"type":"string"},"type":"object"},"type":"array"},{"type":"null"}],"title":"Documents","description":"A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing \"title\" and \"text\" keys."},"chat_template":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Chat Template","description":"A Jinja template to use for this conversion. As of transformers v4.44, default chat template is no longer allowed, so you must provide a chat template if the tokenizer does not define one."},"chat_template_kwargs":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Chat Template Kwargs","description":"Additional kwargs to pass to the template renderer. Will be accessible by the chat template."},"guided_json":{"anyOf":[{"type":"string"},{"type":"object"},{"$ref":"#/components/schemas/BaseModel"},{"type":"null"}],"title":"Guided Json","description":"If specified, the output will follow the JSON schema."},"guided_regex":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Regex","description":"If specified, the output will follow the regex pattern."},"guided_choice":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Guided Choice","description":"If specified, the output will be exactly one of the choices."},"guided_grammar":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Grammar","description":"If specified, the output will follow the context free grammar."},"guided_decoding_backend":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Decoding Backend","description":"If specified, will override the default guided decoding backend of the server for this specific request. If set, must be either 'outlines' / 'lm-format-enforcer'"},"guided_whitespace_pattern":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Whitespace Pattern","description":"If specified, will override the default whitespace pattern for guided json decoding."},"priority":{"type":"integer","title":"Priority","description":"The priority of the request (lower means earlier handling; default: 0). Any priority other than 0 will raise an error if the served model does not use priority scheduling.","default":0},"request_id":{"type":"string","title":"Request Id","description":"The request_id related to this request. If the caller does not set it, a random_uuid will be generated. This id is used through out the inference process and return in response."},"logits_processors":{"anyOf":[{"items":{"anyOf":[{"type":"string"},{"$ref":"#/components/schemas/LogitsProcessorConstructor"}]},"type":"array"},{"type":"null"}],"title":"Logits Processors","description":"A list of either qualified names of logits processors, or constructor objects, to apply when sampling. A constructor is a JSON object with a required 'qualname' field specifying the qualified name of the processor class/factory, and optional 'args' and 'kwargs' fields containing positional and keyword arguments. For example: {'qualname': 'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': {'param': 'value'}}."}},"additionalProperties":true,"type":"object","required":["messages","model"],"title":"ChatCompletionRequest"},"ChatCompletionSystemMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","enum":["system"],"const":"system","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionSystemMessageParam"},"ChatCompletionToolMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},"type":"array"}],"title":"Content"},"role":{"type":"string","enum":["tool"],"const":"tool","title":"Role"},"tool_call_id":{"type":"string","title":"Tool Call Id"}},"type":"object","required":["content","role","tool_call_id"],"title":"ChatCompletionToolMessageParam"},"ChatCompletionToolsParam":{"properties":{"type":{"type":"string","enum":["function"],"const":"function","title":"Type","default":"function"},"function":{"$ref":"#/components/schemas/FunctionDefinition"}},"additionalProperties":true,"type":"object","required":["function"],"title":"ChatCompletionToolsParam"},"ChatCompletionUserMessageParam":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartImageParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartInputAudioParam"}]},"type":"array"}],"title":"Content"},"role":{"type":"string","enum":["user"],"const":"user","title":"Role"},"name":{"type":"string","title":"Name"}},"type":"object","required":["content","role"],"title":"ChatCompletionUserMessageParam"},"CompletionRequest":{"properties":{"model":{"type":"string","title":"Model"},"prompt":{"anyOf":[{"items":{"type":"integer"},"type":"array"},{"items":{"items":{"type":"integer"},"type":"array"},"type":"array"},{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Prompt"},"best_of":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Best Of"},"echo":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Echo","default":false},"frequency_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Frequency Penalty","default":0.0},"logit_bias":{"anyOf":[{"additionalProperties":{"type":"number"},"type":"object"},{"type":"null"}],"title":"Logit Bias"},"logprobs":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Logprobs"},"max_tokens":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Max Tokens","default":16},"n":{"type":"integer","title":"N","default":1},"presence_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Presence Penalty","default":0.0},"seed":{"anyOf":[{"type":"integer","maximum":9.223372036854776e+18,"minimum":-9.223372036854776e+18},{"type":"null"}],"title":"Seed"},"stop":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Stop"},"stream":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Stream","default":false},"stream_options":{"anyOf":[{"$ref":"#/components/schemas/StreamOptions"},{"type":"null"}]},"suffix":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Suffix"},"temperature":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Temperature"},"top_p":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Top P"},"user":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User"},"use_beam_search":{"type":"boolean","title":"Use Beam Search","default":false},"top_k":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Top K"},"min_p":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Min P"},"repetition_penalty":{"anyOf":[{"type":"number"},{"type":"null"}],"title":"Repetition Penalty"},"length_penalty":{"type":"number","title":"Length Penalty","default":1.0},"stop_token_ids":{"anyOf":[{"items":{"type":"integer"},"type":"array"},{"type":"null"}],"title":"Stop Token Ids"},"include_stop_str_in_output":{"type":"boolean","title":"Include Stop Str In Output","default":false},"ignore_eos":{"type":"boolean","title":"Ignore Eos","default":false},"min_tokens":{"type":"integer","title":"Min Tokens","default":0},"skip_special_tokens":{"type":"boolean","title":"Skip Special Tokens","default":true},"spaces_between_special_tokens":{"type":"boolean","title":"Spaces Between Special Tokens","default":true},"truncate_prompt_tokens":{"anyOf":[{"type":"integer","minimum":1.0},{"type":"null"}],"title":"Truncate Prompt Tokens"},"allowed_token_ids":{"anyOf":[{"items":{"type":"integer"},"type":"array"},{"type":"null"}],"title":"Allowed Token Ids"},"prompt_logprobs":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Prompt Logprobs"},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true (the default), special tokens (e.g. BOS) will be added to the prompt.","default":true},"response_format":{"anyOf":[{"$ref":"#/components/schemas/ResponseFormat"},{"type":"null"}],"description":"Similar to chat completion, this parameter specifies the format of output. Only {'type': 'json_object'}, {'type': 'json_schema'} or {'type': 'text' } is supported."},"guided_json":{"anyOf":[{"type":"string"},{"type":"object"},{"$ref":"#/components/schemas/BaseModel"},{"type":"null"}],"title":"Guided Json","description":"If specified, the output will follow the JSON schema."},"guided_regex":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Regex","description":"If specified, the output will follow the regex pattern."},"guided_choice":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"null"}],"title":"Guided Choice","description":"If specified, the output will be exactly one of the choices."},"guided_grammar":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Grammar","description":"If specified, the output will follow the context free grammar."},"guided_decoding_backend":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Decoding Backend","description":"If specified, will override the default guided decoding backend of the server for this specific request. If set, must be one of 'outlines' / 'lm-format-enforcer'"},"guided_whitespace_pattern":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Guided Whitespace Pattern","description":"If specified, will override the default whitespace pattern for guided json decoding."},"priority":{"type":"integer","title":"Priority","description":"The priority of the request (lower means earlier handling; default: 0). Any priority other than 0 will raise an error if the served model does not use priority scheduling.","default":0},"logits_processors":{"anyOf":[{"items":{"anyOf":[{"type":"string"},{"$ref":"#/components/schemas/LogitsProcessorConstructor"}]},"type":"array"},{"type":"null"}],"title":"Logits Processors","description":"A list of either qualified names of logits processors, or constructor objects, to apply when sampling. A constructor is a JSON object with a required 'qualname' field specifying the qualified name of the processor class/factory, and optional 'args' and 'kwargs' fields containing positional and keyword arguments. For example: {'qualname': 'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': {'param': 'value'}}."}},"additionalProperties":true,"type":"object","required":["model","prompt"],"title":"CompletionRequest"},"CustomChatCompletionContentSimpleAudioParam":{"properties":{"audio_url":{"type":"string","title":"Audio Url"}},"type":"object","required":["audio_url"],"title":"CustomChatCompletionContentSimpleAudioParam","description":"A simpler version of the param that only accepts a plain audio_url.\n\nExample:\n{\n \"audio_url\": \"https://example.com/audio.mp3\"\n}"},"CustomChatCompletionContentSimpleImageParam":{"properties":{"image_url":{"type":"string","title":"Image Url"}},"type":"object","required":["image_url"],"title":"CustomChatCompletionContentSimpleImageParam","description":"A simpler version of the param that only accepts a plain image_url.\nThis is supported by OpenAI API, although it is not documented.\n\nExample:\n{\n \"image_url\": \"https://example.com/image.jpg\"\n}"},"CustomChatCompletionContentSimpleVideoParam":{"properties":{"video_url":{"type":"string","title":"Video Url"}},"type":"object","required":["video_url"],"title":"CustomChatCompletionContentSimpleVideoParam","description":"A simpler version of the param that only accepts a plain audio_url.\n\nExample:\n{\n \"video_url\": \"https://example.com/video.mp4\"\n}"},"CustomChatCompletionMessageParam":{"properties":{"role":{"type":"string","title":"Role"},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionContentPartTextParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartImageParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartInputAudioParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartAudioParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartVideoParam"},{"$ref":"#/components/schemas/ChatCompletionContentPartRefusalParam"},{"$ref":"#/components/schemas/CustomChatCompletionContentSimpleImageParam"},{"$ref":"#/components/schemas/CustomChatCompletionContentSimpleAudioParam"},{"$ref":"#/components/schemas/CustomChatCompletionContentSimpleVideoParam"},{"type":"string"}]},"type":"array"}],"title":"Content"},"name":{"type":"string","title":"Name"},"tool_call_id":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Tool Call Id"},"tool_calls":{"anyOf":[{"items":{"$ref":"#/components/schemas/ChatCompletionMessageToolCallParam"},"type":"array"},{"type":"null"}],"title":"Tool Calls"}},"type":"object","required":["role"],"title":"CustomChatCompletionMessageParam","description":"Enables custom roles in the Chat Completion API."},"DetokenizeRequest":{"properties":{"model":{"type":"string","title":"Model"},"tokens":{"items":{"type":"integer"},"type":"array","title":"Tokens"}},"additionalProperties":true,"type":"object","required":["model","tokens"],"title":"DetokenizeRequest"},"EmbeddingChatRequest":{"properties":{"model":{"type":"string","title":"Model"},"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionDeveloperMessageParam"},{"$ref":"#/components/schemas/ChatCompletionSystemMessageParam"},{"$ref":"#/components/schemas/ChatCompletionUserMessageParam"},{"$ref":"#/components/schemas/ChatCompletionAssistantMessageParam"},{"$ref":"#/components/schemas/ChatCompletionToolMessageParam"},{"$ref":"#/components/schemas/ChatCompletionFunctionMessageParam"},{"$ref":"#/components/schemas/CustomChatCompletionMessageParam"}]},"type":"array","title":"Messages"},"encoding_format":{"type":"string","enum":["float","base64"],"title":"Encoding Format","default":"float"},"dimensions":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Dimensions"},"user":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User"},"truncate_prompt_tokens":{"anyOf":[{"type":"integer","minimum":1.0},{"type":"null"}],"title":"Truncate Prompt Tokens"},"additional_data":{"anyOf":[{},{"type":"null"}],"title":"Additional Data"},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true, special tokens (e.g. BOS) will be added to the prompt on top of what is added by the chat template. For most models, the chat template takes care of adding the special tokens so this should be set to false (as is the default).","default":false},"chat_template":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Chat Template","description":"A Jinja template to use for this conversion. As of transformers v4.44, default chat template is no longer allowed, so you must provide a chat template if the tokenizer does not define one."},"chat_template_kwargs":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Chat Template Kwargs","description":"Additional kwargs to pass to the template renderer. Will be accessible by the chat template."},"priority":{"type":"integer","title":"Priority","description":"The priority of the request (lower means earlier handling; default: 0). Any priority other than 0 will raise an error if the served model does not use priority scheduling.","default":0}},"additionalProperties":true,"type":"object","required":["model","messages"],"title":"EmbeddingChatRequest"},"EmbeddingCompletionRequest":{"properties":{"model":{"type":"string","title":"Model"},"input":{"anyOf":[{"items":{"type":"integer"},"type":"array"},{"items":{"items":{"type":"integer"},"type":"array"},"type":"array"},{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Input"},"encoding_format":{"type":"string","enum":["float","base64"],"title":"Encoding Format","default":"float"},"dimensions":{"anyOf":[{"type":"integer"},{"type":"null"}],"title":"Dimensions"},"user":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User"},"truncate_prompt_tokens":{"anyOf":[{"type":"integer","minimum":1.0},{"type":"null"}],"title":"Truncate Prompt Tokens"},"additional_data":{"anyOf":[{},{"type":"null"}],"title":"Additional Data"},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true (the default), special tokens (e.g. BOS) will be added to the prompt.","default":true},"priority":{"type":"integer","title":"Priority","description":"The priority of the request (lower means earlier handling; default: 0). Any priority other than 0 will raise an error if the served model does not use priority scheduling.","default":0}},"additionalProperties":true,"type":"object","required":["model","input"],"title":"EmbeddingCompletionRequest"},"Function":{"properties":{"arguments":{"type":"string","title":"Arguments"},"name":{"type":"string","title":"Name"}},"type":"object","required":["arguments","name"],"title":"Function"},"FunctionCall":{"properties":{"arguments":{"type":"string","title":"Arguments"},"name":{"type":"string","title":"Name"}},"type":"object","required":["arguments","name"],"title":"FunctionCall"},"FunctionDefinition":{"properties":{"name":{"type":"string","title":"Name"},"description":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Description"},"parameters":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Parameters"}},"additionalProperties":true,"type":"object","required":["name"],"title":"FunctionDefinition"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"ImageURL":{"properties":{"url":{"type":"string","title":"Url"},"detail":{"type":"string","enum":["auto","low","high"],"title":"Detail"}},"type":"object","required":["url"],"title":"ImageURL"},"InputAudio":{"properties":{"data":{"type":"string","title":"Data"},"format":{"type":"string","enum":["wav","mp3"],"title":"Format"}},"type":"object","required":["data","format"],"title":"InputAudio"},"JsonSchemaResponseFormat":{"properties":{"name":{"type":"string","title":"Name"},"description":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Description"},"schema":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Schema"},"strict":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Strict"}},"additionalProperties":true,"type":"object","required":["name"],"title":"JsonSchemaResponseFormat"},"LogitsProcessorConstructor":{"properties":{"qualname":{"type":"string","title":"Qualname"},"args":{"anyOf":[{"items":{},"type":"array"},{"type":"null"}],"title":"Args"},"kwargs":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Kwargs"}},"type":"object","required":["qualname"],"title":"LogitsProcessorConstructor"},"ResponseFormat":{"properties":{"type":{"type":"string","enum":["text","json_object","json_schema"],"title":"Type"},"json_schema":{"anyOf":[{"$ref":"#/components/schemas/JsonSchemaResponseFormat"},{"type":"null"}]}},"additionalProperties":true,"type":"object","required":["type"],"title":"ResponseFormat"},"ScoreRequest":{"properties":{"model":{"type":"string","title":"Model"},"text_1":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"string"}],"title":"Text 1"},"text_2":{"anyOf":[{"items":{"type":"string"},"type":"array"},{"type":"string"}],"title":"Text 2"},"truncate_prompt_tokens":{"anyOf":[{"type":"integer","minimum":1.0},{"type":"null"}],"title":"Truncate Prompt Tokens"},"additional_data":{"anyOf":[{},{"type":"null"}],"title":"Additional Data"},"priority":{"type":"integer","title":"Priority","description":"The priority of the request (lower means earlier handling; default: 0). Any priority other than 0 will raise an error if the served model does not use priority scheduling.","default":0}},"additionalProperties":true,"type":"object","required":["model","text_1","text_2"],"title":"ScoreRequest"},"StreamOptions":{"properties":{"include_usage":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Include Usage","default":true},"continuous_usage_stats":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Continuous Usage Stats","default":false}},"additionalProperties":true,"type":"object","title":"StreamOptions"},"TokenizeChatRequest":{"properties":{"model":{"type":"string","title":"Model"},"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/ChatCompletionDeveloperMessageParam"},{"$ref":"#/components/schemas/ChatCompletionSystemMessageParam"},{"$ref":"#/components/schemas/ChatCompletionUserMessageParam"},{"$ref":"#/components/schemas/ChatCompletionAssistantMessageParam"},{"$ref":"#/components/schemas/ChatCompletionToolMessageParam"},{"$ref":"#/components/schemas/ChatCompletionFunctionMessageParam"},{"$ref":"#/components/schemas/CustomChatCompletionMessageParam"}]},"type":"array","title":"Messages"},"add_generation_prompt":{"type":"boolean","title":"Add Generation Prompt","description":"If true, the generation prompt will be added to the chat template. This is a parameter used by chat template in tokenizer config of the model.","default":true},"continue_final_message":{"type":"boolean","title":"Continue Final Message","description":"If this is set, the chat will be formatted so that the final message in the chat is open-ended, without any EOS tokens. The model will continue this message rather than starting a new one. This allows you to \"prefill\" part of the model's response for it. Cannot be used at the same time as `add_generation_prompt`.","default":false},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true, special tokens (e.g. BOS) will be added to the prompt on top of what is added by the chat template. For most models, the chat template takes care of adding the special tokens so this should be set to false (as is the default).","default":false},"chat_template":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Chat Template","description":"A Jinja template to use for this conversion. As of transformers v4.44, default chat template is no longer allowed, so you must provide a chat template if the tokenizer does not define one."},"chat_template_kwargs":{"anyOf":[{"type":"object"},{"type":"null"}],"title":"Chat Template Kwargs","description":"Additional kwargs to pass to the template renderer. Will be accessible by the chat template."}},"additionalProperties":true,"type":"object","required":["model","messages"],"title":"TokenizeChatRequest"},"TokenizeCompletionRequest":{"properties":{"model":{"type":"string","title":"Model"},"prompt":{"type":"string","title":"Prompt"},"add_special_tokens":{"type":"boolean","title":"Add Special Tokens","description":"If true (the default), special tokens (e.g. BOS) will be added to the prompt.","default":true}},"additionalProperties":true,"type":"object","required":["model","prompt"],"title":"TokenizeCompletionRequest"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"},"VideoURL":{"properties":{"url":{"type":"string","title":"Url"}},"type":"object","required":["url"],"title":"VideoURL"}}}}
a_temp/temp1.ipynb ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/dscilab_dungvo/workspace/bin/envs/lmdeploy/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import datasets, huggingface_hub\n",
19
+ "disk_path ='/dscilab_dungvo/workspace/BA-PRE_THESIS/dataset_pretraining/SYNTH-PEDES/annotation_english_vietnamese_processed'\n",
20
+ "dataset = datasets.load_from_disk(disk_path)"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "CUDA_VISIBLE_DEVICES=0 python inference.py --model_path \"deepseek-ai/deepseek-vl2-small\" --chunk_size 512\n",
30
+ "CUDA_VISIBLE_DEVICES=0,1,2 python inference.py --model_path \"deepseek-ai/deepseek-vl2\""
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 31,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "# Base64\n",
40
+ "import requests\n",
41
+ "from PIL import Image\n",
42
+ "from io import BytesIO\n",
43
+ "import base64\n",
44
+ "from openai import OpenAI\n",
45
+ "from langchain_community.llms import VLLMOpenAI\n",
46
+ "from langchain_openai import ChatOpenAI\n",
47
+ "from langchain_core.messages import HumanMessage, SystemMessage\n",
48
+ "from langchain_core.prompts.chat import (\n",
49
+ " ChatPromptTemplate,\n",
50
+ " HumanMessagePromptTemplate,\n",
51
+ " SystemMessagePromptTemplate,\n",
52
+ ")\n",
53
+ "\n",
54
+ "\n",
55
+ "PORT = 19400\n",
56
+ "client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=f\"http://0.0.0.0:{PORT}/v1\")\n",
57
+ "model_name = client.models.list().data[0].id\n",
58
+ "\n",
59
+ "inference_server_url = f\"http://0.0.0.0:{PORT}/v1\"\n",
60
+ "\n",
61
+ "llm = ChatOpenAI(\n",
62
+ " model=model_name,\n",
63
+ " openai_api_key=\"EMPTY\",\n",
64
+ " openai_api_base=inference_server_url,\n",
65
+ " max_tokens=2000,\n",
66
+ " # temperature=0.1,\n",
67
+ " # top_p=0.8,\n",
68
+ " temperature=0.05,\n",
69
+ " top_p=0.9,\n",
70
+ ")\n",
71
+ "\n",
72
+ "def make_message(pil_image):\n",
73
+ "\n",
74
+ " # INSERT THIS ...\n",
75
+ " buffered = BytesIO()\n",
76
+ " pil_image.save(buffered, format=\"JPEG\")\n",
77
+ " img_str = base64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n",
78
+ " img_str = str(img_str)\n",
79
+ " message = HumanMessage(\n",
80
+ " content=[\n",
81
+ " {\"type\": \"text\", \"text\": \"Describe the image\"},\n",
82
+ " {\"type\": \"image_url\", \"image_url\": {\"url\": 'data:image/jpeg;base64,' + img_str}},\n",
83
+ " ],\n",
84
+ " )\n",
85
+ " return message\n",
86
+ "# response = llm.invoke([message], temperature=0.1, top_p=0.9)\n",
87
+ "# response\n",
88
+ "def get_answer(chain, message):\n",
89
+ " response = chain.invoke([message], temperature=0.1, top_p=0.9)\n",
90
+ " return response.content\n"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 14,
96
+ "metadata": {},
97
+ "outputs": [
98
+ {
99
+ "data": {
100
+ "text/plain": [
101
+ "'OpenGVLab/InternVL2_5-8B-AWQ'"
102
+ ]
103
+ },
104
+ "execution_count": 14,
105
+ "metadata": {},
106
+ "output_type": "execute_result"
107
+ }
108
+ ],
109
+ "source": [
110
+ "model_name"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 15,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "example_image = dataset[1000]['image']\n",
120
+ "message = make_message(example_image)\n",
121
+ "response = get_answer(message)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 16,
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "data": {
136
+ "text/plain": [
137
+ "'The image shows a person from behind walking on a tiled floor. The person is wearing a dark shirt and dark pants. The lighting is dim, and there is a bright screen or display in the background. The person appears to be holding something in their right hand.'"
138
+ ]
139
+ },
140
+ "execution_count": 16,
141
+ "metadata": {},
142
+ "output_type": "execute_result"
143
+ }
144
+ ],
145
+ "source": [
146
+ "response"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 24,
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "data": {
156
+ "text/plain": [
157
+ "[SystemMessage(content='You are a helpful assistant who is helping user to caption about the image related to person, taking from surveillance camera. Please provide the caption in detail.'),\n",
158
+ " HumanMessage(content='Describe the image')]"
159
+ ]
160
+ },
161
+ "execution_count": 24,
162
+ "metadata": {},
163
+ "output_type": "execute_result"
164
+ }
165
+ ],
166
+ "source": [
167
+ "init_prompt = ChatPromptTemplate(\n",
168
+ " [\n",
169
+ " (\n",
170
+ " \"system\",\n",
171
+ " \"You are a helpful assistant who is helping the user write a clear prompt for guiding a Multimodal Large Language Model (MLLM) to describe the image.\",\n",
172
+ " ),\n",
173
+ " (\n",
174
+ " \"user\",\n",
175
+ " \"\"\"I want the MLLM to provide a detailed, fine-grained description of the image related to a person, taken from surveillance. The model must cover these aspects:\n",
176
+ " - The gender, pose, appearance, and age of the person in the image.\n",
177
+ " - The region of the head, face, and items such as hats, glasses, helmets, etc.\n",
178
+ " - Characteristics of the upper body, such as a red shirt, blue and white jacket, etc.\n",
179
+ " - Characteristics of the lower body, such as black jeans, white skirt, etc.\n",
180
+ " - Characteristics of accessories the person is holding, such as a phone, bag, etc.\n",
181
+ " - Characteristics of the bottom of the person, such as shoes, sandals, etc.\n",
182
+ " - The location of the person and objects in the image, such as in the park, on the street, in the house, etc.\n",
183
+ " - The transportation in the image, such as a car, bike, bus, etc.\n",
184
+ " - The time of day or lighting conditions.\n",
185
+ " - The weather conditions, such as sunny, rainy, etc.\n",
186
+ " - Any notable actions or activities the person is engaged in.\n",
187
+ " \n",
188
+ " For the objects that occur in the image or on the person, please provide a detailed description of the object, such as the color, shape, size, and any other relevant details.\n",
189
+ " Please generate three example templates to help the model describe the image in detail. For example:\n",
190
+ " EX1: \"The [gender] [age] person is wearing a [color] [type of clothing] and holding a [object] in the [location]. [He/She] is standing next to a [object] and [object]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\n",
191
+ " EX2: \"The [gender] [age] person is [action] while wearing a [color] [type of clothing]. [He/She] is holding a [object] and is located in the [location]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\n",
192
+ " EX3: \"In the [location], the [gender] [age] person is seen wearing a [color] [type of clothing] and holding a [object]. [He/She] is next to a [object] and [object]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\n",
193
+ " \"\"\"\n",
194
+ " ),\n",
195
+ " (\n",
196
+ " \"user\",\n",
197
+ " [\n",
198
+ " {\n",
199
+ " \"type\": \"image_url\",\n",
200
+ " \"image_url\": {\"url\": \"data:image/jpeg;base64,{image_data}\"},\n",
201
+ " }\n",
202
+ " ],\n",
203
+ " )\n",
204
+ " ]\n",
205
+ ")\n",
206
+ "\n",
207
+ "\n",
208
+ "\n",
209
+ "\n",
210
+ "extract_prompt = ChatPromptTemplate(\n",
211
+ " [\n",
212
+ " (\n",
213
+ " \"system\",\n",
214
+ " \"You are a helpful assistant who is helping user to caption about the image related to person, taking from surveillance camera. Please provide the caption in detail.\",\n",
215
+ " ),\n",
216
+ " (\n",
217
+ " \"user\",\n",
218
+ " \"{guild}\"\n",
219
+ " ,\n",
220
+ " ),\n",
221
+ " ]\n",
222
+ ")\n",
223
+ "\n",
224
+ "extract_prompt.format_messages(guild=\"Describe the image\")\n"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "chain = init_prompt + llm \n",
234
+ "\n",
235
+ "response = chain.invoke()"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 25,
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "data": {
245
+ "text/plain": [
246
+ "[SystemMessage(content='You are a helpful assistant who is helping the user write a clear prompt for guiding a Multimodal Large Language Model (MLLM) to describe the image.'),\n",
247
+ " HumanMessage(content='I want the MLLM to provide a detailed, fine-grained description of the image related to a person, taken from surveillance. The model must cover these aspects:\\n - The gender, pose, appearance, and age of the person in the image.\\n - The region of the head, face, and items such as hats, glasses, helmets, etc.\\n - Characteristics of the upper body, such as a red shirt, blue and white jacket, etc.\\n - Characteristics of the lower body, such as black jeans, white skirt, etc.\\n - Characteristics of accessories the person is holding, such as a phone, bag, etc.\\n - Characteristics of the bottom of the person, such as shoes, sandals, etc.\\n - The location of the person and objects in the image, such as in the park, on the street, in the house, etc.\\n - The transportation in the image, such as a car, bike, bus, etc.\\n - The time of day or lighting conditions.\\n - The weather conditions, such as sunny, rainy, etc.\\n - Any notable actions or activities the person is engaged in.\\n \\n For the objects that occur in the image or on the person, please provide a detailed description of the object, such as the color, shape, size, and any other relevant details.\\n Please generate three example templates to help the model describe the image in detail. For example:\\n EX1: \"The [gender] [age] person is wearing a [color] [type of clothing] and holding a [object] in the [location]. [He/She] is standing next to a [object] and [object]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n EX2: \"The [gender] [age] person is [action] while wearing a [color] [type of clothing]. [He/She] is holding a [object] and is located in the [location]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n EX3: \"In the [location], the [gender] [age] person is seen wearing a [color] [type of clothing] and holding a [object]. [He/She] is next to a [object] and [object]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n '),\n",
248
+ " HumanMessage(content=[{'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCADwAFgDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDxIRnNPKGri2zE8jFSfZ8dawvoIzTGaXyzV9ocVCVIPIpILFKQlBjvUJdj1JNLLnzGz60ytkgHpM6EYP51o20wmBDgBuxrMq9psZknIAyqjLUpLQqKLFwpMQB7Gn2Qy2T2oKSGXyn/AJ16DeeA9L0zwiNRfWE+3GBZVgGCDk9K5KtaMLJ9TqhSurnK3lzFdtGRDsZEC5B60VSjy0645HrRVWQ9tCyY8N0pjJ0q9LAUySMH3qnJj1q2cEJ3IWQVEyKQae3XrTCpPc800aXRh3IxO3FRGr2oR4IcHrwRVEiuiJIsal3Cjua6XT9BmktYpGdI0fksW5C+uKwbTHnLkgEnAJ6V7P8ADrwZpfiS9kttXXz1t7RGXypSMMWOeRWVRttJG9NK12efTaJcxTFYWSZB0bcBmp3tdQlCLIkSIg6B8/jX0MfhF4R24FncKPa4b/GkHwh8IgcWlx/4ENUuinqy1XtsfPkNi4bkxg+xor32X4M+EpRjyrxP924NFV7PsT7RM8T8RG9+1XDXob7T5hEgbqDn2rmmLE8muq14/vZFbczFiSTySa5i6KwqzcgnsRUR1PCoYn2yuiJZoYZBI+SB2FQXerPcWqW6RhEj6HuaqOxYnJqI9K6Ixsd8bkZBPWkwKcaSrSLuJiuh8N+LNV8L3X2nS5/LkPUHlWHcEVz9AosmNSaPofwz8drecxwa9ZeWWIH2iA5H1KmvYbO7gv7SK7tpFkglUMjr0Ir4dimKEV9B/AzxZFcaXcaBcSYmt8zQgn7yHrj6cfnSaLWp7FNKsUTOTwoyaK47V9Zhae5tFyrPwzeoxRXO6h6FLBSlG7PBfELzGVnEbHLcYFczqpKukR6hcn6mu+mv7e0t5ZplVwikqCO9ea3dybid5GGNxJxmpoq+qPAo4eNJWiisxxTc0E03NdZ0oWkNANGaZQhFJ0oJppNAC5Oa2PDuv3fh7WbXUbRsSQPuI/vDuKxRT8kdDSZcXZn0tFqkGspBcxWwU3OJd2/PXnp2oriPhPqi3SjS5dzS2+ZFYH+Dj+RP60VyuLue9GceVWOH8Q3F26KCCsIOScVzrMSa9CimguU3owY453Y6VyWuwWyajItsqqMfNt6ZopNRVkeJKhJPVGMWycUlW4NOkntLm6U4S3ALcHucVUNdKdzJqwUUg6UpOKoBDSUUUDEpRz0ppBFX9KskutRtYpnKRSyojEdQCQM/rUt2GkbPgbX5fDvi2xvwQYlcJMB3jbhv5/pRX0toHwr8J6CEeHTlupl5Et0fMIPsDwKKhs2VVpWPmiKCUwSSpzHGQrMDWZMx3NXremaRYReGr25lt48MzOqEbsAZ7dq8puADK+Om49frXJTkdbxCqx2Og0WC3b4e64zxDzndcSdxg5riFzxXqGh2duPhjrVzPESY1YxsGxhu3868vXj610wdzhqQaeo7tSHpS0h6VsjG1hKKKKZQh5rS0mY299bTg5MciuB64OazsVasgfOXPTNZVNjWnue+P8bNVGRDoNuRnAZ5zz+Q4orxy4Sa4l/0YMqY+7uJx+dFY3ZTjG56lJOU8NzQJ90KUwfqRXkc64nkx616A2h37SyGctNEzFjG8+ee2OK5SwsVu/EMFkwwJLhYyPb0rlU9Ga0oWsdp4pjGk/CWG2C7WuWjUj1Od39K8dP3jkYOelev/F8lNE0yKP7izHC+mFxXkABPXrXVhneNxYrR2A0UdKK6zjENFFFACir9upx8vWqFaVodqk46rxWVQ0hubmiqGndZHUDbkE8UVmWc6iPMn3kJx70VjdG1z22SxnTcxAwOTXkTyNF4hZlOCLgEfmK9t1gyRaTdsh+cQvj67TXz/NM0kpkz8+ST9c152HbmmejOmoNHe/EqWeTwtYvI+4NOvGP9mvLB6mvWvFenXN54B01nuBuVldgV7EfWvOv7GduRPHj1r0MK1y2OHF6yMrHNIa1v7Gx1uox+FJ/ZEQ+9djH0rsucVjJFLWodNtuguz+Qpy6NFjIu8/RalsLGQc9a7rw34PfW9IjvheLCGJXaV9DisGLRUZtu6ST02ivSfA9rJBoLQlHVVmfbvGCRxXNXnZaHTQhdmbH8ObdOH1JyCeQFortvKOTRXE6kjs9mi7cXcUsbxupKMCpA9DXA614JtJY5JNM3pKQSI3bIY+ldS8jYpEckisaa5XobylzLUjug6+EpLORVWSOzZWyM4IBrwVriYtzK/wCdfQzBXd0cZVsgj1FfP2pIItRuIlxhJXUY/wB4134d7nDiI9SuZXJ5Zj9TSFieppKCK6zjsGaesjjgOwHsaZRkjmkwSPV/hXKqaXf7yxZphz9F/wDr13rXMQzwx+grifhfZyTeH55ET5fOOWz3GP8A61d6tlIB1FcVRNs9KilyoqG4XsjUVbNkcctRWXIzaxnSQEA0xY8YNXH6GosVmkCehHjBFeC69bva69fQyfeWZs/nmvfQMsBXg/ie4+1eJdQm27d07cZzjHH9K68PucuI2MuugNhDL4GhvVX9/FdOrn/ZI4/Wue71NHbTSW81wiZihx5jZ6ZOB+tdhyX0IaKKVSA2WGQBSRKPb/hFbsvhOZyRhrliB+Arv2jwK5b4Y2n2TwPZ5BzMzSHI9Tjj8q69sYrCW56VNe6imykUVI45NFSaGA8oAqEy5HXFQtKo71F5y461y2JjItNKY43kHJVSePpXz5M264kPqx/nXuF/Oh0y6BPBifI/A14WfvH6110DlxDA1raXIn9j6zCx5eGNlHuJB/jWTV7TuY9QX/p1J/JlrpOdMpUZOCB/FxSVPaQNdX1vboMvJKqAfU0CW59NeGrVrXwzpkDH547dR6dulaRYbR1xjg+/pUUOIrWFDlAqKj+qsBjIpWlO5txGQfmAP/j1YTR6NPYazcUVXlnVVO7r/P3oqLGhxhfIppYAVWLECk3n1rBHIpMr6+C/hzUNjYYQsePavIcc169qAM2mXUQ6vEw/SvIcYJHpXVS2Mql2wHFWrG4+zSysTw8Ekf8A30pFVqK2MkgrW8L2aah4o021dmVXnXJHtzWR3Fb3g60a98WadGjlGEobcO2Kb2KjufR/QsJcMV+ViOBjsaikVDGDkFl5Ge4pnmlU3Mh8xD5cik/eHY/kaGyF8pSPl5ic9x6VizvhsUrkBwpHRuUPofSiknYYZmGEYcj+6fWikWcSZc03zKrF6YZDXOcKLhYSBkJ+8Cv5ivJrqMwXcsRGCjkc/WvTN7dQeRXC+Jrcwaw8hOVmG8H+db0hS2MinpE8iSMo+WNdzH05xTK2La3hTwvcXbN++kkEYGOw5roMzG716F8JrVJtdvZTjfFANh9CT1rz016N8LrJ1e/1Vc7oAqqoH3h1NTLYuG566ylpVn9fkk/lSv8AKvk7gGGTGfWmG4TdC4OYboY3ehxVeeUspjwfNhO4c8kVkdqWglzKHG88AjEi+/rRVG7vUYLOobZJ8rADoaKm4XOI4703IVgaUjmmMKyOUkfiU8Vj+ItNi1C3Q+ZtmXhBjO72rZlZSykegrF1slo1wehrSDsxSVzjRbEZB7GmyM6x+VuPlg7tvbNX5MkkBCfpVZreSRvulfrW6kRylTGRXq3wnISwuhIF2TSlc+4HSvL/ACJAcYyfQV7T4G0pbLwuqfKZ2fznKnoccCiT0HBanSWgDJNpkn3ozuiJ7jtTXclUuwDviJDqO9F/JstrfUYVYyQHLgfxL3FR3Vwsdwk68wXA+cdgexrM676Fe5McUhA5imGR/smiq0+3zXtJfutyh9P/AK9FQK5yo4pDjFQPdIuT1qlLes/AXA9qzOd7k9zdeWVAHese7kMpbPep7hHdVIJzmnxWO/Bc4FWkMz7KFvOzUt6Q8y7h90YrSZYIDhetY935skzbFyatMkiih/06OTblEO5h7CvSvDmqy72nTKwykoQxzivOtMt72a62sQqdwRya7vSImj0OWLH72M8j3FNscdzp4byWO6e1bJjlBIyenrVO3kOLjTZm3GPJQ+x5GPpVe4u2msYryPh0IOAO3Qim38oja21CLkZCPkfwn/Cg15hQ0k9v8x/fxHAPvRTLl1inE6fdbqR6UVFxn//Z'}}])]"
249
+ ]
250
+ },
251
+ "execution_count": 25,
252
+ "metadata": {},
253
+ "output_type": "execute_result"
254
+ }
255
+ ],
256
+ "source": [
257
+ "def get_str_img(pil_image):\n",
258
+ " buffered = BytesIO()\n",
259
+ " pil_image.save(buffered, format=\"JPEG\")\n",
260
+ " img_str = base64.b64encode(buffered.getvalue()).decode(\"utf-8\")\n",
261
+ " img_str = str(img_str)\n",
262
+ " return img_str\n",
263
+ "\n",
264
+ "\n",
265
+ "\n",
266
+ "init_prompt.format_messages(\n",
267
+ " image_data=get_str_img(example_image)\n",
268
+ ")"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 29,
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "init_chain = init_prompt | llm\n",
278
+ "response = init_chain.invoke(input={\"image_data\": get_str_img(example_image)})"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 30,
284
+ "metadata": {},
285
+ "outputs": [
286
+ {
287
+ "data": {
288
+ "text/plain": [
289
+ "AIMessage(content='To guide a Multimodal Large Language Model (MLLM) to provide a detailed description of the image, you can use the following templates. These templates are designed to cover all the aspects you mentioned, ensuring a comprehensive and clear description:\\n\\n### Template 1:\\n\"The [gender] [age] person is seen from behind, wearing a [color] [type of clothing] and [type of pants]. [He/She] is walking on a [surface] and appears to be in a [location]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n\\n### Template 2:\\n\"The [gender] [age] person is walking while wearing a [color] [type of clothing] and [type of pants]. [He/She] is holding a [object] and is located in the [location]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n\\n### Template 3:\\n\"In the [location], the [gender] [age] person is seen from behind, wearing a [color] [type of clothing] and [type of pants]. [He/She] is walking on a [surface] and appears to be in a [location]. The [upper body clothing] is [color] and [lower body clothing] is [color]. [He/She] is wearing [accessories] and [shoes].\"\\n\\nThese templates provide a structured format for the MLLM to describe the image, ensuring that all relevant details are covered.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 350, 'prompt_tokens': 1649, 'total_tokens': 1999, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'OpenGVLab/InternVL2_5-8B-AWQ', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-eb1871d9-302e-45a6-a6c5-5b2f425f7c3b-0', usage_metadata={'input_tokens': 1649, 'output_tokens': 350, 'total_tokens': 1999})"
290
+ ]
291
+ },
292
+ "execution_count": 30,
293
+ "metadata": {},
294
+ "output_type": "execute_result"
295
+ }
296
+ ],
297
+ "source": [
298
+ "response"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": []
307
+ }
308
+ ],
309
+ "metadata": {
310
+ "kernelspec": {
311
+ "display_name": "Python 3 (ipykernel)",
312
+ "language": "python",
313
+ "name": "python3"
314
+ },
315
+ "language_info": {
316
+ "codemirror_mode": {
317
+ "name": "ipython",
318
+ "version": 3
319
+ },
320
+ "file_extension": ".py",
321
+ "mimetype": "text/x-python",
322
+ "name": "python",
323
+ "nbconvert_exporter": "python",
324
+ "pygments_lexer": "ipython3",
325
+ "version": "3.12.2"
326
+ }
327
+ },
328
+ "nbformat": 4,
329
+ "nbformat_minor": 4
330
+ }
a_temp/vllm_example.sh ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # usage: vllm serve <model_tag> [options]
2
+
3
+ # positional arguments:
4
+ # model_tag The model tag to serve
5
+
6
+ # options:
7
+ # --allow-credentials allow credentials
8
+ # --allowed-headers ALLOWED_HEADERS
9
+ # allowed headers
10
+ # --allowed-local-media-path ALLOWED_LOCAL_MEDIA_PATH
11
+ # Allowing API requests to read local images or videos from directories
12
+ # specified by the server file system. This is a security risk. Should only be
13
+ # enabled in trusted environments.
14
+ # --allowed-methods ALLOWED_METHODS
15
+ # allowed methods
16
+ # --allowed-origins ALLOWED_ORIGINS
17
+ # allowed origins
18
+ # --api-key API_KEY If provided, the server will require this key to be presented in the header.
19
+ # --block-size {8,16,32,64,128}
20
+ # Token block size for contiguous chunks of tokens. This is ignored on neuron
21
+ # devices and set to max-model-len. On CUDA devices, only block sizes up to 32
22
+ # are supported. On HPU devices, block size defaults to 128.
23
+ # --chat-template CHAT_TEMPLATE
24
+ # The file path to the chat template, or the template in single-line form for
25
+ # the specified model
26
+ # --chat-template-content-format {auto,string,openai}
27
+ # The format to render message content within a chat template. * "string" will
28
+ # render the content as a string. Example: "Hello World" * "openai" will render
29
+ # the content as a list of dictionaries, similar to OpenAI schema. Example:
30
+ # [{"type": "text", "text": "Hello world!"}]
31
+ # --code-revision CODE_REVISION
32
+ # The specific revision to use for the model code on Hugging Face Hub. It can
33
+ # be a branch name, a tag name, or a commit id. If unspecified, will use the
34
+ # default version.
35
+ # --collect-detailed-traces COLLECT_DETAILED_TRACES
36
+ # Valid choices are model,worker,all. It makes sense to set this only if
37
+ # --otlp-traces-endpoint is set. If set, it will collect detailed traces for
38
+ # the specified modules. This involves use of possibly costly and or blocking
39
+ # operations and hence might have a performance impact.
40
+ # --compilation-config COMPILATION_CONFIG, -O COMPILATION_CONFIG
41
+ # torch.compile configuration for the model.When it is a number (0, 1, 2, 3),
42
+ # it will be interpreted as the optimization level. NOTE: level 0 is the
43
+ # default level without any optimization. level 1 and 2 are for internal
44
+ # testing only. level 3 is the recommended level for production. To specify the
45
+ # full compilation config, use a JSON string. Following the convention of
46
+ # traditional compilers, using -O without space is also supported. -O3 is
47
+ # equivalent to -O 3.
48
+ # --config CONFIG Read CLI options from a config file.Must be a YAML with the following options
49
+ # :https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-
50
+ # reference
51
+ # --config-format {auto,hf,mistral}
52
+ # The format of the model config to load. * "auto" will try to load the config
53
+ # in hf format if available else it will try to load in mistral format
54
+ # --cpu-offload-gb CPU_OFFLOAD_GB
55
+ # The space in GiB to offload to CPU, per GPU. Default is 0, which means no
56
+ # offloading. Intuitively, this argument can be seen as a virtual way to
57
+ # increase the GPU memory size. For example, if you have one 24 GB GPU and set
58
+ # this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a
59
+ # 13B model with BF16 weight, which requires at least 26GB GPU memory. Note
60
+ # that this requires fast CPU-GPU interconnect, as part of the model is loaded
61
+ # from CPU memory to GPU memory on the fly in each model forward pass.
62
+ # --device {auto,cuda,neuron,cpu,openvino,tpu,xpu,hpu}
63
+ # Device type for vLLM execution.
64
+ # --disable-async-output-proc
65
+ # Disable async output processing. This may result in lower performance.
66
+ # --disable-custom-all-reduce
67
+ # See ParallelConfig.
68
+ # --disable-fastapi-docs
69
+ # Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint
70
+ # --disable-frontend-multiprocessing
71
+ # If specified, will run the OpenAI frontend server in the same process as the
72
+ # model serving engine.
73
+ # --disable-log-requests
74
+ # Disable logging requests.
75
+ # --disable-log-stats Disable logging statistics.
76
+ # --disable-logprobs-during-spec-decoding [DISABLE_LOGPROBS_DURING_SPEC_DECODING]
77
+ # If set to True, token log probabilities are not returned during speculative
78
+ # decoding. If set to False, log probabilities are returned according to the
79
+ # settings in SamplingParams. If not specified, it defaults to True. Disabling
80
+ # log probabilities during speculative decoding reduces latency by skipping
81
+ # logprob calculation in proposal sampling, target sampling, and after accepted
82
+ # tokens are determined.
83
+ # --disable-mm-preprocessor-cache
84
+ # If true, then disables caching of the multi-modal preprocessor/mapper. (not
85
+ # recommended)
86
+ # --disable-sliding-window
87
+ # Disables sliding window, capping to sliding window size
88
+ # --distributed-executor-backend {ray,mp}
89
+ # Backend to use for distributed model workers, either "ray" or "mp"
90
+ # (multiprocessing). If the product of pipeline_parallel_size and
91
+ # tensor_parallel_size is less than or equal to the number of GPUs available,
92
+ # "mp" will be used to keep processing on a single host. Otherwise, this will
93
+ # default to "ray" if Ray is installed and fail otherwise. Note that tpu and
94
+ # hpu only support Ray for distributed inference.
95
+ # --download-dir DOWNLOAD_DIR
96
+ # Directory to download and load the weights, default to the default cache dir
97
+ # of huggingface.
98
+ # --dtype {auto,half,float16,bfloat16,float,float32}
99
+ # Data type for model weights and activations. * "auto" will use FP16 precision
100
+ # for FP32 and FP16 models, and BF16 precision for BF16 models. * "half" for
101
+ # FP16. Recommended for AWQ quantization. * "float16" is the same as "half". *
102
+ # "bfloat16" for a balance between precision and range. * "float" is shorthand
103
+ # for FP32 precision. * "float32" for FP32 precision.
104
+ # --enable-auto-tool-choice
105
+ # Enable auto tool choice for supported models. Use --tool-call-parser to
106
+ # specify which parser to use
107
+ # --enable-chunked-prefill [ENABLE_CHUNKED_PREFILL]
108
+ # If set, the prefill requests can be chunked based on the
109
+ # max_num_batched_tokens.
110
+ # --enable-lora If True, enable handling of LoRA adapters.
111
+ # --enable-lora-bias If True, enable bias for LoRA adapters.
112
+ # --enable-prefix-caching, --no-enable-prefix-caching
113
+ # Enables automatic prefix caching. Use --no-enable-prefix-caching to disable
114
+ # explicitly.
115
+ # --enable-prompt-adapter
116
+ # If True, enable handling of PromptAdapters.
117
+ # --enable-prompt-tokens-details
118
+ # If set to True, enable prompt_tokens_details in usage.
119
+ # --enable-request-id-headers
120
+ # If specified, API server will add X-Request-Id header to responses. Caution:
121
+ # this hurts performance at high QPS.
122
+ # --enforce-eager Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph
123
+ # in hybrid for maximal performance and flexibility.
124
+ # --fully-sharded-loras
125
+ # By default, only half of the LoRA computation is sharded with tensor
126
+ # parallelism. Enabling this will use the fully sharded layers. At high
127
+ # sequence length, max rank or tensor parallel size, this is likely faster.
128
+ # --generation-config GENERATION_CONFIG
129
+ # The folder path to the generation config. Defaults to None, will use the
130
+ # default generation config in vLLM. If set to 'auto', the generation config
131
+ # will be automatically loaded from model. If set to a folder path, the
132
+ # generation config will be loaded from the specified folder path.
133
+ # --gpu-memory-utilization GPU_MEMORY_UTILIZATION
134
+ # The fraction of GPU memory to be used for the model executor, which can range
135
+ # from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
136
+ # utilization. If unspecified, will use the default value of 0.9. This is a
137
+ # per-instance limit, and only applies to the current vLLM instance.It does not
138
+ # matter if you have another vLLM instance running on the same GPU. For
139
+ # example, if you have two vLLM instances running on the same GPU, you can set
140
+ # the GPU memory utilization to 0.5 for each instance.
141
+ # --guided-decoding-backend {outlines,lm-format-enforcer,xgrammar}
142
+ # Which engine will be used for guided decoding (JSON schema / regex etc) by
143
+ # default. Currently support https://github.com/outlines-dev/outlines,
144
+ # https://github.com/mlc-ai/xgrammar, and https://github.com/noamgat/lm-format-
145
+ # enforcer. Can be overridden per request via guided_decoding_backend
146
+ # parameter.
147
+ # --hf-overrides HF_OVERRIDES
148
+ # Extra arguments for the HuggingFace config. This should be a JSON string that
149
+ # will be parsed into a dictionary.
150
+ # --host HOST host name
151
+ # --ignore-patterns IGNORE_PATTERNS
152
+ # The pattern(s) to ignore when loading the model.Default to `original/**/*` to
153
+ # avoid repeated loading of llama's checkpoints.
154
+ # --kv-cache-dtype {auto,fp8,fp8_e5m2,fp8_e4m3}
155
+ # Data type for kv cache storage. If "auto", will use model data type. CUDA
156
+ # 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports fp8
157
+ # (=fp8_e4m3)
158
+ # --kv-transfer-config KV_TRANSFER_CONFIG
159
+ # The configurations for distributed KV cache transfer. Should be a JSON
160
+ # string.
161
+ # --limit-mm-per-prompt LIMIT_MM_PER_PROMPT
162
+ # For each multimodal plugin, limit how many input instances to allow for each
163
+ # prompt. Expects a comma-separated list of items, e.g.: `image=16,video=2`
164
+ # allows a maximum of 16 images and 2 videos per prompt. Defaults to 1 for each
165
+ # modality.
166
+ # --load-format {auto,pt,safetensors,npcache,dummy,tensorizer,sharded_state,gguf,bitsandbytes,mistral,runai_streamer}
167
+ # The format of the model weights to load. * "auto" will try to load the
168
+ # weights in the safetensors format and fall back to the pytorch bin format if
169
+ # safetensors format is not available. * "pt" will load the weights in the
170
+ # pytorch bin format. * "safetensors" will load the weights in the safetensors
171
+ # format. * "npcache" will load the weights in pytorch format and store a numpy
172
+ # cache to speed up the loading. * "dummy" will initialize the weights with
173
+ # random values, which is mainly for profiling. * "tensorizer" will load the
174
+ # weights using tensorizer from CoreWeave. See the Tensorize vLLM Model script
175
+ # in the Examples section for more information. * "runai_streamer" will load
176
+ # the Safetensors weights using Run:aiModel Streamer * "bitsandbytes" will load
177
+ # the weights using bitsandbytes quantization.
178
+ # --logits-processor-pattern LOGITS_PROCESSOR_PATTERN
179
+ # Optional regex pattern specifying valid logits processor qualified names that
180
+ # can be passed with the `logits_processors` extra completion argument.
181
+ # Defaults to None, which allows no processors.
182
+ # --long-lora-scaling-factors LONG_LORA_SCALING_FACTORS
183
+ # Specify multiple scaling factors (which can be different from base model
184
+ # scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
185
+ # trained with those scaling factors to be used at the same time. If not
186
+ # specified, only adapters trained with the base model scaling factor are
187
+ # allowed.
188
+ # --lora-dtype {auto,float16,bfloat16}
189
+ # Data type for LoRA. If auto, will default to base model dtype.
190
+ # --lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE
191
+ # Maximum size of extra vocabulary that can be present in a LoRA adapter (added
192
+ # to the base model vocabulary).
193
+ # --lora-modules LORA_MODULES [LORA_MODULES ...]
194
+ # LoRA module configurations in either 'name=path' formator JSON format.
195
+ # Example (old format): 'name=path' Example (new format): '{"name": "name",
196
+ # "local_path": "path", "base_model_name": "id"}'
197
+ # --max-cpu-loras MAX_CPU_LORAS
198
+ # Maximum number of LoRAs to store in CPU memory. Must be >= than max_loras.
199
+ # Defaults to max_loras.
200
+ # --max-log-len MAX_LOG_LEN
201
+ # Max number of prompt characters or prompt ID numbers being printed in log.
202
+ # Default: Unlimited
203
+ # --max-logprobs MAX_LOGPROBS
204
+ # Max number of log probs to return logprobs is specified in SamplingParams.
205
+ # --max-lora-rank MAX_LORA_RANK
206
+ # Max LoRA rank.
207
+ # --max-loras MAX_LORAS
208
+ # Max number of LoRAs in a single batch.
209
+ # --max-model-len MAX_MODEL_LEN
210
+ # Model context length. If unspecified, will be automatically derived from the
211
+ # model config.
212
+ # --max-num-batched-tokens MAX_NUM_BATCHED_TOKENS
213
+ # Maximum number of batched tokens per iteration.
214
+ # --max-num-seqs MAX_NUM_SEQS
215
+ # Maximum number of sequences per iteration.
216
+ # --max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS
217
+ # Load model sequentially in multiple batches, to avoid RAM OOM when using
218
+ # tensor parallel and large models.
219
+ # --max-prompt-adapter-token MAX_PROMPT_ADAPTER_TOKEN
220
+ # Max number of PromptAdapters tokens
221
+ # --max-prompt-adapters MAX_PROMPT_ADAPTERS
222
+ # Max number of PromptAdapters in a batch.
223
+ # --max-seq-len-to-capture MAX_SEQ_LEN_TO_CAPTURE
224
+ # Maximum sequence length covered by CUDA graphs. When a sequence has context
225
+ # length larger than this, we fall back to eager mode. Additionally for
226
+ # encoder-decoder models, if the sequence length of the encoder input is larger
227
+ # than this, we fall back to the eager mode.
228
+ # --middleware MIDDLEWARE
229
+ # Additional ASGI middleware to apply to the app. We accept multiple
230
+ # --middleware arguments. The value should be an import path. If a function is
231
+ # provided, vLLM will add it to the server using @app.middleware('http'). If a
232
+ # class is provided, vLLM will add it to the server using app.add_middleware().
233
+ # --mm-processor-kwargs MM_PROCESSOR_KWARGS
234
+ # Overrides for the multimodal input mapping/processing, e.g., image processor.
235
+ # For example: {"num_crops": 4}.
236
+ # --model MODEL Name or path of the huggingface model to use.
237
+ # --model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG
238
+ # Extra config for model loader. This will be passed to the model loader
239
+ # corresponding to the chosen load_format. This should be a JSON string that
240
+ # will be parsed into a dictionary.
241
+ # --multi-step-stream-outputs [MULTI_STEP_STREAM_OUTPUTS]
242
+ # If False, then multi-step will stream outputs at the end of all steps
243
+ # --ngram-prompt-lookup-max NGRAM_PROMPT_LOOKUP_MAX
244
+ # Max size of window for ngram prompt lookup in speculative decoding.
245
+ # --ngram-prompt-lookup-min NGRAM_PROMPT_LOOKUP_MIN
246
+ # Min size of window for ngram prompt lookup in speculative decoding.
247
+ # --num-gpu-blocks-override NUM_GPU_BLOCKS_OVERRIDE
248
+ # If specified, ignore GPU profiling result and use this number of GPU blocks.
249
+ # Used for testing preemption.
250
+ # --num-lookahead-slots NUM_LOOKAHEAD_SLOTS
251
+ # Experimental scheduling config necessary for speculative decoding. This will
252
+ # be replaced by speculative config in the future; it is present to enable
253
+ # correctness tests until then.
254
+ # --num-scheduler-steps NUM_SCHEDULER_STEPS
255
+ # Maximum number of forward steps per scheduler call.
256
+ # --num-speculative-tokens NUM_SPECULATIVE_TOKENS
257
+ # The number of speculative tokens to sample from the draft model in
258
+ # speculative decoding.
259
+ # --otlp-traces-endpoint OTLP_TRACES_ENDPOINT
260
+ # Target URL to which OpenTelemetry traces will be sent.
261
+ # --override-neuron-config OVERRIDE_NEURON_CONFIG
262
+ # Override or set neuron device configuration. e.g. {"cast_logits_dtype":
263
+ # "bloat16"}.'
264
+ # --override-pooler-config OVERRIDE_POOLER_CONFIG
265
+ # Override or set the pooling method for pooling models. e.g. {"pooling_type":
266
+ # "mean", "normalize": false}.'
267
+ # --pipeline-parallel-size PIPELINE_PARALLEL_SIZE, -pp PIPELINE_PARALLEL_SIZE
268
+ # Number of pipeline stages.
269
+ # --port PORT port number
270
+ # --preemption-mode PREEMPTION_MODE
271
+ # If 'recompute', the engine performs preemption by recomputing; If 'swap', the
272
+ # engine performs preemption by block swapping.
273
+ # --prompt-adapters PROMPT_ADAPTERS [PROMPT_ADAPTERS ...]
274
+ # Prompt adapter configurations in the format name=path. Multiple adapters can
275
+ # be specified.
276
+ # --qlora-adapter-name-or-path QLORA_ADAPTER_NAME_OR_PATH
277
+ # Name or path of the QLoRA adapter.
278
+ # --quantization {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,modelopt,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,compressed-tensors,bitsandbytes,qqq,hqq,experts_int8,neuron_quant,ipex,None}, -q {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,modelopt,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,compressed-tensors,bitsandbytes,qqq,hqq,experts_int8,neuron_quant,ipex,None}
279
+ # Method used to quantize the weights. If None, we first check the
280
+ # `quantization_config` attribute in the model config file. If that is None, we
281
+ # assume the model weights are not quantized and use `dtype` to determine the
282
+ # data type of the weights.
283
+ # --quantization-param-path QUANTIZATION_PARAM_PATH
284
+ # Path to the JSON file containing the KV cache scaling factors. This should
285
+ # generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache
286
+ # scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2
287
+ # (without scaling) is only supported on cuda version greater than 11.8. On
288
+ # ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
289
+ # --ray-workers-use-nsight
290
+ # If specified, use nsight to profile Ray workers.
291
+ # --response-role RESPONSE_ROLE
292
+ # The role name to return if `request.add_generation_prompt=true`.
293
+ # --return-tokens-as-token-ids
294
+ # When --max-logprobs is specified, represents single tokens as strings of the
295
+ # form 'token_id:{token_id}' so that tokens that are not JSON-encodable can be
296
+ # identified.
297
+ # --revision REVISION The specific model version to use. It can be a branch name, a tag name, or a
298
+ # commit id. If unspecified, will use the default version.
299
+ # --root-path ROOT_PATH
300
+ # FastAPI root_path when app is behind a path based routing proxy
301
+ # --rope-scaling ROPE_SCALING
302
+ # RoPE scaling configuration in JSON format. For example,
303
+ # {"rope_type":"dynamic","factor":2.0}
304
+ # --rope-theta ROPE_THETA
305
+ # RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE theta
306
+ # improves the performance of the scaled model.
307
+ # --scheduler-delay-factor SCHEDULER_DELAY_FACTOR
308
+ # Apply a delay (of delay factor multiplied by previous prompt latency) before
309
+ # scheduling next prompt.
310
+ # --scheduling-policy {fcfs,priority}
311
+ # The scheduling policy to use. "fcfs" (first come first served, i.e. requests
312
+ # are handled in order of arrival; default) or "priority" (requests are handled
313
+ # based on given priority (lower value means earlier handling) and time of
314
+ # arrival deciding any ties).
315
+ # --seed SEED Random seed for operations.
316
+ # --served-model-name SERVED_MODEL_NAME [SERVED_MODEL_NAME ...]
317
+ # The model name(s) used in the API. If multiple names are provided, the server
318
+ # will respond to any of the provided names. The model name in the model field
319
+ # of a response will be the first name in this list. If not specified, the
320
+ # model name will be the same as the `--model` argument. Noted that this
321
+ # name(s) will also be used in `model_name` tag content of prometheus metrics,
322
+ # if multiple names provided, metrics tag will take the first one.
323
+ # --skip-tokenizer-init
324
+ # Skip initialization of tokenizer and detokenizer
325
+ # --spec-decoding-acceptance-method {rejection_sampler,typical_acceptance_sampler}
326
+ # Specify the acceptance method to use during draft token verification in
327
+ # speculative decoding. Two types of acceptance routines are supported: 1)
328
+ # RejectionSampler which does not allow changing the acceptance rate of draft
329
+ # tokens, 2) TypicalAcceptanceSampler which is configurable, allowing for a
330
+ # higher acceptance rate at the cost of lower quality, and vice versa.
331
+ # --speculative-disable-by-batch-size SPECULATIVE_DISABLE_BY_BATCH_SIZE
332
+ # Disable speculative decoding for new incoming requests if the number of
333
+ # enqueue requests is larger than this value.
334
+ # --speculative-disable-mqa-scorer
335
+ # If set to True, the MQA scorer will be disabled in speculative and fall back
336
+ # to batch expansion
337
+ # --speculative-draft-tensor-parallel-size SPECULATIVE_DRAFT_TENSOR_PARALLEL_SIZE, -spec-draft-tp SPECULATIVE_DRAFT_TENSOR_PARALLEL_SIZE
338
+ # Number of tensor parallel replicas for the draft model in speculative
339
+ # decoding.
340
+ # --speculative-max-model-len SPECULATIVE_MAX_MODEL_LEN
341
+ # The maximum sequence length supported by the draft model. Sequences over this
342
+ # length will skip speculation.
343
+ # --speculative-model SPECULATIVE_MODEL
344
+ # The name of the draft model to be used in speculative decoding.
345
+ # --speculative-model-quantization {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,modelopt,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,compressed-tensors,bitsandbytes,qqq,hqq,experts_int8,neuron_quant,ipex,None}
346
+ # Method used to quantize the weights of speculative model. If None, we first
347
+ # check the `quantization_config` attribute in the model config file. If that
348
+ # is None, we assume the model weights are not quantized and use `dtype` to
349
+ # determine the data type of the weights.
350
+ # --ssl-ca-certs SSL_CA_CERTS
351
+ # The CA certificates file
352
+ # --ssl-cert-reqs SSL_CERT_REQS
353
+ # Whether client certificate is required (see stdlib ssl module's)
354
+ # --ssl-certfile SSL_CERTFILE
355
+ # The file path to the SSL cert file
356
+ # --ssl-keyfile SSL_KEYFILE
357
+ # The file path to the SSL key file
358
+ # --swap-space SWAP_SPACE
359
+ # CPU swap space size (GiB) per GPU.
360
+ # --task {auto,generate,embedding,embed,classify,score,reward}
361
+ # The task to use the model for. Each vLLM instance only supports one task,
362
+ # even if the same model can be used for multiple tasks. When the model only
363
+ # supports one task, "auto" can be used to select it; otherwise, you must
364
+ # specify explicitly which task to use.
365
+ # --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
366
+ # Number of tensor parallel replicas.
367
+ # --tokenizer TOKENIZER
368
+ # Name or path of the huggingface tokenizer to use. If unspecified, model name
369
+ # or path will be used.
370
+ # --tokenizer-mode {auto,slow,mistral}
371
+ # The tokenizer mode. * "auto" will use the fast tokenizer if available. *
372
+ # "slow" will always use the slow tokenizer. * "mistral" will always use the
373
+ # `mistral_common` tokenizer.
374
+ # --tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG
375
+ # Extra config for tokenizer pool. This should be a JSON string that will be
376
+ # parsed into a dictionary. Ignored if tokenizer_pool_size is 0.
377
+ # --tokenizer-pool-size TOKENIZER_POOL_SIZE
378
+ # Size of tokenizer pool to use for asynchronous tokenization. If 0, will use
379
+ # synchronous tokenization.
380
+ # --tokenizer-pool-type TOKENIZER_POOL_TYPE
381
+ # Type of tokenizer pool to use for asynchronous tokenization. Ignored if
382
+ # tokenizer_pool_size is 0.
383
+ # --tokenizer-revision TOKENIZER_REVISION
384
+ # Revision of the huggingface tokenizer to use. It can be a branch name, a tag
385
+ # name, or a commit id. If unspecified, will use the default version.
386
+ # --tool-call-parser {granite-20b-fc,granite,hermes,internlm,jamba,llama3_json,mistral,pythonic} or name registered in --tool-parser-plugin
387
+ # Select the tool call parser depending on the model that you're using. This is
388
+ # used to parse the model-generated tool call into OpenAI API format. Required
389
+ # for --enable-auto-tool-choice.
390
+ # --tool-parser-plugin TOOL_PARSER_PLUGIN
391
+ # Special the tool parser plugin write to parse the model-generated tool into
392
+ # OpenAI API format, the name register in this plugin can be used in --tool-
393
+ # call-parser.
394
+ # --trust-remote-code Trust remote code from huggingface.
395
+ # --typical-acceptance-sampler-posterior-alpha TYPICAL_ACCEPTANCE_SAMPLER_POSTERIOR_ALPHA
396
+ # A scaling factor for the entropy-based threshold for token acceptance in the
397
+ # TypicalAcceptanceSampler. Typically defaults to sqrt of --typical-acceptance-
398
+ # sampler-posterior-threshold i.e. 0.3
399
+ # --typical-acceptance-sampler-posterior-threshold TYPICAL_ACCEPTANCE_SAMPLER_POSTERIOR_THRESHOLD
400
+ # Set the lower bound threshold for the posterior probability of a token to be
401
+ # accepted. This threshold is used by the TypicalAcceptanceSampler to make
402
+ # sampling decisions during speculative decoding. Defaults to 0.09
403
+ # --use-v2-block-manager
404
+ # [DEPRECATED] block manager v1 has been removed and SelfAttnBlockSpaceManager
405
+ # (i.e. block manager v2) is now the default. Setting this flag to True or
406
+ # False has no effect on vLLM behavior.
407
+ # --uvicorn-log-level {debug,info,warning,error,critical,trace}
408
+ # log level for uvicorn
409
+ # --worker-cls WORKER_CLS
410
+ # The worker class to use for distributed execution.
411
+ # --worker-use-ray Deprecated, use --distributed-executor-backend=ray.
412
+ # -h, --help show this help message and exit
groundingLMM/train.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py - GLaMM Model Training on Mixed Datasets
3
+
4
+ Trains the GLaMM model using Caption, Region, and Segmentation datasets with a random sampling approach. This method
5
+ is crucial for developing a versatile model capable of handling diverse applications effectively.
6
+ """
7
+ import os
8
+ import sys
9
+ import time
10
+ import tqdm
11
+ import random
12
+ import torch
13
+ import argparse
14
+ import deepspeed
15
+ import numpy as np
16
+ import transformers
17
+ from functools import partial
18
+ from torch.utils.data import ConcatDataset
19
+ from peft import LoraConfig, get_peft_model
20
+ from torch.utils.tensorboard import SummaryWriter
21
+
22
+ from model.GLaMM import GLaMMForCausalLM
23
+ from model.llava import conversation as conversation_lib
24
+
25
+ from dataset.dataset import custom_collate_fn, HybridSegDataset, HybridRegDataset, HybridCapDataset
26
+ from tools.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, AverageMeter, ProgressMeter, dict_to_cuda,
27
+ Summary, intersectionAndUnionGPU)
28
+
29
+ from dataset.segm_datasets.RefCOCO_Segm_ds import ReferSegmDataset
30
+ from dataset.region_datasets.RefCOCO_VG_Region_ds import RefCocoGRegDataset, VisualGenomeRegDataset
31
+ from dataset.caption_datasets.COCO_Caption_ds import CocoCapDataset
32
+ from dataset.gcg_datasets.GranDf_gcg_ds import OpenPsgGCGDataset, Flickr30kGCGDataset, RefCOCOgGCGDataset
33
+
34
+
35
+ def parse_args(args):
36
+ parser = argparse.ArgumentParser(description="GLaMM Model Training")
37
+
38
+ # Model-specific settings
39
+ parser.add_argument("--version", default="MBZUAI/GLaMM-GranD-Pretrained")
40
+ parser.add_argument("--vision_pretrained", default="./checkpoints/sam_vit_h_4b8939.pth", type=str)
41
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14-336", type=str)
42
+ parser.add_argument("--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"])
43
+ parser.add_argument("--tune_mm_mlp_adapter", action="store_true")
44
+ parser.add_argument("--freeze_mm_mlp_adapter", action="store_true")
45
+ parser.add_argument("--mm_use_im_start_end", action="store_true", default=True)
46
+ parser.add_argument("--out_dim", default=256, type=int)
47
+ parser.add_argument("--image_size", default=1024, type=int, help="Image size for grounding image encoder")
48
+ parser.add_argument("--model_max_length", default=1536, type=int)
49
+ parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
50
+ parser.add_argument("--with_region", action="store_true", default=True)
51
+ parser.add_argument("--mm_vision_select_layer", default=-2, type=int)
52
+ parser.add_argument("--pretrain_mm_mlp_adapter", default="", type=str)
53
+ parser.add_argument("--precision", default='bf16', type=str)
54
+
55
+ # Dataset settings
56
+ parser.add_argument("--use_cap_data", action="store_true", help="Use caption data")
57
+ parser.add_argument("--use_reg_data", action="store_true", help="Use region data")
58
+ parser.add_argument("--use_segm_data", action="store_true", help="Use segmentation data")
59
+ parser.add_argument("--weight_cap", default=0.15, type=float, help="Sampling weight for caption data")
60
+ parser.add_argument("--weight_reg", default=0.40, type=float, help="Sampling weight for region data")
61
+ parser.add_argument("--weight_segm", default=0.45, type=float, help="Sampling weight for segmentation data")
62
+ parser.add_argument("--dataset_dir", default="./data", type=str)
63
+ parser.add_argument("--seg_dataset", default="Semantic_Segm||Refer_Segm||RefCoco_GCG||PSG_GCG||Flickr_GCG||GranDf_GCG",
64
+ type=str, help="Choose from: Semantic_Segm, Refer_Segm, RefCoco_GCG, GranDf_GCG, PSG_GCG, Flickr_GCG, GrandRefer_Segm")
65
+ parser.add_argument("--segm_sample_rates", default="5,4,3,3,3,1", type=str)
66
+ parser.add_argument("--reg_dataset", default="RefCoco_Reg||RefCocoG_Reg||RefCocoP_Reg||VisGen_Reg",
67
+ type=str, help="Choose from: RefCoco_Reg, RefCocoG_Reg, RefCocoP_Reg, VisGen_Reg, Flickr_Reg, GrandRefer_Reg")
68
+ parser.add_argument("--reg_sample_rates", default="1,1,1,1", type=str)
69
+ parser.add_argument("--cap_dataset", default="CocoCap||LLaVaInstruct", type=str,
70
+ help="Choose from: CocoCap, LLaVaInstruct, GrandCaptionDataset")
71
+ parser.add_argument("--cap_sample_rates", default="1,1", type=str)
72
+ parser.add_argument("--semantic_segm_data", default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", type=str)
73
+ parser.add_argument("--refer_segm_data", default="refcoco||refcoco+||refcocog||refclef", type=str)
74
+ parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
75
+ parser.add_argument("--num_classes_per_sample", default=3, type=int)
76
+
77
+ # Training settings
78
+ parser.add_argument("--pretrained", action="store_true")
79
+ parser.add_argument("--resume", default="", type=str)
80
+ parser.add_argument("--auto_resume", action="store_true")
81
+ parser.add_argument("--weight", default="", type=str)
82
+ parser.add_argument("--lr", default=0.0003, type=float)
83
+ parser.add_argument("--epochs", default=10, type=int)
84
+ parser.add_argument("--steps_per_epoch", default=500, type=int)
85
+ parser.add_argument("--batch_size", default=2, type=int, help="batch size per device per step")
86
+ parser.add_argument("--grad_accumulation_steps", default=10, type=int)
87
+ parser.add_argument("--val_batch_size", default=1, type=int)
88
+ parser.add_argument("--workers", default=2, type=int)
89
+ parser.add_argument("--lora_r", default=8, type=int)
90
+ parser.add_argument("--lora_alpha", default=16, type=int)
91
+ parser.add_argument("--lora_dropout", default=0.05, type=float)
92
+ parser.add_argument("--ce_loss_weight", default=1.0, type=float)
93
+ parser.add_argument("--dice_loss_weight", default=0.5, type=float)
94
+ parser.add_argument("--bce_loss_weight", default=2.0, type=float)
95
+ parser.add_argument("--beta1", default=0.9, type=float)
96
+ parser.add_argument("--beta2", default=0.95, type=float)
97
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
98
+ parser.add_argument("--train_mask_decoder", action="store_true", default=True)
99
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
100
+ parser.add_argument("--print_freq", default=1, type=int)
101
+ parser.add_argument("--start_epoch", default=0, type=int)
102
+ parser.add_argument("--local_rank", default=0, type=int, help="node rank")
103
+
104
+ # Evaluation settings
105
+ parser.add_argument("--val_dataset", default="CocoCapVal|RefCOCOgRegVal|RefCOCOgSegmVal", type=str,
106
+ help="Choose from: CocoCapVal, RefCOCOgRegVal, VisGenomeRegVal, RefCOCOgSegmVal, PsgGCGVal, "
107
+ "RefCocoGCGVal, FlickrGCGVal")
108
+ parser.add_argument("--mask_validation", action="store_true")
109
+ parser.add_argument("--no_eval", action="store_true")
110
+ parser.add_argument("--eval_only", action="store_true")
111
+
112
+ # Experiment settings
113
+ parser.add_argument("--log_base_dir", default="./output", type=str)
114
+ parser.add_argument("--exp_name", default="GlamFinetuneOS", type=str)
115
+
116
+ return parser.parse_args(args)
117
+
118
+
119
+ def initialize_environment(args):
120
+ """ Set up logging and model directories. """
121
+ args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
122
+ if args.local_rank == 0:
123
+ os.makedirs(args.log_dir, exist_ok=True)
124
+ return SummaryWriter(args.log_dir)
125
+ return None
126
+
127
+
128
+ def setup_tokenizer_and_special_tokens(args):
129
+ """ Load tokenizer and add special tokens. """
130
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
131
+ args.version, model_max_length=args.model_max_length, padding_side="right", use_fast=False
132
+ )
133
+ print('\033[92m' + "---- Initialized tokenizer from: {} ----".format(args.version) + '\033[0m')
134
+ tokenizer.pad_token = tokenizer.unk_token
135
+
136
+ if not args.pretrained:
137
+ if args.use_mm_start_end:
138
+ tokenizer.add_tokens(
139
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
140
+ )
141
+ # modifications specific for regions
142
+ reg_tokens = ['<bbox>', '<point>']
143
+ # Adding special tokens for pixel grounding
144
+ segmentation_tokens = ['[SEG]']
145
+ # Adding tokens for GCG
146
+ phrase_tokens = ['<p>', '</p>']
147
+ special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
148
+ tokenizer.add_tokens(special_tokens, special_tokens=True)
149
+
150
+ args.bbox_token_idx = tokenizer("<bbox>", add_special_tokens=False).input_ids[0]
151
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
152
+ args.bop_token_idx = tokenizer("<p>", add_special_tokens=False).input_ids[0]
153
+ args.eop_token_idx = tokenizer("</p>", add_special_tokens=False).input_ids[0]
154
+
155
+ return tokenizer
156
+
157
+
158
+ def initialize_model(args, tokenizer):
159
+ """ Initialize the GLaMM model. """
160
+ model_args = {k: getattr(args, k) for k in
161
+ ["train_mask_decoder", "out_dim", "ce_loss_weight", "dice_loss_weight", "bce_loss_weight",
162
+ "seg_token_idx", "vision_pretrained", "vision_tower", "use_mm_start_end", "mm_vision_select_layer",
163
+ "pretrain_mm_mlp_adapter", "tune_mm_mlp_adapter", "freeze_mm_mlp_adapter", "mm_use_im_start_end",
164
+ "with_region", "bbox_token_idx", "eop_token_idx", "bop_token_idx"]}
165
+ model_args["num_level_reg_features"] = 4
166
+
167
+ model = GLaMMForCausalLM.from_pretrained(
168
+ args.version, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, **model_args
169
+ )
170
+ print('\033[92m' + "---- Initialized model from: {} ----".format(args.version) + '\033[0m')
171
+
172
+ # Configure model tokens
173
+ model.config.eos_token_id = tokenizer.eos_token_id
174
+ model.config.bos_token_id = tokenizer.bos_token_id
175
+ model.config.pad_token_id = tokenizer.pad_token_id
176
+
177
+ return model
178
+
179
+
180
+ def prepare_model_for_training(model, tokenizer, args):
181
+ # Enable input gradients
182
+ model.enable_input_require_grads()
183
+ model.gradient_checkpointing_enable()
184
+
185
+ # Initialize vision tower
186
+ print(
187
+ '\033[92m' + "---- Initialized Global Image Encoder (vision tower) from: {} ----".format(
188
+ args.vision_tower
189
+ ) + '\033[0m'
190
+ )
191
+ model.get_model().initialize_vision_modules(model.get_model().config)
192
+ vision_tower = model.get_model().get_vision_tower()
193
+ vision_tower.to(dtype=torch.bfloat16, device=args.local_rank)
194
+
195
+ # Initialize GLaMM model and adjust requires_grad
196
+ if not args.pretrained:
197
+ model.get_model().initialize_glamm_model(model.get_model().config)
198
+ else:
199
+ for param in model.get_model().grounding_encoder.parameters():
200
+ param.requires_grad = False
201
+ if model.get_model().config.train_mask_decoder:
202
+ model.get_model().grounding_encoder.mask_decoder.train()
203
+ for param in model.get_model().grounding_encoder.mask_decoder.parameters():
204
+ param.requires_grad = True
205
+
206
+ # Projection layer
207
+ model.get_model().text_hidden_fcs.train()
208
+ for param in model.get_model().text_hidden_fcs.parameters():
209
+ param.requires_grad = True
210
+
211
+ # Set requires_grad for vision tower and mm projector
212
+ for p in vision_tower.parameters():
213
+ p.requires_grad = False
214
+ for p in model.get_model().mm_projector.parameters():
215
+ p.requires_grad = False
216
+
217
+ # Set requires_grad based on LoRA training
218
+ lora_r = args.lora_r
219
+ if lora_r == 0:
220
+ for p in model.get_model().layers.parameters():
221
+ p.requires_grad = True
222
+ for p in model.get_model().mm_projector.parameters():
223
+ p.requires_grad = True
224
+
225
+ # Configure conversation library
226
+ conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv_type]
227
+
228
+ # Configure LoRA if applicable
229
+ if lora_r > 0:
230
+ lora_config = setup_lora_config(model, args)
231
+ model = get_peft_model(model, lora_config)
232
+
233
+ # Resize token embeddings
234
+ model.resize_token_embeddings(len(tokenizer))
235
+
236
+ # Make certain modules trainable
237
+ set_trainable_modules(model)
238
+
239
+
240
+ def setup_lora_config(model, args):
241
+ """ Configure LoRA settings for the model. """
242
+
243
+ def find_proj_layers(model, target_modules):
244
+ """ Identify projection layers in the model for LoRA adaptation. """
245
+ linear_cls = torch.nn.Linear
246
+ lora_module_names = set()
247
+ for name, module in model.named_modules():
248
+ if (isinstance(module, linear_cls) and all(
249
+ x not in name for x in ["grounding_encoder", "vision_tower", "mm_projector", "text_hidden_fcs"]
250
+ ) and any(x in name for x in target_modules)):
251
+ lora_module_names.add(name)
252
+ return sorted(list(lora_module_names))
253
+
254
+ # Extracting LoRA target modules
255
+ lora_target_modules = args.lora_target_modules.split(",")
256
+ lora_module_names = find_proj_layers(model, lora_target_modules)
257
+
258
+ # Configuring LoRA
259
+ lora_config = LoraConfig(
260
+ r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=lora_module_names, lora_dropout=args.lora_dropout,
261
+ bias="none", task_type="CAUSAL_LM"
262
+ )
263
+ return lora_config
264
+
265
+
266
+ def set_trainable_modules(model):
267
+ """ Make specified modules in the model trainable. """
268
+ trainable_modules = ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs", "region_encoder"]
269
+ for name, param in model.named_parameters():
270
+ if any(module in name for module in trainable_modules):
271
+ print(f"Making trainable: {name}, Shape: {param.shape}")
272
+ param.requires_grad = True
273
+
274
+ def count_parameters(model):
275
+ total_params = sum(p.numel() for p in model.parameters())
276
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
277
+
278
+ print('\033[92m' + "---- Total parameters: ----{}".format(total_params) + '\033[0m')
279
+ print('\033[92m' + "---- Trainable parameters: ----{}".format(trainable_params) + '\033[0m')
280
+
281
+ count_parameters(model)
282
+
283
+
284
+ def initialize_datasets_and_loaders(args, tokenizer):
285
+ world_size = torch.cuda.device_count()
286
+ args.distributed = world_size > 1
287
+
288
+ # Common dataset arguments
289
+ common_ds_args = {"dataset_dir": args.dataset_dir, "tokenizer": tokenizer,
290
+ "global_image_encoder": args.vision_tower,
291
+ "epoch_samples": args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size,
292
+ "precision": args.precision, "image_size": args.image_size,
293
+ "num_classes_per_sample": args.num_classes_per_sample}
294
+
295
+ # Training datasets
296
+ cap_train_dataset = HybridCapDataset(
297
+ **common_ds_args, dataset=args.cap_dataset, sample_rate=[float(x) for x in args.cap_sample_rates.split(",")],
298
+ batch_size=args.batch_size, ) if args.use_cap_data else None
299
+ reg_train_dataset = HybridRegDataset(
300
+ **common_ds_args, dataset=args.reg_dataset, sample_rate=[float(x) for x in args.reg_sample_rates.split(",")],
301
+ batch_size=args.batch_size, ) if args.use_reg_data else None
302
+ seg_train_dataset = HybridSegDataset(
303
+ **common_ds_args, dataset=args.seg_dataset, sample_rate=[float(x) for x in args.segm_sample_rates.split(",")],
304
+ semantic_segm_data=args.semantic_segm_data, refer_segm_data=args.refer_segm_data,
305
+ batch_size=args.batch_size, ) if args.use_segm_data else None
306
+
307
+ # Validation datasets
308
+ val_datasets = []
309
+ if not args.no_eval:
310
+ val_dataset_classes = {'CocoCapVal': CocoCapDataset,
311
+ 'RefCOCOgRegVal': RefCocoGRegDataset,
312
+ 'VisGenomeRegVal': VisualGenomeRegDataset,
313
+ 'RefCOCOgSegmVal': ReferSegmDataset,
314
+ 'PsgGCGVal': OpenPsgGCGDataset,
315
+ 'RefCocoGCGVal': RefCOCOgGCGDataset,
316
+ 'FlickrGCGVal': Flickr30kGCGDataset,
317
+ }
318
+ for val_dataset_name in args.val_dataset.split('|'):
319
+ val_dataset_class = val_dataset_classes.get(val_dataset_name)
320
+ if val_dataset_class:
321
+ if val_dataset_class == ReferSegmDataset:
322
+ # Modify this if other datasets in refer_segm_data need to be included in val
323
+ refer_segm_data = 'refcocog'
324
+ all_datasets = refer_segm_data.split("||")
325
+ for d in all_datasets:
326
+ val_dataset_class = val_dataset_class(
327
+ **common_ds_args, validation=True, refer_segm_data=d, split='val'
328
+ )
329
+ val_dataset_class._set_len(len(val_dataset_class.refer_segm_data[d]['images']))
330
+ val_datasets.append(val_dataset_class)
331
+ else:
332
+ val_datasets.append(val_dataset_class(**common_ds_args, validation=True))
333
+
334
+ return cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets
335
+
336
+
337
+ def setup_data_loaders(args, cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets, tokenizer):
338
+ sampler_args = {"shuffle": False, "drop_last": False}
339
+ train_loader_args = {"batch_size": args.batch_size, "shuffle": False, "num_workers": args.workers,
340
+ "pin_memory": False}
341
+ val_loader_args = {"batch_size": args.val_batch_size, "shuffle": False, "num_workers": args.workers,
342
+ "pin_memory": False}
343
+ collate_fn_args_train = partial(
344
+ custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank,
345
+ inference=False
346
+ )
347
+ inference_mode = args.mask_validation
348
+ collate_fn_args_val = partial(
349
+ custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank,
350
+ inference=inference_mode
351
+ )
352
+
353
+ # Training loaders
354
+ cap_train_loader = torch.utils.data.DataLoader(
355
+ cap_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler(
356
+ cap_train_dataset, **sampler_args
357
+ ), collate_fn=collate_fn_args_train, **train_loader_args
358
+ ) if cap_train_dataset is not None else None
359
+ reg_train_loader = torch.utils.data.DataLoader(
360
+ reg_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler(
361
+ reg_train_dataset, **sampler_args
362
+ ), collate_fn=collate_fn_args_train, **train_loader_args
363
+ ) if reg_train_dataset is not None else None
364
+ seg_train_loader = torch.utils.data.DataLoader(
365
+ seg_train_dataset, sampler=torch.utils.data.distributed.DistributedSampler(
366
+ seg_train_dataset, **sampler_args
367
+ ), collate_fn=collate_fn_args_train, **train_loader_args
368
+ ) if seg_train_dataset is not None else None
369
+
370
+ # Validation loader
371
+ val_loader = None
372
+ if val_datasets:
373
+ combined_val_datasets = ConcatDataset(val_datasets)
374
+ val_loader = torch.utils.data.DataLoader(
375
+ combined_val_datasets, **val_loader_args, collate_fn=collate_fn_args_val,
376
+ sampler=torch.utils.data.distributed.DistributedSampler(combined_val_datasets, **sampler_args), )
377
+
378
+ return cap_train_loader, reg_train_loader, seg_train_loader, val_loader
379
+
380
+
381
+ def initialize_deepspeed(model, tokenizer, args):
382
+ ds_config = {"train_micro_batch_size_per_gpu": args.batch_size,
383
+ "gradient_accumulation_steps": args.grad_accumulation_steps,
384
+ "optimizer": {"type": "AdamW", "params": {"lr": args.lr, "weight_decay": 0.0,
385
+ "betas": (args.beta1, args.beta2)}},
386
+ "scheduler": {"type": "WarmupDecayLR",
387
+ "params": {"total_num_steps": args.epochs * args.steps_per_epoch, "warmup_min_lr": 0,
388
+ "warmup_max_lr": args.lr, "warmup_num_steps": 100, "warmup_type": "linear"}},
389
+ "fp16": {"enabled": args.precision == "fp16"}, "bf16": {"enabled": args.precision == "bf16"},
390
+ "gradient_clipping": 1.0,
391
+ "zero_optimization": {"stage": 2, "contiguous_gradients": True, "overlap_comm": True,
392
+ "reduce_scatter": True, "reduce_bucket_size": 5e8,
393
+ "allgather_bucket_size": 5e8}, }
394
+
395
+ model_engine, optimizer, _, scheduler = deepspeed.initialize(
396
+ model=model, model_parameters=model.parameters(), collate_fn=partial(
397
+ custom_collate_fn, tokenizer=tokenizer, use_mm_start_end=args.use_mm_start_end, local_rank=args.local_rank
398
+ ), config=ds_config
399
+ )
400
+
401
+ return model_engine, optimizer, scheduler
402
+
403
+
404
+ def resume_training_from_checkpoint(model_engine, args):
405
+ if args.auto_resume and not args.resume:
406
+ resume = os.path.join(args.log_dir, "ckpt_model")
407
+ if os.path.exists(resume):
408
+ args.resume = resume
409
+
410
+ if args.resume:
411
+ load_path, client_state = model_engine.load_checkpoint(args.resume)
412
+ with open(os.path.join(args.resume, "latest"), "r") as f:
413
+ ckpt_dir = f.readlines()[0].strip()
414
+ args.start_epoch = int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
415
+ print(f"Resume training from {args.resume}, start from epoch {args.start_epoch}")
416
+
417
+
418
+ def main(args):
419
+ tokenizer = setup_tokenizer_and_special_tokens(args)
420
+ model = initialize_model(args, tokenizer)
421
+ prepare_model_for_training(model, tokenizer, args)
422
+
423
+ model_engine, optimizer, scheduler = initialize_deepspeed(model, tokenizer, args)
424
+ resume_training_from_checkpoint(model_engine, args)
425
+
426
+ cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets = (
427
+ initialize_datasets_and_loaders(args, tokenizer))
428
+ cap_train_loader, reg_train_loader, seg_train_loader, val_loader = (
429
+ setup_data_loaders(args, cap_train_dataset, reg_train_dataset, seg_train_dataset, val_datasets, tokenizer))
430
+
431
+ # Determine active datasets and their weights
432
+ active_dataloaders = []
433
+ weights = []
434
+
435
+ if args.use_cap_data:
436
+ active_dataloaders.append(('cap', cap_train_loader))
437
+ weights.append(args.weight_cap)
438
+ if args.use_reg_data:
439
+ active_dataloaders.append(('reg', reg_train_loader))
440
+ weights.append(args.weight_reg)
441
+ if args.use_segm_data:
442
+ active_dataloaders.append(('seg', seg_train_loader))
443
+ weights.append(args.weight_segm)
444
+
445
+ # Assert that at least one dataset is active
446
+ assert active_dataloaders, "Error: At least one dataset (segm, reg, or cap) must be active."
447
+
448
+ dataset_iters = {'cap': iter(cap_train_loader) if args.use_cap_data else None,
449
+ 'reg': iter(reg_train_loader) if args.use_reg_data else None,
450
+ 'seg': iter(seg_train_loader) if args.use_segm_data else None, }
451
+
452
+ writer = initialize_environment(args)
453
+
454
+ if args.eval_only:
455
+ cur_val_loss = validate_model_performance(val_loader, model_engine, 0, writer, args)[0]
456
+ exit()
457
+
458
+ epoch_seeds = [random.randint(0, 100000) for _ in range(args.epochs)]
459
+ dataset_choices = [idx for idx, _ in enumerate(active_dataloaders)]
460
+
461
+ best_giou, best_ciou, best_val_loss = 0.0, 0.0, np.inf
462
+ for epoch in range(args.start_epoch, args.epochs):
463
+ random.seed(epoch_seeds[epoch])
464
+
465
+ step_choices = random.choices(dataset_choices, weights=weights, k=args.steps_per_epoch)
466
+
467
+ dataset_iters = train(
468
+ active_dataloaders, model_engine, epoch, scheduler, writer, dataset_iters, args, step_choices
469
+ )
470
+
471
+ if args.mask_validation:
472
+ giou, ciou = validate_model_performance(val_loader, model_engine, epoch, writer, args)
473
+ is_best = giou > best_giou
474
+ best_giou = max(giou, best_giou)
475
+ best_ciou = ciou if is_best else best_ciou
476
+ if args.local_rank == 0: # Log the progress
477
+ print(f"Epoch: {epoch}, giou: {giou}, ciou: {ciou}, best_giou: {best_giou}, best_ciou: {best_ciou}")
478
+ save_checkpoint(model_engine, args, epoch, 'giou-ciou', f"{giou:.4f}-{ciou:.4f}", is_best)
479
+ else:
480
+ cur_val_loss = validate_model_performance(val_loader, model_engine, epoch, writer, args)
481
+ is_best = cur_val_loss < best_val_loss
482
+ best_val_loss = min(cur_val_loss, best_val_loss)
483
+ if args.local_rank == 0: # Log the progress
484
+ print(f"Epoch: {epoch}, Current Validation Loss: {cur_val_loss:.4f}, Best Validation Loss: {best_val_loss:}")
485
+ save_checkpoint(model_engine, args, epoch, 'loss', f"{cur_val_loss:.4f}", is_best)
486
+
487
+
488
+ def save_checkpoint(model_engine, args, epoch, metric_name, metric_value, is_best):
489
+ """ Saves the model checkpoint. """
490
+ # If the checkpoint is the best, save it in ckpt_model_best, else in ckpt_model_last_epoch
491
+ save_dir_name = "ckpt_model_best" if is_best else "ckpt_model_last_epoch"
492
+ save_dir = os.path.join(args.log_dir, save_dir_name)
493
+ # Ensure the directory exists
494
+ if args.local_rank == 0:
495
+ os.makedirs(save_dir, exist_ok=True)
496
+ ckpt_filename = f"epoch_{epoch}_val_{metric_name}_{metric_value}.pth"
497
+ torch.save({"epoch": epoch, f"val_{metric_name}": metric_value}, os.path.join(save_dir, ckpt_filename))
498
+ torch.distributed.barrier()
499
+ model_engine.save_checkpoint(save_dir)
500
+
501
+
502
+ def train(active_datasets, model, epoch, scheduler, writer, dataset_iters, args, step_choices):
503
+ """Main training loop."""
504
+
505
+ def get_next_input(iterator, data_loader):
506
+ """Retrieve next input from the iterator, or reinitialize if necessary."""
507
+ try:
508
+ return next(iterator), iterator
509
+ except StopIteration:
510
+ new_iterator = iter(data_loader)
511
+ return next(new_iterator), new_iterator
512
+
513
+ def log_progress():
514
+ """Log training progress."""
515
+ if global_step % args.print_freq == 0:
516
+ if args.distributed:
517
+ for tracker in trackers.values():
518
+ tracker.all_reduce()
519
+
520
+ if args.local_rank == 0:
521
+ progress.display(global_step + 1)
522
+ for key, tracker in trackers.items():
523
+ writer.add_scalar(f"train/{key}", tracker.avg, global_step)
524
+ writer.add_scalar("metrics/total_secs_per_batch", batch_time.avg, global_step)
525
+ writer.add_scalar("metrics/data_secs_per_batch", data_time.avg, global_step)
526
+
527
+ for tracker in trackers.values():
528
+ tracker.reset()
529
+
530
+ batch_time = AverageMeter("Time", ":.4f")
531
+ data_time = AverageMeter("Data", ":.4f")
532
+ trackers = {"loss": AverageMeter("Loss", ":.4f"),
533
+ "ce_loss": AverageMeter("CeLoss", ":.4f"),
534
+ "mask_bce_loss": AverageMeter("MaskBCELoss", ":.4f"),
535
+ "mask_dice_loss": AverageMeter("MaskDICELoss", ":.4f"),
536
+ "mask_loss": AverageMeter("MaskLoss", ":.4f")}
537
+ progress = ProgressMeter(args.steps_per_epoch, list(trackers.values()), prefix=f"Epoch: [{epoch}]")
538
+
539
+ model.train()
540
+ end = time.time()
541
+ for global_step in range(args.steps_per_epoch):
542
+ for _ in range(args.grad_accumulation_steps):
543
+ # Select data loader based on step choice
544
+ dataset_type, data_loader = active_datasets[step_choices[global_step]]
545
+ data_batch, new_iter = get_next_input(dataset_iters[dataset_type], data_loader)
546
+ dataset_iters[dataset_type] = new_iter
547
+
548
+ data_time.update(time.time() - end)
549
+ # Prepare data and convert relevant tensors to bfloat16
550
+ data_batch = dict_to_cuda(data_batch)
551
+ for key in ["global_enc_images", "grounding_enc_images"]:
552
+ if data_batch[key] is not None:
553
+ data_batch[key] = data_batch[key].bfloat16()
554
+
555
+ output_dict = model(**data_batch)
556
+
557
+ # Update training metrics
558
+ for key, tracker in trackers.items():
559
+ if key in output_dict:
560
+ tracker.update(output_dict[key].item(), data_batch["global_enc_images"].size(0))
561
+
562
+ model.backward(output_dict["loss"])
563
+ model.step()
564
+
565
+ batch_time.update(time.time() - end)
566
+ end = time.time()
567
+ log_progress()
568
+
569
+ if global_step != 0:
570
+ curr_lr = scheduler.get_last_lr()
571
+ if args.local_rank == 0:
572
+ writer.add_scalar("train/lr", curr_lr[0], global_step)
573
+
574
+ return dataset_iters
575
+
576
+
577
+ def validate_model_performance(validation_loader, training_model, current_epoch, tensorboard_writer, args):
578
+ if args.mask_validation:
579
+ # For use with only segmentation/GCG type datasets
580
+ trackers = {"intersection": AverageMeter("Intersec", ":.4f", Summary.SUM),
581
+ "union": AverageMeter("Union", ":.4f", Summary.SUM),
582
+ "gIoU": AverageMeter("gIoU", ":.4f", Summary.SUM)}
583
+
584
+ training_model.eval()
585
+ for data_batch in tqdm.tqdm(validation_loader):
586
+ # Prepare data and convert relevant tensors to bfloat16
587
+ data_batch = dict_to_cuda(data_batch)
588
+ for key in ["global_enc_images", "grounding_enc_images"]:
589
+ data_batch[key] = data_batch[key].bfloat16()
590
+ torch.cuda.empty_cache()
591
+ # Model inference without gradient tracking
592
+ with torch.no_grad():
593
+ results = training_model(**data_batch)
594
+
595
+ predictions = results["pred_masks"]
596
+ gt_masks = results["gt_masks"][0].int()
597
+ # Note: An error at this line may suggest that the dataset used for validation does not support
598
+ # segmentation tasks. Ensure that the dataset is appropriate for segmentation analysis.
599
+ predicted_masks = (predictions[0] > 0).int()
600
+ assert len(predictions) == 1
601
+
602
+ intersection, union, accuracy_iou = 0.0, 0.0, 0.0
603
+ for target, prediction in zip(gt_masks, predicted_masks):
604
+ intersect, union_, _ = intersectionAndUnionGPU(
605
+ prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255
606
+ )
607
+ intersection += intersect
608
+ union += union_
609
+ accuracy_iou += intersect / (union_ + 1e-5)
610
+ # handles no-object targets
611
+ accuracy_iou[union_ == 0] += 1.0
612
+
613
+ intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
614
+ accuracy_iou = accuracy_iou.cpu().numpy() / gt_masks.shape[0]
615
+ trackers["intersection"].update(intersection)
616
+ trackers["union"].update(union)
617
+ trackers["gIoU"].update(accuracy_iou, n=gt_masks.shape[0])
618
+
619
+ for meter in trackers.values():
620
+ meter.all_reduce()
621
+
622
+ iou_per_class = trackers["intersection"].sum / (trackers["union"].sum + 1e-10)
623
+ class_iou = iou_per_class[1]
624
+ global_iou = trackers["gIoU"].avg[1]
625
+
626
+ if args.local_rank == 0:
627
+ tensorboard_writer.add_scalar("val/giou", global_iou, current_epoch)
628
+ tensorboard_writer.add_scalar("val/ciou", class_iou, current_epoch)
629
+ print("giou: {:.4f}, ciou: {:.4f}".format(global_iou, class_iou))
630
+
631
+ return global_iou, class_iou
632
+ else:
633
+ # Initializing performance trackers
634
+ trackers = {"loss": AverageMeter("Loss", ":.4f"), "ce_loss": AverageMeter("CeLoss", ":.4f"),
635
+ "mask_bce_loss": AverageMeter("MaskBCELoss", ":.4f"),
636
+ "mask_dice_loss": AverageMeter("MaskDICELoss", ":.4f"),
637
+ "mask_loss": AverageMeter("MaskLoss", ":.4f")}
638
+
639
+ # Prepare model for validation phase
640
+ # Hack to get the loss
641
+ training_model.train()
642
+
643
+ for data_batch in tqdm.tqdm(validation_loader):
644
+ # Prepare data and convert relevant tensors to bfloat16
645
+ data_batch = dict_to_cuda(data_batch)
646
+ for key in ["global_enc_images", "grounding_enc_images"]:
647
+ if data_batch[key] is not None:
648
+ data_batch[key] = data_batch[key].bfloat16()
649
+ torch.cuda.empty_cache()
650
+ # Model inference without gradient tracking
651
+ with torch.no_grad():
652
+ predictions = training_model(**data_batch)
653
+ # Update performance metrics)
654
+ for key, tracker in trackers.items():
655
+ tracker.update(predictions[key].item(), data_batch["global_enc_images"].size(0))
656
+
657
+ # Synchronize metrics across processes
658
+ for tracker in trackers.values():
659
+ tracker.all_reduce()
660
+ # Calculate average validation loss
661
+ avg_val_loss = trackers["ce_loss"].avg
662
+ # Tensorboard logging for primary process
663
+ if args.local_rank == 0:
664
+ tensorboard_writer.add_scalar("val/loss", avg_val_loss, current_epoch)
665
+
666
+ return avg_val_loss
667
+
668
+
669
+ if __name__ == "__main__":
670
+ args = parse_args(sys.argv[1:])
671
+ main(args)
lightning-hydra-template/.github/codecov.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ coverage:
2
+ status:
3
+ # measures overall project coverage
4
+ project:
5
+ default:
6
+ threshold: 100% # how much decrease in coverage is needed to not consider success
7
+
8
+ # measures PR or single commit coverage
9
+ patch:
10
+ default:
11
+ threshold: 100% # how much decrease in coverage is needed to not consider success
12
+
13
+
14
+ # project: off
15
+ # patch: off
lightning-hydra-template/.github/workflows/test.yml ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main, "release/*", "dev"]
8
+
9
+ jobs:
10
+ run_tests_ubuntu:
11
+ runs-on: ${{ matrix.os }}
12
+
13
+ strategy:
14
+ fail-fast: false
15
+ matrix:
16
+ os: ["ubuntu-latest"]
17
+ python-version: ["3.8", "3.9", "3.10"]
18
+
19
+ timeout-minutes: 20
20
+
21
+ steps:
22
+ - name: Checkout
23
+ uses: actions/checkout@v3
24
+
25
+ - name: Set up Python ${{ matrix.python-version }}
26
+ uses: actions/setup-python@v3
27
+ with:
28
+ python-version: ${{ matrix.python-version }}
29
+
30
+ - name: Install dependencies
31
+ run: |
32
+ python -m pip install --upgrade pip
33
+ pip install -r requirements.txt
34
+ pip install pytest
35
+ pip install sh
36
+
37
+ - name: List dependencies
38
+ run: |
39
+ python -m pip list
40
+
41
+ - name: Run pytest
42
+ run: |
43
+ pytest -v
44
+
45
+ run_tests_macos:
46
+ runs-on: ${{ matrix.os }}
47
+
48
+ strategy:
49
+ fail-fast: false
50
+ matrix:
51
+ os: ["macos-latest"]
52
+ python-version: ["3.8", "3.9", "3.10"]
53
+
54
+ timeout-minutes: 20
55
+
56
+ steps:
57
+ - name: Checkout
58
+ uses: actions/checkout@v3
59
+
60
+ - name: Set up Python ${{ matrix.python-version }}
61
+ uses: actions/setup-python@v3
62
+ with:
63
+ python-version: ${{ matrix.python-version }}
64
+
65
+ - name: Install dependencies
66
+ run: |
67
+ python -m pip install --upgrade pip
68
+ pip install -r requirements.txt
69
+ pip install pytest
70
+ pip install sh
71
+
72
+ - name: List dependencies
73
+ run: |
74
+ python -m pip list
75
+
76
+ - name: Run pytest
77
+ run: |
78
+ pytest -v
79
+
80
+ run_tests_windows:
81
+ runs-on: ${{ matrix.os }}
82
+
83
+ strategy:
84
+ fail-fast: false
85
+ matrix:
86
+ os: ["windows-latest"]
87
+ python-version: ["3.8", "3.9", "3.10"]
88
+
89
+ timeout-minutes: 20
90
+
91
+ steps:
92
+ - name: Checkout
93
+ uses: actions/checkout@v3
94
+
95
+ - name: Set up Python ${{ matrix.python-version }}
96
+ uses: actions/setup-python@v3
97
+ with:
98
+ python-version: ${{ matrix.python-version }}
99
+
100
+ - name: Install dependencies
101
+ run: |
102
+ python -m pip install --upgrade pip
103
+ pip install -r requirements.txt
104
+ pip install pytest
105
+
106
+ - name: List dependencies
107
+ run: |
108
+ python -m pip list
109
+
110
+ - name: Run pytest
111
+ run: |
112
+ pytest -v
113
+
114
+ # upload code coverage report
115
+ code-coverage:
116
+ runs-on: ubuntu-latest
117
+
118
+ steps:
119
+ - name: Checkout
120
+ uses: actions/checkout@v2
121
+
122
+ - name: Set up Python 3.10
123
+ uses: actions/setup-python@v2
124
+ with:
125
+ python-version: "3.10"
126
+
127
+ - name: Install dependencies
128
+ run: |
129
+ python -m pip install --upgrade pip
130
+ pip install -r requirements.txt
131
+ pip install pytest
132
+ pip install pytest-cov[toml]
133
+ pip install sh
134
+
135
+ - name: Run tests and collect coverage
136
+ run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
137
+
138
+ - name: Upload coverage to Codecov
139
+ uses: codecov/codecov-action@v3
lightning-hydra-template/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # this file is needed here to include configs when building project as a package
lightning-hydra-template/configs/local/.gitkeep ADDED
File without changes
lightning-hydra-template/configs/train.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: mnist
8
+ - model: mnist
9
+ - callbacks: default
10
+ - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
+ - trainer: default
12
+ - paths: default
13
+ - extras: default
14
+ - hydra: default
15
+
16
+ # experiment configs allow for version control of specific hyperparameters
17
+ # e.g. best hyperparameters for given model and datamodule
18
+ - experiment: null
19
+
20
+ # config for hyperparameter optimization
21
+ - hparams_search: null
22
+
23
+ # optional local config for machine/user specific settings
24
+ # it's optional since it doesn't need to exist and is excluded from version control
25
+ - optional local: default
26
+
27
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
28
+ - debug: null
29
+
30
+ # task name, determines output directory path
31
+ task_name: "train"
32
+
33
+ # tags to help you identify your experiments
34
+ # you can overwrite this in experiment configs
35
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36
+ tags: ["dev"]
37
+
38
+ # set False to skip model training
39
+ train: True
40
+
41
+ # evaluate on test set, using best model weights achieved during training
42
+ # lightning chooses best weights based on the metric specified in checkpoint callback
43
+ test: True
44
+
45
+ # simply provide checkpoint path to resume training
46
+ ckpt_path: null
47
+
48
+ # seed for random number generators in pytorch, numpy and python.random
49
+ seed: null
lightning-hydra-template/logs/.gitkeep ADDED
File without changes
lightning-hydra-template/tests/test_datamodules.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from src.data.mnist_datamodule import MNISTDataModule
7
+
8
+
9
+ @pytest.mark.parametrize("batch_size", [32, 128])
10
+ def test_mnist_datamodule(batch_size: int) -> None:
11
+ """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary
12
+ attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes
13
+ correctly match.
14
+
15
+ :param batch_size: Batch size of the data to be loaded by the dataloader.
16
+ """
17
+ data_dir = "data/"
18
+
19
+ dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size)
20
+ dm.prepare_data()
21
+
22
+ assert not dm.data_train and not dm.data_val and not dm.data_test
23
+ assert Path(data_dir, "MNIST").exists()
24
+ assert Path(data_dir, "MNIST", "raw").exists()
25
+
26
+ dm.setup()
27
+ assert dm.data_train and dm.data_val and dm.data_test
28
+ assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()
29
+
30
+ num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
31
+ assert num_datapoints == 70_000
32
+
33
+ batch = next(iter(dm.train_dataloader()))
34
+ x, y = batch
35
+ assert len(x) == batch_size
36
+ assert len(y) == batch_size
37
+ assert x.dtype == torch.float32
38
+ assert y.dtype == torch.int64
lightning-hydra-template/tests/test_eval.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ from hydra.core.hydra_config import HydraConfig
6
+ from omegaconf import DictConfig, open_dict
7
+
8
+ from src.eval import evaluate
9
+ from src.train import train
10
+
11
+
12
+ @pytest.mark.slow
13
+ def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None:
14
+ """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with
15
+ `eval.py`.
16
+
17
+ :param tmp_path: The temporary logging path.
18
+ :param cfg_train: A DictConfig containing a valid training configuration.
19
+ :param cfg_eval: A DictConfig containing a valid evaluation configuration.
20
+ """
21
+ assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir
22
+
23
+ with open_dict(cfg_train):
24
+ cfg_train.trainer.max_epochs = 1
25
+ cfg_train.test = True
26
+
27
+ HydraConfig().set_config(cfg_train)
28
+ train_metric_dict, _ = train(cfg_train)
29
+
30
+ assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")
31
+
32
+ with open_dict(cfg_eval):
33
+ cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
34
+
35
+ HydraConfig().set_config(cfg_eval)
36
+ test_metric_dict, _ = evaluate(cfg_eval)
37
+
38
+ assert test_metric_dict["test/acc"] > 0.0
39
+ assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001