Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml +54 -0
- VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml +31 -0
- VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml +23 -0
- VLMEvalKit_old/InternVL/internvl_chat_llava/docs/Data.md +29 -0
- VLMEvalKit_old/InternVL/internvl_chat_llava/docs/LLaVA_Bench.md +31 -0
- VLMEvalKit_old/InternVL/internvl_g/eval/evaluate_caption.py +237 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/__init__.py +0 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/__init__.py +87 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/configuration_intern_vit.py +117 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/flash_attention.py +76 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_intern_vit.py +342 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_internvl.py +684 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_qllama.py +1073 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py +87 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_intern_vit.py +117 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_internvl.py +108 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/flash_attention.py +76 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py +342 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py +669 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py +1073 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/train/dataset.py +283 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py +286 -0
- VLMEvalKit_old/InternVL/internvl_g/internvl/train/trainer_monkey_patch.py +150 -0
- VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh +58 -0
- VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh +58 -0
- VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh +61 -0
- VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh +61 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py +56 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_640x640.py +54 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/cityscapes_832x832.py +35 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff10k.py +57 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff164k.py +54 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/hrf.py +59 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ann_r50-d8.py +46 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ccnet_r50-d8.py +44 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/cgnet.py +35 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/danet_r50-d8.py +44 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_r50-d8.py +44 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py +50 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py +44 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/emanet_r50-d8.py +47 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py +48 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/erfnet_fcn.py +32 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fast_scnn.py +57 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py +53 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/gcnet_r50-d8.py +46 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py +45 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py +25 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/nonlocal_r50-d8.py +46 -0
- VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/pointrend_r50.py +56 -0
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 🐞 Bug report
|
| 2 |
+
description: Create a report to help us reproduce and fix the bug
|
| 3 |
+
title: "[Bug] "
|
| 4 |
+
labels: ['Bug']
|
| 5 |
+
|
| 6 |
+
body:
|
| 7 |
+
- type: checkboxes
|
| 8 |
+
attributes:
|
| 9 |
+
label: Checklist
|
| 10 |
+
options:
|
| 11 |
+
- label: 1. I have searched related issues but cannot get the expected help.
|
| 12 |
+
- label: 2. The bug has not been fixed in the latest version.
|
| 13 |
+
- label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
|
| 14 |
+
- type: textarea
|
| 15 |
+
attributes:
|
| 16 |
+
label: Describe the bug
|
| 17 |
+
description: A clear and concise description of what the bug is.
|
| 18 |
+
validations:
|
| 19 |
+
required: true
|
| 20 |
+
- type: textarea
|
| 21 |
+
attributes:
|
| 22 |
+
label: Reproduction
|
| 23 |
+
description: |
|
| 24 |
+
1. What command or script did you run?
|
| 25 |
+
placeholder: |
|
| 26 |
+
A placeholder for the command.
|
| 27 |
+
validations:
|
| 28 |
+
required: true
|
| 29 |
+
- type: textarea
|
| 30 |
+
attributes:
|
| 31 |
+
label: Environment
|
| 32 |
+
description: |
|
| 33 |
+
1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here.
|
| 34 |
+
2. You may add addition that may be helpful for locating the problem, such as
|
| 35 |
+
- Which **model** are you using?
|
| 36 |
+
- How you installed PyTorch \[e.g., pip, conda, source\]
|
| 37 |
+
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
|
| 38 |
+
placeholder: Environment here.
|
| 39 |
+
render: Shell
|
| 40 |
+
validations:
|
| 41 |
+
required: true
|
| 42 |
+
- type: textarea
|
| 43 |
+
attributes:
|
| 44 |
+
label: Error traceback
|
| 45 |
+
description: |
|
| 46 |
+
If applicable, paste the error trackback here.
|
| 47 |
+
placeholder: Logs and traceback here.
|
| 48 |
+
render: Shell
|
| 49 |
+
- type: markdown
|
| 50 |
+
attributes:
|
| 51 |
+
value: >
|
| 52 |
+
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
|
| 53 |
+
|
| 54 |
+
Thanks for your bug report. We appreciate it a lot.
|
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 🚀 Feature request
|
| 2 |
+
description: Suggest an idea for this project
|
| 3 |
+
title: "[Feature] "
|
| 4 |
+
|
| 5 |
+
body:
|
| 6 |
+
- type: markdown
|
| 7 |
+
attributes:
|
| 8 |
+
value: |
|
| 9 |
+
We strongly appreciate you creating a PR to implement this feature [here](https://github.com/OpenGVLab/InternVL/pulls)!
|
| 10 |
+
If you need our help, please fill in as much of the following form as you're able to.
|
| 11 |
+
|
| 12 |
+
**The less clear the description, the longer it will take to solve it.**
|
| 13 |
+
- type: textarea
|
| 14 |
+
attributes:
|
| 15 |
+
label: Motivation
|
| 16 |
+
description: |
|
| 17 |
+
A clear and concise description of the motivation of the feature.
|
| 18 |
+
Ex1. It is inconvenient when \[....\].
|
| 19 |
+
validations:
|
| 20 |
+
required: true
|
| 21 |
+
- type: textarea
|
| 22 |
+
attributes:
|
| 23 |
+
label: Related resources
|
| 24 |
+
description: |
|
| 25 |
+
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
|
| 26 |
+
- type: textarea
|
| 27 |
+
attributes:
|
| 28 |
+
label: Additional context
|
| 29 |
+
description: |
|
| 30 |
+
Add any other context or screenshots about the feature request here.
|
| 31 |
+
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
|
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 📚 Documentation
|
| 2 |
+
description: Report an issue related to the documentation.
|
| 3 |
+
labels: "kind/doc,status/unconfirmed"
|
| 4 |
+
title: "[Docs] "
|
| 5 |
+
|
| 6 |
+
body:
|
| 7 |
+
- type: textarea
|
| 8 |
+
attributes:
|
| 9 |
+
label: 📚 The doc issue
|
| 10 |
+
description: >
|
| 11 |
+
A clear and concise description the issue.
|
| 12 |
+
validations:
|
| 13 |
+
required: true
|
| 14 |
+
|
| 15 |
+
- type: textarea
|
| 16 |
+
attributes:
|
| 17 |
+
label: Suggest a potential alternative/fix
|
| 18 |
+
description: >
|
| 19 |
+
Tell us how we could improve the documentation in this regard.
|
| 20 |
+
- type: markdown
|
| 21 |
+
attributes:
|
| 22 |
+
value: >
|
| 23 |
+
Thanks for contributing 🎉!
|
VLMEvalKit_old/InternVL/internvl_chat_llava/docs/Data.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Data
|
| 2 |
+
|
| 3 |
+
| Data file name | Size |
|
| 4 |
+
| --- | ---: |
|
| 5 |
+
| [llava_instruct_150k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json) | 229 MB |
|
| 6 |
+
| [llava_instruct_80k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_80k.json) | 229 MB |
|
| 7 |
+
| [conversation_58k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/conversation_58k.json) | 126 MB |
|
| 8 |
+
| [detail_23k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/detail_23k.json) | 20.5 MB |
|
| 9 |
+
| [complex_reasoning_77k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/complex_reasoning_77k.json) | 79.6 MB |
|
| 10 |
+
|
| 11 |
+
### Pretraining Dataset
|
| 12 |
+
The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
|
| 13 |
+
|
| 14 |
+
If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
|
| 15 |
+
|
| 16 |
+
| Data | Chat File | Meta Data | Size |
|
| 17 |
+
| --- | --- | --- | ---: |
|
| 18 |
+
| CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
|
| 19 |
+
| LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
|
| 20 |
+
|
| 21 |
+
**Important notice**: Upon the request from the community, as ~15% images of the original CC-3M dataset are no longer accessible, we upload [`images.zip`](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/images.zip) for better reproducing our work in research community. It must not be used for any other purposes. The use of these images must comply with the CC-3M license. This may be taken down at any time when requested by the original CC-3M dataset owner or owners of the referenced images.
|
| 22 |
+
|
| 23 |
+
### GPT-4 Prompts
|
| 24 |
+
|
| 25 |
+
We provide our prompts and few-shot samples for GPT-4 queries, to better facilitate research in this domain. Please check out the [`prompts`](playground/data/prompts) folder for three kinds of questions: conversation, detail description, and complex reasoning.
|
| 26 |
+
|
| 27 |
+
They are organized in a format of `system_message.txt` for system message, pairs of `abc_caps.txt` for few-shot sample user input, and `abc_conv.txt` for few-shot sample reference output.
|
| 28 |
+
|
| 29 |
+
Note that you may find them in different format. For example, `conversation` is in `jsonl`, and detail description is answer-only. The selected format in our preliminary experiments works slightly better than a limited set of alternatives that we tried: `jsonl`, more natural format, answer-only. If interested, you may try other variants or conduct more careful study in this. Contributions are welcomed!
|
VLMEvalKit_old/InternVL/internvl_chat_llava/docs/LLaVA_Bench.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaVA-Bench [[Download](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild)]
|
| 2 |
+
|
| 3 |
+
**-Introduction-** Large commercial multimodal chatbots have been released in this week, including
|
| 4 |
+
- [Multimodal Bing-Chat by Microsoft](https://blogs.bing.com/search/july-2023/Bing-Chat-Enterprise-announced,-multimodal-Visual-Search-rolling-out-to-Bing-Chat) (July 18, 2023)
|
| 5 |
+
- [Multimodal Bard by Google](https://bard.google.com/).
|
| 6 |
+
|
| 7 |
+
These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less explored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use.
|
| 8 |
+
|
| 9 |
+
## LLaVA-Bench (In-the-Wild *[Ongoing work]*)
|
| 10 |
+
|
| 11 |
+
To evaluate the model's capability in more challenging tasks and generalizability to novel domains, we collect a diverse set of 24 images with 60 questions in total, including indoor and outdoor scenes, memes, paintings, sketches, etc, and associate each image with a highly-detailed and manually-curated description and a proper selection of questions. Such design also assesses the model's robustness to different prompts. In this release, we also categorize questions into three categories: conversation (simple QA), detailed description, and complex reasoning. We continue to expand and improve the diversity of the LLaVA-Bench (In-the-Wild). We manually query Bing-Chat and Bard to get the responses.
|
| 12 |
+
|
| 13 |
+
### Results
|
| 14 |
+
|
| 15 |
+
The score is measured by comparing against a reference answer generated by text-only GPT-4. It is generated by feeding the question, along with the ground truth image annotations as the context. A text-only GPT-4 evaluator rates both answers. We query GPT-4 by putting the reference answer first, and then the answer generated by the candidate model. We upload images at their original resolution to Bard and Bing-Chat to obtain the results.
|
| 16 |
+
|
| 17 |
+
| Approach | Conversation | Detail | Reasoning | Overall |
|
| 18 |
+
|----------------|--------------|--------|-----------|---------|
|
| 19 |
+
| Bard-0718 | 83.7 | 69.7 | 78.7 | 77.8 |
|
| 20 |
+
| Bing-Chat-0629 | 59.6 | 52.2 | 90.1 | 71.5 |
|
| 21 |
+
| LLaVA-13B-v1-336px-0719 (beam=1) | 64.3 | 55.9 | 81.7 | 70.1 |
|
| 22 |
+
| LLaVA-13B-v1-336px-0719 (beam=5) | 68.4 | 59.9 | 84.3 | 73.5 |
|
| 23 |
+
|
| 24 |
+
Note that Bard sometimes refuses to answer questions about images containing humans, and Bing-Chat blurs the human faces in the images. We also provide the benchmark score for the subset without humans.
|
| 25 |
+
|
| 26 |
+
| Approach | Conversation | Detail | Reasoning | Overall |
|
| 27 |
+
|----------------|--------------|--------|-----------|---------|
|
| 28 |
+
| Bard-0718 | 94.9 | 74.3 | 84.3 | 84.6 |
|
| 29 |
+
| Bing-Chat-0629 | 55.8 | 53.6 | 93.5 | 72.6 |
|
| 30 |
+
| LLaVA-13B-v1-336px-0719 (beam=1) | 62.2 | 56.4 | 82.2 | 70.0 |
|
| 31 |
+
| LLaVA-13B-v1-336px-0719 (beam=5) | 65.6 | 61.7 | 85.0 | 73.6 |
|
VLMEvalKit_old/InternVL/internvl_g/eval/evaluate_caption.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as T
|
| 11 |
+
from internvl.model.internvl_stage2 import InternVLConfig, InternVLModel
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from pycocoevalcap.eval import COCOEvalCap
|
| 14 |
+
from pycocotools.coco import COCO
|
| 15 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import LlamaTokenizer
|
| 18 |
+
|
| 19 |
+
ds_collections = {
|
| 20 |
+
'flickr30k': {
|
| 21 |
+
'root': 'data/flickr30k/',
|
| 22 |
+
'annotation': 'data/flickr30k/flickr30k_test_karpathy.json',
|
| 23 |
+
},
|
| 24 |
+
'coco': {
|
| 25 |
+
'root': 'data/coco/',
|
| 26 |
+
'annotation': ['data/coco/annotations/coco_karpathy_test.json',
|
| 27 |
+
'data/coco/annotations/coco_karpathy_test_gt.json'],
|
| 28 |
+
},
|
| 29 |
+
'nocaps': {
|
| 30 |
+
'root': 'data/nocaps/images',
|
| 31 |
+
'annotation': 'data/nocaps/nocaps_val_4500_captions.json',
|
| 32 |
+
},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CaptionDataset(torch.utils.data.Dataset):
|
| 37 |
+
|
| 38 |
+
def __init__(self, name, root, annotation, prompt, input_size=224):
|
| 39 |
+
if name == 'coco':
|
| 40 |
+
self.images = json.load(open(annotation))
|
| 41 |
+
else:
|
| 42 |
+
self.images = json.load(open(annotation))['images']
|
| 43 |
+
self.name = name
|
| 44 |
+
self.prompt = prompt
|
| 45 |
+
self.root = root
|
| 46 |
+
self.transform = T.Compose([
|
| 47 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 48 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 49 |
+
T.ToTensor(),
|
| 50 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.images)
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
if self.name == 'coco':
|
| 58 |
+
filename = self.images[idx]['image']
|
| 59 |
+
image_id = int(filename.split('_')[-1].replace('.jpg', ''))
|
| 60 |
+
image_path = os.path.join(self.root, filename)
|
| 61 |
+
else:
|
| 62 |
+
image_id = self.images[idx]['id']
|
| 63 |
+
if 'file_name' in self.images[idx]:
|
| 64 |
+
image_path = os.path.join(self.root, self.images[idx]['file_name'])
|
| 65 |
+
else:
|
| 66 |
+
image_path = os.path.join(self.root, self.images[idx]['image'])
|
| 67 |
+
image = Image.open(image_path)
|
| 68 |
+
pixel_values = self.transform(image).unsqueeze(0)
|
| 69 |
+
return {
|
| 70 |
+
'image_id': image_id,
|
| 71 |
+
'input_text': self.prompt,
|
| 72 |
+
'pixel_values': pixel_values
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def collate_fn(inputs, tokenizer):
|
| 77 |
+
pixel_values = torch.cat([_['pixel_values'] for _ in inputs], dim=0)
|
| 78 |
+
image_ids = [_['image_id'] for _ in inputs]
|
| 79 |
+
input_texts = [_['input_text'] for _ in inputs]
|
| 80 |
+
input_tokens = tokenizer(input_texts, return_tensors='pt')
|
| 81 |
+
|
| 82 |
+
return pixel_values, image_ids, input_tokens.input_ids, input_tokens.attention_mask
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
| 86 |
+
|
| 87 |
+
def __init__(self, size):
|
| 88 |
+
self._size = int(size)
|
| 89 |
+
assert size > 0
|
| 90 |
+
self._rank = torch.distributed.get_rank()
|
| 91 |
+
self._world_size = torch.distributed.get_world_size()
|
| 92 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def _get_local_indices(total_size, world_size, rank):
|
| 96 |
+
shard_size = total_size // world_size
|
| 97 |
+
left = total_size % world_size
|
| 98 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
| 99 |
+
|
| 100 |
+
begin = sum(shard_sizes[:rank])
|
| 101 |
+
end = min(sum(shard_sizes[:rank + 1]), total_size)
|
| 102 |
+
return range(begin, end)
|
| 103 |
+
|
| 104 |
+
def __iter__(self):
|
| 105 |
+
yield from self._local_indices
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
return len(self._local_indices)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def evaluate_qllama_model():
|
| 112 |
+
prompts = ['English caption:']
|
| 113 |
+
print('prompts:', prompts)
|
| 114 |
+
|
| 115 |
+
config = InternVLConfig.from_pretrained(args.checkpoint)
|
| 116 |
+
model = InternVLModel.from_pretrained(args.checkpoint, config=config).eval()
|
| 117 |
+
model = model.to(torch.float16).cuda()
|
| 118 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
|
| 119 |
+
tokenizer.add_eos_token = False
|
| 120 |
+
|
| 121 |
+
random.seed(args.seed)
|
| 122 |
+
summaries = []
|
| 123 |
+
for prompt in prompts:
|
| 124 |
+
for ds_name in args.datasets:
|
| 125 |
+
annotation = ds_collections[ds_name]['annotation']
|
| 126 |
+
if type(annotation) == list:
|
| 127 |
+
annotation = annotation[0]
|
| 128 |
+
if model.config.force_image_size is not None:
|
| 129 |
+
image_size = model.config.force_image_size
|
| 130 |
+
else:
|
| 131 |
+
image_size = model.config.vision_config.image_size
|
| 132 |
+
dataset = CaptionDataset(
|
| 133 |
+
name=ds_name,
|
| 134 |
+
root=ds_collections[ds_name]['root'],
|
| 135 |
+
annotation=annotation,
|
| 136 |
+
prompt=prompt,
|
| 137 |
+
input_size=image_size,
|
| 138 |
+
)
|
| 139 |
+
dataloader = torch.utils.data.DataLoader(
|
| 140 |
+
dataset=dataset,
|
| 141 |
+
sampler=InferenceSampler(len(dataset)),
|
| 142 |
+
batch_size=args.batch_size,
|
| 143 |
+
num_workers=args.num_workers,
|
| 144 |
+
pin_memory=True,
|
| 145 |
+
drop_last=False,
|
| 146 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
image_ids, captions = [], []
|
| 150 |
+
for _, (pixel_values, ids, input_ids, attention_mask) in tqdm(enumerate(dataloader)):
|
| 151 |
+
pred = model.generate(
|
| 152 |
+
pixel_values=pixel_values.cuda().to(torch.float16),
|
| 153 |
+
input_ids=input_ids.cuda(),
|
| 154 |
+
attention_mask=attention_mask.cuda(),
|
| 155 |
+
do_sample=False,
|
| 156 |
+
num_beams=args.num_beams,
|
| 157 |
+
max_new_tokens=30,
|
| 158 |
+
min_new_tokens=8,
|
| 159 |
+
use_cache=True
|
| 160 |
+
)
|
| 161 |
+
image_ids.extend(ids)
|
| 162 |
+
caption = [tokenizer.decode(_.cpu(), skip_special_tokens=True).strip() for _ in pred]
|
| 163 |
+
captions.extend(caption)
|
| 164 |
+
print(caption)
|
| 165 |
+
|
| 166 |
+
torch.distributed.barrier()
|
| 167 |
+
|
| 168 |
+
world_size = torch.distributed.get_world_size()
|
| 169 |
+
merged_ids = [None for _ in range(world_size)]
|
| 170 |
+
merged_captions = [None for _ in range(world_size)]
|
| 171 |
+
torch.distributed.all_gather_object(merged_ids, image_ids)
|
| 172 |
+
torch.distributed.all_gather_object(merged_captions, captions)
|
| 173 |
+
|
| 174 |
+
merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
|
| 175 |
+
merged_captions = [_ for _ in itertools.chain.from_iterable(merged_captions)]
|
| 176 |
+
average_length = sum(len(x.split()) for x in merged_captions) / len(merged_captions)
|
| 177 |
+
print(f'Average length: {average_length}')
|
| 178 |
+
|
| 179 |
+
if torch.distributed.get_rank() == 0:
|
| 180 |
+
print(f'Evaluating {ds_name} ...')
|
| 181 |
+
|
| 182 |
+
results = []
|
| 183 |
+
for image_id, caption in zip(merged_ids, merged_captions):
|
| 184 |
+
results.append({
|
| 185 |
+
'image_id': int(image_id),
|
| 186 |
+
'caption': caption,
|
| 187 |
+
})
|
| 188 |
+
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
|
| 189 |
+
results_file = f'{ds_name}_{time_prefix}.json'
|
| 190 |
+
results_file = os.path.join(args.out_dir, results_file)
|
| 191 |
+
json.dump(results, open(results_file, 'w'))
|
| 192 |
+
|
| 193 |
+
annotation = ds_collections[ds_name]['annotation']
|
| 194 |
+
if type(annotation) == list:
|
| 195 |
+
annotation = annotation[-1]
|
| 196 |
+
coco = COCO(annotation)
|
| 197 |
+
coco_result = coco.loadRes(results_file)
|
| 198 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
| 199 |
+
coco_eval.evaluate()
|
| 200 |
+
|
| 201 |
+
summary = coco_eval.eval.items()
|
| 202 |
+
print([ds_name, prompt, average_length, summary])
|
| 203 |
+
summaries.append([ds_name, prompt, average_length, summary])
|
| 204 |
+
|
| 205 |
+
torch.distributed.barrier()
|
| 206 |
+
|
| 207 |
+
for summary in summaries:
|
| 208 |
+
print(summary)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == '__main__':
|
| 212 |
+
|
| 213 |
+
parser = argparse.ArgumentParser()
|
| 214 |
+
parser.add_argument('--checkpoint', type=str, default='')
|
| 215 |
+
parser.add_argument('--datasets', type=str, default='coco,flickr30k,nocaps')
|
| 216 |
+
parser.add_argument('--batch-size', type=int, default=1)
|
| 217 |
+
parser.add_argument('--num-workers', type=int, default=1)
|
| 218 |
+
parser.add_argument('--num-beams', type=int, default=5)
|
| 219 |
+
parser.add_argument('--out-dir', type=str, default='results')
|
| 220 |
+
parser.add_argument('--seed', type=int, default=0)
|
| 221 |
+
args = parser.parse_args()
|
| 222 |
+
|
| 223 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
args.datasets = args.datasets.split(',')
|
| 226 |
+
print('datasets:', args.datasets)
|
| 227 |
+
assert args.batch_size == 1, 'Only batch size 1 is supported'
|
| 228 |
+
|
| 229 |
+
torch.distributed.init_process_group(
|
| 230 |
+
backend='nccl',
|
| 231 |
+
world_size=int(os.getenv('WORLD_SIZE', '1')),
|
| 232 |
+
rank=int(os.getenv('RANK', '0')),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
|
| 236 |
+
|
| 237 |
+
evaluate_qllama_model()
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/__init__.py
ADDED
|
File without changes
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torchvision.transforms import InterpolationMode
|
| 11 |
+
from transformers import LlamaTokenizer
|
| 12 |
+
|
| 13 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 14 |
+
from .configuration_internvl import InternVLConfig
|
| 15 |
+
from .modeling_intern_vit import InternVisionModel
|
| 16 |
+
from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
|
| 17 |
+
|
| 18 |
+
__all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
|
| 19 |
+
'InternVLModel', 'InternVL_C', 'InternVL_G']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Prefix the text "summarize:"
|
| 23 |
+
class InternVLTokenizer(nn.Module):
|
| 24 |
+
def __init__(self, model_path):
|
| 25 |
+
super(InternVLTokenizer, self).__init__()
|
| 26 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
| 27 |
+
self.tokenizer.pad_token = ' ' # allow padding
|
| 28 |
+
self.tokenizer.add_eos_token = True
|
| 29 |
+
|
| 30 |
+
def forward(self, text, prefix='summarize:'):
|
| 31 |
+
if type(text) == str:
|
| 32 |
+
text = prefix + text
|
| 33 |
+
elif type(text) == list:
|
| 34 |
+
text = [prefix + item for item in text]
|
| 35 |
+
text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
|
| 36 |
+
return text
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
| 40 |
+
if task == 'retrieval':
|
| 41 |
+
transform = T.Compose([
|
| 42 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 43 |
+
T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
| 44 |
+
T.ToTensor(),
|
| 45 |
+
T.Normalize(mean=mean, std=std)])
|
| 46 |
+
else:
|
| 47 |
+
transform = T.Compose([
|
| 48 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 49 |
+
T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
| 50 |
+
T.CenterCrop(image_size),
|
| 51 |
+
T.ToTensor(),
|
| 52 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
| 53 |
+
return transform
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_internvl_c_huggingface(ckpt_path, device, task):
|
| 57 |
+
model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
|
| 58 |
+
if model.config.use_backbone_lora:
|
| 59 |
+
model.vision_model.merge_and_unload()
|
| 60 |
+
model.vision_model = model.vision_model.model
|
| 61 |
+
if model.config.use_qllama_lora:
|
| 62 |
+
model.qllama.merge_and_unload()
|
| 63 |
+
model.qllama = model.qllama.model
|
| 64 |
+
if model.config.force_image_size is not None:
|
| 65 |
+
image_size = model.config.force_image_size
|
| 66 |
+
else:
|
| 67 |
+
image_size = model.config.vision_config.image_size
|
| 68 |
+
transform = build_transform(task, image_size)
|
| 69 |
+
tokenizer = InternVLTokenizer(ckpt_path)
|
| 70 |
+
return model, transform, tokenizer
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_internvl_g_huggingface(ckpt_path, device, task):
|
| 74 |
+
model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
|
| 75 |
+
if model.config.use_backbone_lora:
|
| 76 |
+
model.vision_model.merge_and_unload()
|
| 77 |
+
model.vision_model = model.vision_model.model
|
| 78 |
+
if model.config.use_qllama_lora:
|
| 79 |
+
model.qllama.merge_and_unload()
|
| 80 |
+
model.qllama = model.qllama.model
|
| 81 |
+
if model.config.force_image_size is not None:
|
| 82 |
+
image_size = model.config.force_image_size
|
| 83 |
+
else:
|
| 84 |
+
image_size = model.config.vision_config.image_size
|
| 85 |
+
transform = build_transform(task, image_size)
|
| 86 |
+
tokenizer = InternVLTokenizer(ckpt_path)
|
| 87 |
+
return model, transform, tokenizer
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/configuration_intern_vit.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InternVisionConfig(PretrainedConfig):
|
| 16 |
+
r"""
|
| 17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
| 18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
| 26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 27 |
+
The size (resolution) of each patch.
|
| 28 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 29 |
+
The size (resolution) of each image.
|
| 30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
| 31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
| 33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
| 35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
| 37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
| 43 |
+
Whether to use flash attention mechanism.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
| 47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 48 |
+
The epsilon used by the layer normalization layers.
|
| 49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
Dropout rate for stochastic depth.
|
| 53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout ratio for the attention probabilities.
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
A factor for layer scale.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_type = 'intern_vit_6b'
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
num_channels=3,
|
| 66 |
+
patch_size=14,
|
| 67 |
+
image_size=224,
|
| 68 |
+
qkv_bias=False,
|
| 69 |
+
hidden_size=3200,
|
| 70 |
+
num_attention_heads=25,
|
| 71 |
+
intermediate_size=12800,
|
| 72 |
+
qk_normalization=True,
|
| 73 |
+
num_hidden_layers=48,
|
| 74 |
+
use_flash_attn=True,
|
| 75 |
+
hidden_act='gelu',
|
| 76 |
+
layer_norm_eps=1e-6,
|
| 77 |
+
dropout=0.0,
|
| 78 |
+
drop_path_rate=0.0,
|
| 79 |
+
attention_dropout=0.0,
|
| 80 |
+
initializer_range=0.02,
|
| 81 |
+
initializer_factor=0.1,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.intermediate_size = intermediate_size
|
| 88 |
+
self.dropout = dropout
|
| 89 |
+
self.drop_path_rate = drop_path_rate
|
| 90 |
+
self.num_hidden_layers = num_hidden_layers
|
| 91 |
+
self.num_attention_heads = num_attention_heads
|
| 92 |
+
self.num_channels = num_channels
|
| 93 |
+
self.patch_size = patch_size
|
| 94 |
+
self.image_size = image_size
|
| 95 |
+
self.initializer_range = initializer_range
|
| 96 |
+
self.initializer_factor = initializer_factor
|
| 97 |
+
self.attention_dropout = attention_dropout
|
| 98 |
+
self.layer_norm_eps = layer_norm_eps
|
| 99 |
+
self.hidden_act = hidden_act
|
| 100 |
+
self.qkv_bias = qkv_bias
|
| 101 |
+
self.qk_normalization = qk_normalization
|
| 102 |
+
self.use_flash_attn = use_flash_attn
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
| 106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 107 |
+
|
| 108 |
+
if 'vision_config' in config_dict:
|
| 109 |
+
config_dict = config_dict['vision_config']
|
| 110 |
+
|
| 111 |
+
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
|
| 112 |
+
logger.warning(
|
| 113 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 114 |
+
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return cls.from_dict(config_dict, **kwargs)
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/flash_attention.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
try: # v1
|
| 7 |
+
from flash_attn.flash_attn_interface import \
|
| 8 |
+
flash_attn_unpadded_qkvpacked_func
|
| 9 |
+
except: # v2
|
| 10 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
| 11 |
+
|
| 12 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlashAttention(nn.Module):
|
| 16 |
+
"""Implement the scaled dot product attention with softmax.
|
| 17 |
+
Arguments
|
| 18 |
+
---------
|
| 19 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 20 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 21 |
+
runtime)
|
| 22 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 23 |
+
(default: 0.0)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.softmax_scale = softmax_scale
|
| 29 |
+
self.dropout_p = attention_dropout
|
| 30 |
+
|
| 31 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
| 32 |
+
max_s=None, need_weights=False):
|
| 33 |
+
"""Implements the multihead softmax attention.
|
| 34 |
+
Arguments
|
| 35 |
+
---------
|
| 36 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
| 37 |
+
if unpadded: (nnz, 3, h, d)
|
| 38 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
| 39 |
+
"""
|
| 40 |
+
assert not need_weights
|
| 41 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 42 |
+
assert qkv.is_cuda
|
| 43 |
+
|
| 44 |
+
if cu_seqlens is None:
|
| 45 |
+
batch_size = qkv.shape[0]
|
| 46 |
+
seqlen = qkv.shape[1]
|
| 47 |
+
if key_padding_mask is None:
|
| 48 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
| 49 |
+
max_s = seqlen
|
| 50 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 51 |
+
device=qkv.device)
|
| 52 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 53 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 54 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 55 |
+
)
|
| 56 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
| 57 |
+
else:
|
| 58 |
+
nheads = qkv.shape[-2]
|
| 59 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
| 60 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
| 61 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
| 62 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
| 63 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 64 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 65 |
+
)
|
| 66 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
| 67 |
+
indices, batch_size, seqlen),
|
| 68 |
+
'b s (h d) -> b s h d', h=nheads)
|
| 69 |
+
else:
|
| 70 |
+
assert max_s is not None
|
| 71 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 72 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 73 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return output, None
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_intern_vit.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from timm.models.layers import DropPath
|
| 13 |
+
from torch import nn
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
| 16 |
+
BaseModelOutputWithPooling)
|
| 17 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from .flash_attention import FlashAttention
|
| 24 |
+
has_flash_attn = True
|
| 25 |
+
except:
|
| 26 |
+
print('FlashAttention is not installed.')
|
| 27 |
+
has_flash_attn = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class InternRMSNorm(nn.Module):
|
| 34 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 37 |
+
self.variance_epsilon = eps
|
| 38 |
+
|
| 39 |
+
def forward(self, hidden_states):
|
| 40 |
+
input_dtype = hidden_states.dtype
|
| 41 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 42 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 43 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 44 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from apex.normalization import FusedRMSNorm
|
| 49 |
+
|
| 50 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
| 51 |
+
|
| 52 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
|
| 53 |
+
except ImportError:
|
| 54 |
+
# using the normal InternRMSNorm
|
| 55 |
+
pass
|
| 56 |
+
except Exception:
|
| 57 |
+
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class InternVisionEmbeddings(nn.Module):
|
| 62 |
+
def __init__(self, config: InternVisionConfig):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.config = config
|
| 65 |
+
self.embed_dim = config.hidden_size
|
| 66 |
+
self.image_size = config.image_size
|
| 67 |
+
self.patch_size = config.patch_size
|
| 68 |
+
|
| 69 |
+
self.class_embedding = nn.Parameter(
|
| 70 |
+
torch.randn(1, 1, self.embed_dim),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.patch_embedding = nn.Conv2d(
|
| 74 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 78 |
+
self.num_positions = self.num_patches + 1
|
| 79 |
+
|
| 80 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 81 |
+
|
| 82 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 83 |
+
batch_size = pixel_values.shape[0]
|
| 84 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 85 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 86 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 87 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 88 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 89 |
+
embeddings = embeddings + self.position_embedding.to(target_dtype)
|
| 90 |
+
return embeddings
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class InternAttention(nn.Module):
|
| 94 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, config: InternVisionConfig):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.embed_dim = config.hidden_size
|
| 100 |
+
self.num_heads = config.num_attention_heads
|
| 101 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
| 102 |
+
if config.use_flash_attn and not has_flash_attn:
|
| 103 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
| 104 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 105 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
| 108 |
+
f' {self.num_heads}).'
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.scale = self.head_dim ** -0.5
|
| 112 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
| 113 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
| 114 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
| 115 |
+
|
| 116 |
+
self.qk_normalization = config.qk_normalization
|
| 117 |
+
|
| 118 |
+
if self.qk_normalization:
|
| 119 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 120 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 121 |
+
|
| 122 |
+
if self.use_flash_attn:
|
| 123 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
| 124 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 125 |
+
|
| 126 |
+
def _naive_attn(self, x):
|
| 127 |
+
B, N, C = x.shape
|
| 128 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 129 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 130 |
+
|
| 131 |
+
if self.qk_normalization:
|
| 132 |
+
B_, H_, N_, D_ = q.shape
|
| 133 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 134 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 135 |
+
|
| 136 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
| 137 |
+
attn = attn.softmax(dim=-1)
|
| 138 |
+
attn = self.attn_drop(attn)
|
| 139 |
+
|
| 140 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 141 |
+
x = self.proj(x)
|
| 142 |
+
x = self.proj_drop(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
| 146 |
+
qkv = self.qkv(x)
|
| 147 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
| 148 |
+
|
| 149 |
+
if self.qk_normalization:
|
| 150 |
+
q, k, v = qkv.unbind(2)
|
| 151 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
| 152 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
| 153 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 154 |
+
|
| 155 |
+
context, _ = self.inner_attn(
|
| 156 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
| 157 |
+
)
|
| 158 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
| 159 |
+
outs = self.proj_drop(outs)
|
| 160 |
+
return outs
|
| 161 |
+
|
| 162 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class InternMLP(nn.Module):
|
| 168 |
+
def __init__(self, config: InternVisionConfig):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.config = config
|
| 171 |
+
self.act = ACT2FN[config.hidden_act]
|
| 172 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 173 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 174 |
+
|
| 175 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
hidden_states = self.fc1(hidden_states)
|
| 177 |
+
hidden_states = self.act(hidden_states)
|
| 178 |
+
hidden_states = self.fc2(hidden_states)
|
| 179 |
+
return hidden_states
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class InternVisionEncoderLayer(nn.Module):
|
| 183 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.embed_dim = config.hidden_size
|
| 186 |
+
self.intermediate_size = config.intermediate_size
|
| 187 |
+
|
| 188 |
+
self.attn = InternAttention(config)
|
| 189 |
+
self.mlp = InternMLP(config)
|
| 190 |
+
self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 191 |
+
self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 192 |
+
|
| 193 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 194 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 195 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 196 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 197 |
+
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
hidden_states: torch.Tensor,
|
| 201 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
| 202 |
+
"""
|
| 203 |
+
Args:
|
| 204 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 205 |
+
"""
|
| 206 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
| 207 |
+
|
| 208 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
|
| 209 |
+
|
| 210 |
+
return hidden_states
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class InternVisionEncoder(nn.Module):
|
| 214 |
+
"""
|
| 215 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 216 |
+
[`InternEncoderLayer`].
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
config (`InternConfig`):
|
| 220 |
+
The corresponding vision configuration for the `InternEncoder`.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, config: InternVisionConfig):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.config = config
|
| 226 |
+
# stochastic depth decay rule
|
| 227 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
| 228 |
+
self.layers = nn.ModuleList([
|
| 229 |
+
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
| 230 |
+
self.gradient_checkpointing = True
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
inputs_embeds,
|
| 235 |
+
output_hidden_states: Optional[bool] = None,
|
| 236 |
+
return_dict: Optional[bool] = None,
|
| 237 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 238 |
+
r"""
|
| 239 |
+
Args:
|
| 240 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 241 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 242 |
+
output_hidden_states (`bool`, *optional*):
|
| 243 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 244 |
+
for more detail.
|
| 245 |
+
return_dict (`bool`, *optional*):
|
| 246 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 247 |
+
"""
|
| 248 |
+
output_hidden_states = (
|
| 249 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 250 |
+
)
|
| 251 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 252 |
+
|
| 253 |
+
encoder_states = () if output_hidden_states else None
|
| 254 |
+
hidden_states = inputs_embeds
|
| 255 |
+
|
| 256 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 257 |
+
if output_hidden_states:
|
| 258 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 259 |
+
if self.gradient_checkpointing and self.training:
|
| 260 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 261 |
+
encoder_layer,
|
| 262 |
+
hidden_states)
|
| 263 |
+
else:
|
| 264 |
+
layer_outputs = encoder_layer(
|
| 265 |
+
hidden_states,
|
| 266 |
+
)
|
| 267 |
+
hidden_states = layer_outputs
|
| 268 |
+
|
| 269 |
+
if output_hidden_states:
|
| 270 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 271 |
+
|
| 272 |
+
if not return_dict:
|
| 273 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
| 274 |
+
return BaseModelOutput(
|
| 275 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class InternVisionModel(PreTrainedModel):
|
| 280 |
+
main_input_name = 'pixel_values'
|
| 281 |
+
config_class = InternVisionConfig
|
| 282 |
+
|
| 283 |
+
def __init__(self, config: InternVisionConfig):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
|
| 287 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 288 |
+
self.encoder = InternVisionEncoder(config)
|
| 289 |
+
|
| 290 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
| 291 |
+
pos_emb = self.embeddings.position_embedding
|
| 292 |
+
_, num_positions, embed_dim = pos_emb.shape
|
| 293 |
+
cls_emb = pos_emb[:, :1, :]
|
| 294 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
| 295 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
| 296 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
| 297 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
| 298 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
| 299 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
| 300 |
+
|
| 301 |
+
def get_input_embeddings(self):
|
| 302 |
+
return self.embeddings
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 307 |
+
output_hidden_states: Optional[bool] = None,
|
| 308 |
+
return_dict: Optional[bool] = None,
|
| 309 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 310 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 311 |
+
output_hidden_states = (
|
| 312 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 313 |
+
)
|
| 314 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 315 |
+
|
| 316 |
+
if pixel_values is None and pixel_embeds is None:
|
| 317 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
| 318 |
+
|
| 319 |
+
if pixel_embeds is not None:
|
| 320 |
+
hidden_states = pixel_embeds
|
| 321 |
+
else:
|
| 322 |
+
if len(pixel_values.shape) == 4:
|
| 323 |
+
hidden_states = self.embeddings(pixel_values)
|
| 324 |
+
else:
|
| 325 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
| 326 |
+
encoder_outputs = self.encoder(
|
| 327 |
+
inputs_embeds=hidden_states,
|
| 328 |
+
output_hidden_states=output_hidden_states,
|
| 329 |
+
return_dict=return_dict,
|
| 330 |
+
)
|
| 331 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 332 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 333 |
+
|
| 334 |
+
if not return_dict:
|
| 335 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 336 |
+
|
| 337 |
+
return BaseModelOutputWithPooling(
|
| 338 |
+
last_hidden_state=last_hidden_state,
|
| 339 |
+
pooler_output=pooled_output,
|
| 340 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 341 |
+
attentions=encoder_outputs.attentions,
|
| 342 |
+
)
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_internvl.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.utils.checkpoint
|
| 15 |
+
from peft import LoraConfig, get_peft_model
|
| 16 |
+
from timm.models.layers import DropPath
|
| 17 |
+
from torch import nn
|
| 18 |
+
from transformers import GenerationConfig
|
| 19 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 20 |
+
from transformers.utils import ModelOutput, logging
|
| 21 |
+
|
| 22 |
+
from .configuration_internvl import InternVLConfig
|
| 23 |
+
from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
|
| 24 |
+
InternVisionModel)
|
| 25 |
+
from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from .flash_attention import FlashAttention # v1/v2
|
| 29 |
+
except:
|
| 30 |
+
print('FlashAttention is not installed.')
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class InternVLPreTrainedModel(PreTrainedModel):
|
| 36 |
+
"""
|
| 37 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 38 |
+
models.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
config_class = InternVLConfig
|
| 42 |
+
base_model_prefix = 'internvl'
|
| 43 |
+
supports_gradient_checkpointing = True
|
| 44 |
+
_keys_to_ignore_on_load_missing = [
|
| 45 |
+
r'position_ids',
|
| 46 |
+
]
|
| 47 |
+
_no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
|
| 48 |
+
_skip_keys_device_placement = 'past_key_values'
|
| 49 |
+
_keep_in_fp32_modules = ['wo']
|
| 50 |
+
|
| 51 |
+
# def _init_weights(self, module):
|
| 52 |
+
# """Initialize the weights"""
|
| 53 |
+
# factor = self.config.initializer_range
|
| 54 |
+
# if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
| 55 |
+
# module.weight.data.normal_(mean=0.0, std=factor)
|
| 56 |
+
# if hasattr(module, 'bias') and module.bias is not None:
|
| 57 |
+
# module.bias.data.zero_()
|
| 58 |
+
# if isinstance(module, InternVisionEmbeddings):
|
| 59 |
+
# if hasattr(self.config, 'vision_config'):
|
| 60 |
+
# factor = self.config.vision_config.initializer_range
|
| 61 |
+
# nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 62 |
+
# nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 63 |
+
# elif isinstance(module, nn.LayerNorm):
|
| 64 |
+
# module.bias.data.zero_()
|
| 65 |
+
# module.weight.data.fill_(1.0)
|
| 66 |
+
# elif isinstance(module, nn.Linear) and module.bias is not None:
|
| 67 |
+
# module.bias.data.zero_()
|
| 68 |
+
|
| 69 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 70 |
+
if isinstance(module, InternVisionModel):
|
| 71 |
+
module.gradient_checkpointing = value
|
| 72 |
+
if isinstance(module, InternVisionEncoder):
|
| 73 |
+
module.gradient_checkpointing = value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CrossAttention(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 79 |
+
proj_drop=0., attn_head_dim=None, out_dim=None):
|
| 80 |
+
super().__init__()
|
| 81 |
+
if out_dim is None:
|
| 82 |
+
out_dim = dim
|
| 83 |
+
self.num_heads = num_heads
|
| 84 |
+
head_dim = dim // num_heads
|
| 85 |
+
if attn_head_dim is not None:
|
| 86 |
+
head_dim = attn_head_dim
|
| 87 |
+
all_head_dim = head_dim * self.num_heads
|
| 88 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 89 |
+
assert all_head_dim == dim
|
| 90 |
+
|
| 91 |
+
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
| 92 |
+
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
| 93 |
+
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
| 94 |
+
|
| 95 |
+
if qkv_bias:
|
| 96 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 97 |
+
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 98 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 99 |
+
else:
|
| 100 |
+
self.q_bias = None
|
| 101 |
+
self.k_bias = None
|
| 102 |
+
self.v_bias = None
|
| 103 |
+
|
| 104 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 105 |
+
self.proj = nn.Linear(all_head_dim, out_dim)
|
| 106 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 107 |
+
|
| 108 |
+
def forward(self, x, k=None, v=None):
|
| 109 |
+
B, N, C = x.shape
|
| 110 |
+
N_k = k.shape[1]
|
| 111 |
+
N_v = v.shape[1]
|
| 112 |
+
|
| 113 |
+
q_bias, k_bias, v_bias = None, None, None
|
| 114 |
+
if self.q_bias is not None:
|
| 115 |
+
q_bias = self.q_bias
|
| 116 |
+
k_bias = self.k_bias
|
| 117 |
+
v_bias = self.v_bias
|
| 118 |
+
|
| 119 |
+
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
| 120 |
+
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
| 121 |
+
|
| 122 |
+
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
| 123 |
+
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
| 124 |
+
|
| 125 |
+
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
| 126 |
+
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
| 127 |
+
|
| 128 |
+
q = q * self.scale
|
| 129 |
+
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
| 130 |
+
|
| 131 |
+
attn = attn.softmax(dim=-1)
|
| 132 |
+
attn = self.attn_drop(attn)
|
| 133 |
+
|
| 134 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 135 |
+
x = self.proj(x)
|
| 136 |
+
x = self.proj_drop(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AttentiveBlock(nn.Module):
|
| 142 |
+
|
| 143 |
+
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 144 |
+
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.norm1_q = norm_layer(dim)
|
| 148 |
+
self.norm1_k = norm_layer(dim)
|
| 149 |
+
self.norm1_v = norm_layer(dim)
|
| 150 |
+
self.cross_attn = CrossAttention(
|
| 151 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
| 152 |
+
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
|
| 153 |
+
|
| 154 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 155 |
+
|
| 156 |
+
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
|
| 157 |
+
x_q = self.norm1_q(x_q + pos_q)
|
| 158 |
+
x_k = self.norm1_k(x_kv + pos_k)
|
| 159 |
+
x_v = self.norm1_v(x_kv)
|
| 160 |
+
x = self.cross_attn(x_q, k=x_k, v=x_v)
|
| 161 |
+
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AttentionPoolingBlock(AttentiveBlock):
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
x_q = x.mean(1, keepdim=True)
|
| 169 |
+
x_kv, pos_q, pos_k = x, 0, 0
|
| 170 |
+
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
|
| 171 |
+
x = x.squeeze(1)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dataclass
|
| 176 |
+
class InternVLModelOutput(ModelOutput):
|
| 177 |
+
"""
|
| 178 |
+
Class defining the outputs of [`InternVLModelOutput`].
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
loss: Optional[torch.FloatTensor] = None
|
| 182 |
+
loss_itm: Optional[torch.FloatTensor] = None
|
| 183 |
+
loss_itc: Optional[torch.FloatTensor] = None
|
| 184 |
+
loss_itg: Optional[torch.FloatTensor] = None
|
| 185 |
+
|
| 186 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 187 |
+
return tuple(
|
| 188 |
+
self[k]
|
| 189 |
+
if k not in ['loss', 'loss_itm', 'loss_itc', 'loss_itg']
|
| 190 |
+
else getattr(self, k).to_tuple()
|
| 191 |
+
for k in self.keys()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GatherLayer(torch.autograd.Function):
|
| 196 |
+
"""Gather tensors from all process, supporting backward propagation.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def forward(ctx, input):
|
| 201 |
+
ctx.save_for_backward(input)
|
| 202 |
+
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
|
| 203 |
+
dist.all_gather(output, input)
|
| 204 |
+
return torch.stack(output, 0)
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def backward(ctx, grads):
|
| 208 |
+
input, = ctx.saved_tensors
|
| 209 |
+
dist.all_reduce(grads)
|
| 210 |
+
grad_out = torch.zeros_like(input)
|
| 211 |
+
grad_out[:] = grads[dist.get_rank()]
|
| 212 |
+
return grad_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class InternVLModel(InternVLPreTrainedModel):
|
| 216 |
+
config_class = InternVLConfig
|
| 217 |
+
main_input_name = 'pixel_values'
|
| 218 |
+
|
| 219 |
+
def __init__(self, config: InternVLConfig):
|
| 220 |
+
super().__init__(config)
|
| 221 |
+
|
| 222 |
+
text_hidden_size = config.qllama_config.hidden_size
|
| 223 |
+
vision_hidden_size = config.vision_config.hidden_size
|
| 224 |
+
clip_embed_dim = config.clip_embed_dim
|
| 225 |
+
attn_pool_num_heads = config.attn_pool_num_heads
|
| 226 |
+
config.qllama_config.num_query_token = config.num_query_token
|
| 227 |
+
self.num_query_token = config.num_query_token
|
| 228 |
+
self.label_smoothing = config.label_smoothing
|
| 229 |
+
|
| 230 |
+
self.vision_model = InternVisionModel(config.vision_config) # frozen
|
| 231 |
+
self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
|
| 232 |
+
self.query_tokens = nn.Parameter( # trainable
|
| 233 |
+
torch.zeros(1, config.num_query_token, text_hidden_size)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
|
| 237 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
|
| 238 |
+
self.clip_projector = AttentionPoolingBlock( # frozen
|
| 239 |
+
dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
|
| 240 |
+
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
|
| 241 |
+
self.clip_projector2 = AttentionPoolingBlock( # trainable
|
| 242 |
+
dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
|
| 243 |
+
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
|
| 244 |
+
self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
|
| 245 |
+
self.gradient_checkpointing = True
|
| 246 |
+
|
| 247 |
+
# Initialize weights and apply final processing
|
| 248 |
+
# self.post_init()
|
| 249 |
+
|
| 250 |
+
if config.use_backbone_lora:
|
| 251 |
+
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=config.use_backbone_lora * 2)
|
| 252 |
+
if config.use_qllama_lora:
|
| 253 |
+
self.wrap_qllama_lora(r=config.use_qllama_lora, lora_alpha=config.use_qllama_lora * 2)
|
| 254 |
+
if config.force_image_size:
|
| 255 |
+
self.vision_model.resize_pos_embeddings(
|
| 256 |
+
old_size=config.vision_config.image_size,
|
| 257 |
+
new_size=config.force_image_size,
|
| 258 |
+
patch_size=config.vision_config.patch_size
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 262 |
+
lora_config = LoraConfig(
|
| 263 |
+
r=r,
|
| 264 |
+
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
|
| 265 |
+
lora_alpha=lora_alpha,
|
| 266 |
+
lora_dropout=lora_dropout,
|
| 267 |
+
)
|
| 268 |
+
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
| 269 |
+
self.vision_model.print_trainable_parameters()
|
| 270 |
+
|
| 271 |
+
def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 272 |
+
lora_config = LoraConfig(
|
| 273 |
+
r=r,
|
| 274 |
+
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
| 275 |
+
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
|
| 276 |
+
lora_alpha=lora_alpha,
|
| 277 |
+
lora_dropout=lora_dropout,
|
| 278 |
+
)
|
| 279 |
+
self.qllama = get_peft_model(self.qllama, lora_config)
|
| 280 |
+
self.qllama.print_trainable_parameters()
|
| 281 |
+
|
| 282 |
+
def get_input_embeddings(self):
|
| 283 |
+
return self.qllama.get_input_embeddings()
|
| 284 |
+
|
| 285 |
+
def set_input_embeddings(self, value):
|
| 286 |
+
self.qllama.set_input_embeddings(value)
|
| 287 |
+
|
| 288 |
+
def set_output_embeddings(self, new_embeddings):
|
| 289 |
+
self.qllama.set_output_embeddings(new_embeddings)
|
| 290 |
+
|
| 291 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 292 |
+
return self.qllama.get_output_embeddings()
|
| 293 |
+
|
| 294 |
+
@torch.no_grad()
|
| 295 |
+
def _prepare_attention_mask(
|
| 296 |
+
self,
|
| 297 |
+
image_attention_mask: torch.LongTensor,
|
| 298 |
+
attention_mask: torch.LongTensor,
|
| 299 |
+
input_embeds: torch.FloatTensor,
|
| 300 |
+
repeat_time: int,
|
| 301 |
+
):
|
| 302 |
+
# itm, itc, itg
|
| 303 |
+
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
| 304 |
+
expand_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 305 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 306 |
+
itm_mask, itc_mask, itg_mask = torch.chunk(expand_mask, repeat_time, dim=0)
|
| 307 |
+
|
| 308 |
+
itc_mask[:, :, :self.num_query_token, self.num_query_token:] = torch.finfo(input_embeds.dtype).min
|
| 309 |
+
itc_mask[:, :, self.num_query_token:, :self.num_query_token] = torch.finfo(input_embeds.dtype).min
|
| 310 |
+
itc_mask_causal = _make_causal_mask(
|
| 311 |
+
(itc_mask.shape[0], itc_mask.shape[2] - self.num_query_token),
|
| 312 |
+
input_embeds.dtype,
|
| 313 |
+
device=input_embeds.device
|
| 314 |
+
)
|
| 315 |
+
# use causal mask for text in itc
|
| 316 |
+
itc_mask[:, :, self.num_query_token:, self.num_query_token:] += itc_mask_causal
|
| 317 |
+
|
| 318 |
+
itg_mask_causal = _make_causal_mask(
|
| 319 |
+
(itg_mask.shape[0], itg_mask.shape[2]),
|
| 320 |
+
input_embeds.dtype,
|
| 321 |
+
device=input_embeds.device
|
| 322 |
+
)
|
| 323 |
+
itg_mask = itg_mask + itg_mask_causal
|
| 324 |
+
itg_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
|
| 325 |
+
attention_mask = torch.cat([itm_mask, itc_mask, itg_mask], dim=0)
|
| 326 |
+
|
| 327 |
+
return attention_mask
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
pixel_values: torch.FloatTensor,
|
| 332 |
+
positive_input_ids: torch.FloatTensor,
|
| 333 |
+
positive_attention_mask: torch.LongTensor,
|
| 334 |
+
negative_input_ids: torch.FloatTensor,
|
| 335 |
+
negative_attention_mask: torch.LongTensor,
|
| 336 |
+
summarize_input_ids: torch.FloatTensor,
|
| 337 |
+
summarize_attention_mask: torch.LongTensor,
|
| 338 |
+
input_ids: torch.FloatTensor,
|
| 339 |
+
attention_mask: torch.LongTensor,
|
| 340 |
+
labels: torch.LongTensor,
|
| 341 |
+
output_attentions: Optional[bool] = None,
|
| 342 |
+
output_hidden_states: Optional[bool] = None,
|
| 343 |
+
return_dict: Optional[bool] = None,
|
| 344 |
+
) -> Union[Tuple, InternVLModelOutput]:
|
| 345 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 346 |
+
|
| 347 |
+
# step 1: forward the images through the vision encoder,
|
| 348 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 349 |
+
vision_outputs = self.vision_model(
|
| 350 |
+
pixel_values=pixel_values,
|
| 351 |
+
output_hidden_states=output_hidden_states,
|
| 352 |
+
return_dict=return_dict)
|
| 353 |
+
image_embeds = vision_outputs[0]
|
| 354 |
+
backbone_embeds = self.clip_projector(image_embeds)
|
| 355 |
+
|
| 356 |
+
# step 2: prepare input_ids and attention_mask for three sub-tasks:
|
| 357 |
+
# 1) image-text matching; 2) image-text contrastive learning; 3) image-grounded text generation.
|
| 358 |
+
batch_size = input_ids.shape[0]
|
| 359 |
+
self.positive_num = batch_size // 2
|
| 360 |
+
input_ids = torch.cat([negative_input_ids[:-self.positive_num], positive_input_ids[-self.positive_num:],
|
| 361 |
+
summarize_input_ids, input_ids], dim=0) # [3 * batch_size, seq_len]
|
| 362 |
+
itm_attention_mask = torch.cat(
|
| 363 |
+
[negative_attention_mask[:-self.positive_num], positive_attention_mask[-self.positive_num:]], dim=0)
|
| 364 |
+
selected = itm_attention_mask.sum(1) - 1
|
| 365 |
+
attention_mask = torch.cat(
|
| 366 |
+
[itm_attention_mask, summarize_attention_mask, attention_mask], dim=0) # [3 * batch_size, seq_len]
|
| 367 |
+
|
| 368 |
+
repeat_time = input_ids.size(0) // batch_size
|
| 369 |
+
# step 3: forward the input_ids and attention_mask through the text encoder.
|
| 370 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 371 |
+
query_tokens = self.query_tokens.repeat(repeat_time * batch_size, 1, 1)
|
| 372 |
+
input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
|
| 373 |
+
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 374 |
+
attention_mask = self._prepare_attention_mask(
|
| 375 |
+
image_attention_mask, attention_mask, input_embeds, repeat_time
|
| 376 |
+
)
|
| 377 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 378 |
+
outputs = self.qllama.model.model.forward_train(
|
| 379 |
+
inputs_embeds=input_embeds,
|
| 380 |
+
vision_hidden_states=image_embeds,
|
| 381 |
+
attention_mask=attention_mask,
|
| 382 |
+
output_attentions=output_attentions,
|
| 383 |
+
output_hidden_states=output_hidden_states,
|
| 384 |
+
return_dict=return_dict,
|
| 385 |
+
repeat_time=repeat_time,
|
| 386 |
+
).last_hidden_state
|
| 387 |
+
else:
|
| 388 |
+
outputs = self.qllama.model.forward_train(
|
| 389 |
+
inputs_embeds=input_embeds,
|
| 390 |
+
vision_hidden_states=image_embeds,
|
| 391 |
+
attention_mask=attention_mask,
|
| 392 |
+
output_attentions=output_attentions,
|
| 393 |
+
output_hidden_states=output_hidden_states,
|
| 394 |
+
return_dict=return_dict,
|
| 395 |
+
repeat_time=repeat_time,
|
| 396 |
+
).last_hidden_state
|
| 397 |
+
image_embeds = outputs[:, :self.num_query_token]
|
| 398 |
+
text_embeds = outputs[:, self.num_query_token:]
|
| 399 |
+
image_itm, image_itc, image_itg = image_embeds.chunk(repeat_time, dim=0)
|
| 400 |
+
text_itm, text_itc, text_itg = text_embeds.chunk(repeat_time, dim=0)
|
| 401 |
+
|
| 402 |
+
###============== Image-Text Matching ===================###
|
| 403 |
+
image_itm = self.itm_head(image_itm)
|
| 404 |
+
logits = image_itm.mean(dim=1)
|
| 405 |
+
itm_labels = torch.cat([
|
| 406 |
+
torch.zeros(batch_size - self.positive_num, dtype=torch.long, device=logits.device),
|
| 407 |
+
torch.ones(self.positive_num, dtype=torch.long, device=logits.device)
|
| 408 |
+
], dim=0)
|
| 409 |
+
itm_labels[selected == 1] = -100 # ignore empty texts
|
| 410 |
+
loss_itm = F.cross_entropy(logits, itm_labels)
|
| 411 |
+
neg_match_acc = ((logits[:batch_size - self.positive_num].argmax(dim=-1) == 0) / (
|
| 412 |
+
batch_size - self.positive_num)).sum()
|
| 413 |
+
pos_match_acc = ((logits[-self.positive_num:].argmax(dim=-1) == 1) / self.positive_num).sum()
|
| 414 |
+
|
| 415 |
+
###============== Image-Text Contrastive ===================###
|
| 416 |
+
image_itc = self.clip_projector2(image_itc)
|
| 417 |
+
|
| 418 |
+
selected = summarize_attention_mask.sum(1) - 1
|
| 419 |
+
text_itc = text_itc[torch.arange(text_itc.shape[0]), selected]
|
| 420 |
+
text_itc = text_itc @ self.text_projection
|
| 421 |
+
|
| 422 |
+
# normalized features
|
| 423 |
+
image_itc = image_itc / image_itc.norm(dim=1, keepdim=True)
|
| 424 |
+
text_itc = text_itc / text_itc.norm(dim=1, keepdim=True)
|
| 425 |
+
image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
|
| 426 |
+
text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
|
| 427 |
+
|
| 428 |
+
# cosine similarity as logits
|
| 429 |
+
logit_scale = self.logit_scale.exp()
|
| 430 |
+
sim_i2t = logit_scale * (image_itc @ text_itc_all.t())
|
| 431 |
+
sim_t2i = logit_scale * (text_itc @ image_itc_all.t())
|
| 432 |
+
bs = image_itc.size(0)
|
| 433 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 434 |
+
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=torch.long, device=sim_i2t.device)
|
| 435 |
+
targets[selected == 4] = -100 # ignore empty texts
|
| 436 |
+
loss_itc = (
|
| 437 |
+
F.cross_entropy(sim_i2t, targets, label_smoothing=self.label_smoothing)
|
| 438 |
+
+ F.cross_entropy(sim_t2i, targets, label_smoothing=self.label_smoothing)
|
| 439 |
+
) / 2
|
| 440 |
+
|
| 441 |
+
###============== Image-grounded Text Generation ===================###
|
| 442 |
+
logits = self.qllama.lm_head(text_itg)
|
| 443 |
+
# Shift so that tokens < n predict n
|
| 444 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 445 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 446 |
+
# Flatten the tokens
|
| 447 |
+
shift_logits = shift_logits.view(-1, self.qllama.config.vocab_size)
|
| 448 |
+
shift_labels = shift_labels.view(-1)
|
| 449 |
+
# Enable model parallelism
|
| 450 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 451 |
+
loss_itg = F.cross_entropy(shift_logits, shift_labels)
|
| 452 |
+
|
| 453 |
+
vision_sim = F.cosine_similarity(backbone_embeds.detach(), image_itc).mean()
|
| 454 |
+
|
| 455 |
+
loss = loss_itm + loss_itc + loss_itg
|
| 456 |
+
if dist.get_rank() == 0:
|
| 457 |
+
print(f'loss: {loss.item()}, loss_itm: {loss_itm.item()}, loss_itc: {loss_itc.item()}, '
|
| 458 |
+
f'loss_itg: {loss_itg.item()}, vision_similarity: {round(vision_sim.item(), 5)}, '
|
| 459 |
+
f'logit scale: {round(1.0 / logit_scale.item(), 5)}, '
|
| 460 |
+
f'pos_match_acc: {round(pos_match_acc.item(), 4)}, '
|
| 461 |
+
f'neg_match_acc: {round(neg_match_acc.item(), 4)}')
|
| 462 |
+
|
| 463 |
+
return InternVLModelOutput(
|
| 464 |
+
loss=loss,
|
| 465 |
+
loss_itc=loss_itc.detach(),
|
| 466 |
+
loss_itm=loss_itm.detach(),
|
| 467 |
+
loss_itg=loss_itg.detach(),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
@torch.no_grad()
|
| 471 |
+
def generate(
|
| 472 |
+
self,
|
| 473 |
+
pixel_values: torch.FloatTensor,
|
| 474 |
+
input_ids: torch.FloatTensor,
|
| 475 |
+
attention_mask: torch.LongTensor,
|
| 476 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 477 |
+
output_hidden_states: Optional[bool] = None,
|
| 478 |
+
return_dict: Optional[bool] = None,
|
| 479 |
+
**generate_kwargs,
|
| 480 |
+
) -> torch.LongTensor:
|
| 481 |
+
|
| 482 |
+
vision_outputs = self.vision_model(
|
| 483 |
+
pixel_values=pixel_values,
|
| 484 |
+
output_hidden_states=output_hidden_states,
|
| 485 |
+
return_dict=return_dict)
|
| 486 |
+
image_embeds = vision_outputs[0]
|
| 487 |
+
|
| 488 |
+
batch_size = image_embeds.shape[0]
|
| 489 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 490 |
+
query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
|
| 491 |
+
input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
|
| 492 |
+
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 493 |
+
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
| 494 |
+
|
| 495 |
+
outputs = self.qllama.generate(
|
| 496 |
+
inputs_embeds=input_embeds,
|
| 497 |
+
attention_mask=attention_mask,
|
| 498 |
+
vision_hidden_states=image_embeds,
|
| 499 |
+
generation_config=generation_config,
|
| 500 |
+
use_zero_attention_mask=True,
|
| 501 |
+
**generate_kwargs,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
return outputs
|
| 505 |
+
|
| 506 |
+
def get_text_features(
|
| 507 |
+
self,
|
| 508 |
+
input_ids: torch.Tensor,
|
| 509 |
+
attention_mask: torch.Tensor,
|
| 510 |
+
output_attentions: Optional[bool] = None,
|
| 511 |
+
output_hidden_states: Optional[bool] = None,
|
| 512 |
+
return_dict: Optional[bool] = None,
|
| 513 |
+
):
|
| 514 |
+
r"""
|
| 515 |
+
Returns:
|
| 516 |
+
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
| 517 |
+
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
| 518 |
+
contains the language model logits, the past key values and the hidden states if
|
| 519 |
+
`output_hidden_states=True`.
|
| 520 |
+
```"""
|
| 521 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 522 |
+
output_hidden_states = (
|
| 523 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 524 |
+
)
|
| 525 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 526 |
+
|
| 527 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 528 |
+
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 529 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 530 |
+
attention_mask += _make_causal_mask(
|
| 531 |
+
(attention_mask.shape[0], attention_mask.shape[2]),
|
| 532 |
+
input_embeds.dtype,
|
| 533 |
+
device=input_embeds.device
|
| 534 |
+
)
|
| 535 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 536 |
+
outputs = self.qllama.model.model.forward_train(
|
| 537 |
+
inputs_embeds=input_embeds,
|
| 538 |
+
vision_hidden_states=None,
|
| 539 |
+
attention_mask=attention_mask,
|
| 540 |
+
output_attentions=output_attentions,
|
| 541 |
+
output_hidden_states=output_hidden_states,
|
| 542 |
+
return_dict=return_dict,
|
| 543 |
+
).last_hidden_state
|
| 544 |
+
else:
|
| 545 |
+
outputs = self.qllama.model.forward_train(
|
| 546 |
+
inputs_embeds=input_embeds,
|
| 547 |
+
vision_hidden_states=None,
|
| 548 |
+
attention_mask=attention_mask,
|
| 549 |
+
output_attentions=output_attentions,
|
| 550 |
+
output_hidden_states=output_hidden_states,
|
| 551 |
+
return_dict=return_dict,
|
| 552 |
+
).last_hidden_state
|
| 553 |
+
return outputs
|
| 554 |
+
|
| 555 |
+
def get_image_features(
|
| 556 |
+
self,
|
| 557 |
+
pixel_values: torch.FloatTensor,
|
| 558 |
+
output_attentions: Optional[bool] = None,
|
| 559 |
+
output_hidden_states: Optional[bool] = None,
|
| 560 |
+
return_dict: Optional[bool] = None,
|
| 561 |
+
):
|
| 562 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 563 |
+
output_hidden_states = (
|
| 564 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 565 |
+
)
|
| 566 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 567 |
+
|
| 568 |
+
vision_outputs = self.vision_model(
|
| 569 |
+
pixel_values=pixel_values,
|
| 570 |
+
output_hidden_states=output_hidden_states,
|
| 571 |
+
return_dict=return_dict)
|
| 572 |
+
image_embeds = vision_outputs[0]
|
| 573 |
+
backbone_embeds = image_embeds
|
| 574 |
+
|
| 575 |
+
batch_size = image_embeds.shape[0]
|
| 576 |
+
input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
|
| 577 |
+
|
| 578 |
+
attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 579 |
+
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 580 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 581 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 582 |
+
outputs = self.qllama.model.model.forward_train(
|
| 583 |
+
inputs_embeds=input_embeds,
|
| 584 |
+
vision_hidden_states=image_embeds,
|
| 585 |
+
attention_mask=attention_mask,
|
| 586 |
+
output_attentions=output_attentions,
|
| 587 |
+
output_hidden_states=output_hidden_states,
|
| 588 |
+
return_dict=return_dict,
|
| 589 |
+
).last_hidden_state
|
| 590 |
+
else:
|
| 591 |
+
outputs = self.qllama.model.forward_train(
|
| 592 |
+
inputs_embeds=input_embeds,
|
| 593 |
+
vision_hidden_states=image_embeds,
|
| 594 |
+
attention_mask=attention_mask,
|
| 595 |
+
output_attentions=output_attentions,
|
| 596 |
+
output_hidden_states=output_hidden_states,
|
| 597 |
+
return_dict=return_dict,
|
| 598 |
+
).last_hidden_state
|
| 599 |
+
return backbone_embeds, outputs
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class InternVL_C(InternVLModel):
|
| 603 |
+
|
| 604 |
+
def encode_image(self, image):
|
| 605 |
+
vision_outputs = self.vision_model(
|
| 606 |
+
pixel_values=image,
|
| 607 |
+
output_hidden_states=False,
|
| 608 |
+
return_dict=True)
|
| 609 |
+
image_embeds = vision_outputs[0]
|
| 610 |
+
image_embeds = self.clip_projector(image_embeds)
|
| 611 |
+
return image_embeds
|
| 612 |
+
|
| 613 |
+
def encode_text(self, text):
|
| 614 |
+
attention_mask = text > 0
|
| 615 |
+
text_embeds = self.get_text_features(
|
| 616 |
+
input_ids=text,
|
| 617 |
+
attention_mask=attention_mask,
|
| 618 |
+
output_attentions=False,
|
| 619 |
+
output_hidden_states=False,
|
| 620 |
+
return_dict=True,
|
| 621 |
+
)
|
| 622 |
+
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
|
| 623 |
+
text_embeds = text_embeds @ self.text_projection
|
| 624 |
+
return text_embeds
|
| 625 |
+
|
| 626 |
+
def forward(self, image, text):
|
| 627 |
+
image_features = self.encode_image(image)
|
| 628 |
+
text_features = self.encode_text(text)
|
| 629 |
+
|
| 630 |
+
# normalized features
|
| 631 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 632 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 633 |
+
|
| 634 |
+
# cosine similarity as logits
|
| 635 |
+
logit_scale = self.logit_scale.exp()
|
| 636 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 637 |
+
logits_per_text = logits_per_image.t()
|
| 638 |
+
|
| 639 |
+
return logits_per_image, logits_per_text
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
class InternVL_G(InternVLModel):
|
| 643 |
+
|
| 644 |
+
def encode_image(self, image):
|
| 645 |
+
backbone_embeds, image_embeds = self.get_image_features(
|
| 646 |
+
pixel_values=image,
|
| 647 |
+
output_hidden_states=False,
|
| 648 |
+
return_dict=True,
|
| 649 |
+
)
|
| 650 |
+
backbone_embeds = self.clip_projector(backbone_embeds)
|
| 651 |
+
image_embeds = self.clip_projector2(image_embeds)
|
| 652 |
+
# ensemble
|
| 653 |
+
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
|
| 654 |
+
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
|
| 655 |
+
image_embeds = image_embeds + backbone_embeds
|
| 656 |
+
return image_embeds
|
| 657 |
+
|
| 658 |
+
def encode_text(self, text):
|
| 659 |
+
attention_mask = text > 0
|
| 660 |
+
text_embeds = self.get_text_features(
|
| 661 |
+
input_ids=text,
|
| 662 |
+
attention_mask=attention_mask,
|
| 663 |
+
output_attentions=False,
|
| 664 |
+
output_hidden_states=False,
|
| 665 |
+
return_dict=True,
|
| 666 |
+
)
|
| 667 |
+
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
|
| 668 |
+
text_embeds = text_embeds @ self.text_projection
|
| 669 |
+
return text_embeds
|
| 670 |
+
|
| 671 |
+
def forward(self, image, text):
|
| 672 |
+
image_features = self.encode_image(image)
|
| 673 |
+
text_features = self.encode_text(text)
|
| 674 |
+
|
| 675 |
+
# normalized features
|
| 676 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 677 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 678 |
+
|
| 679 |
+
# cosine similarity as logits
|
| 680 |
+
logit_scale = self.logit_scale.exp()
|
| 681 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 682 |
+
logits_per_text = logits_per_image.t()
|
| 683 |
+
|
| 684 |
+
return logits_per_image, logits_per_text
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_qllama.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 4 |
+
# and OPT implementations in this library. It has been modified from its
|
| 5 |
+
# original forms to accommodate minor architectural differences compared
|
| 6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
""" PyTorch QLLaMA model."""
|
| 20 |
+
import math
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import CrossEntropyLoss
|
| 27 |
+
from transformers import LlamaConfig
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 30 |
+
CausalLMOutputWithPast)
|
| 31 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
+
from transformers.utils import (add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward, logging,
|
| 34 |
+
replace_return_docstrings)
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
_CONFIG_FOR_DOC = 'LlamaConfig'
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 42 |
+
def _make_causal_mask(
|
| 43 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Make causal mask used for bi-directional self-attention.
|
| 47 |
+
"""
|
| 48 |
+
bsz, tgt_len = input_ids_shape
|
| 49 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 50 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 51 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 52 |
+
mask = mask.to(dtype)
|
| 53 |
+
|
| 54 |
+
if past_key_values_length > 0:
|
| 55 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 56 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
| 60 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 61 |
+
"""
|
| 62 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 63 |
+
"""
|
| 64 |
+
bsz, src_len = mask.size()
|
| 65 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 66 |
+
|
| 67 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 68 |
+
|
| 69 |
+
inverted_mask = 1.0 - expanded_mask
|
| 70 |
+
|
| 71 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LlamaRMSNorm(nn.Module):
|
| 75 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 76 |
+
"""
|
| 77 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
| 78 |
+
"""
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 81 |
+
self.variance_epsilon = eps
|
| 82 |
+
|
| 83 |
+
def forward(self, hidden_states):
|
| 84 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 85 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 86 |
+
|
| 87 |
+
# convert into half-precision if necessary
|
| 88 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 89 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 90 |
+
|
| 91 |
+
return self.weight * hidden_states
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
from functools import partial
|
| 96 |
+
|
| 97 |
+
from apex.normalization import FusedRMSNorm
|
| 98 |
+
|
| 99 |
+
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
|
| 100 |
+
print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
|
| 101 |
+
except ImportError:
|
| 102 |
+
# using the normal LlamaRMSNorm
|
| 103 |
+
pass
|
| 104 |
+
except Exception:
|
| 105 |
+
print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
| 110 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 113 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 114 |
+
|
| 115 |
+
# Build here to make `torch.jit.trace` work.
|
| 116 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 117 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 118 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
| 119 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 120 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 121 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 122 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 123 |
+
|
| 124 |
+
def forward(self, x, seq_len=None):
|
| 125 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 126 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
| 127 |
+
if seq_len > self.max_seq_len_cached:
|
| 128 |
+
self.max_seq_len_cached = seq_len
|
| 129 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
| 130 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
| 131 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 132 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 133 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 134 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 135 |
+
return (
|
| 136 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 137 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class FixedLlamaRotaryEmbedding(torch.nn.Module):
|
| 142 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.dim = dim
|
| 146 |
+
self.max_position_embeddings = max_position_embeddings
|
| 147 |
+
self.base = base
|
| 148 |
+
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 149 |
+
|
| 150 |
+
# Build here to make `torch.jit.trace` work.
|
| 151 |
+
self._set_cos_sin_cache(
|
| 152 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 156 |
+
self.max_seq_len_cached = seq_len
|
| 157 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 158 |
+
|
| 159 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 160 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 161 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 162 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 163 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x, seq_len=None):
|
| 166 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 167 |
+
if seq_len > self.max_seq_len_cached:
|
| 168 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 169 |
+
|
| 170 |
+
return (
|
| 171 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 172 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def rotate_half(x):
|
| 180 |
+
"""Rotates half the hidden dims of the input."""
|
| 181 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 182 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 183 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 187 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
| 188 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
| 189 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
| 190 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
| 191 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 192 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 193 |
+
return q_embed, k_embed
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LlamaMLP(nn.Module):
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
hidden_size: int,
|
| 200 |
+
intermediate_size: int,
|
| 201 |
+
hidden_act: str,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 205 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 206 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 207 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class LlamaAttention(nn.Module):
|
| 214 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, config: LlamaConfig):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.hidden_size = config.hidden_size
|
| 220 |
+
self.num_heads = config.num_attention_heads
|
| 221 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 222 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 223 |
+
|
| 224 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
|
| 227 |
+
f' and `num_heads`: {self.num_heads}).'
|
| 228 |
+
)
|
| 229 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 230 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 231 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 232 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 233 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
| 234 |
+
|
| 235 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 236 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
hidden_states: torch.Tensor,
|
| 241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 243 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 244 |
+
output_attentions: bool = False,
|
| 245 |
+
use_cache: bool = False,
|
| 246 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 247 |
+
bsz, q_len, _ = hidden_states.size()
|
| 248 |
+
|
| 249 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 250 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 251 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 252 |
+
|
| 253 |
+
kv_seq_len = key_states.shape[-2]
|
| 254 |
+
if past_key_value is not None:
|
| 255 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 256 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 257 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 258 |
+
# [bsz, nh, t, hd]
|
| 259 |
+
|
| 260 |
+
if past_key_value is not None:
|
| 261 |
+
# reuse k, v, self_attention
|
| 262 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 263 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 264 |
+
|
| 265 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 266 |
+
|
| 267 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 268 |
+
|
| 269 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
|
| 272 |
+
f' {attn_weights.size()}'
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if attention_mask is not None:
|
| 276 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
|
| 279 |
+
)
|
| 280 |
+
attn_weights = attn_weights + attention_mask
|
| 281 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 282 |
+
|
| 283 |
+
# upcast attention to fp32
|
| 284 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 285 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 286 |
+
|
| 287 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
|
| 290 |
+
f' {attn_output.size()}'
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
attn_output = attn_output.transpose(1, 2)
|
| 294 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 295 |
+
|
| 296 |
+
attn_output = self.o_proj(attn_output)
|
| 297 |
+
|
| 298 |
+
if not output_attentions:
|
| 299 |
+
attn_weights = None
|
| 300 |
+
|
| 301 |
+
return attn_output, attn_weights, past_key_value
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class LlamaCrossAttention(nn.Module):
|
| 305 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, config: LlamaConfig):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.config = config
|
| 310 |
+
self.hidden_size = config.hidden_size
|
| 311 |
+
self.num_heads = config.num_attention_heads
|
| 312 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 313 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 314 |
+
self.vision_hidden_size = 3200
|
| 315 |
+
|
| 316 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
|
| 319 |
+
f' and `num_heads`: {self.num_heads}).'
|
| 320 |
+
)
|
| 321 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 322 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 323 |
+
self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 324 |
+
|
| 325 |
+
self.k_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 326 |
+
self.v_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 327 |
+
self.norm2 = LlamaRMSNorm(self.vision_hidden_size, eps=config.rms_norm_eps)
|
| 328 |
+
|
| 329 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 330 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
vision_hidden_states: torch.Tensor,
|
| 336 |
+
repeat_time: int = 1,
|
| 337 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 338 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 339 |
+
output_attentions: bool = False,
|
| 340 |
+
use_cache: bool = False,
|
| 341 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 342 |
+
hidden_states = self.norm1(hidden_states)
|
| 343 |
+
|
| 344 |
+
bsz, q_len, _ = hidden_states.size()
|
| 345 |
+
|
| 346 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 347 |
+
|
| 348 |
+
vision_hidden_states = self.norm2(vision_hidden_states)
|
| 349 |
+
|
| 350 |
+
bs_v, kv_len, _ = vision_hidden_states.size()
|
| 351 |
+
|
| 352 |
+
key_states = self.k_proj(vision_hidden_states).view(
|
| 353 |
+
bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 354 |
+
value_states = self.v_proj(vision_hidden_states).view(
|
| 355 |
+
bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 356 |
+
|
| 357 |
+
key_states = key_states.repeat(repeat_time, 1, 1, 1)
|
| 358 |
+
value_states = value_states.repeat(repeat_time, 1, 1, 1)
|
| 359 |
+
|
| 360 |
+
kv_seq_len = key_states.shape[-2]
|
| 361 |
+
if past_key_value is not None:
|
| 362 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 363 |
+
|
| 364 |
+
if past_key_value is not None:
|
| 365 |
+
# reuse k, v, self_attention
|
| 366 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 367 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 368 |
+
|
| 369 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 370 |
+
|
| 371 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 372 |
+
|
| 373 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
|
| 376 |
+
f' {attn_weights.size()}'
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if attention_mask is not None:
|
| 380 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
|
| 383 |
+
)
|
| 384 |
+
attn_weights = attn_weights + attention_mask
|
| 385 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 386 |
+
|
| 387 |
+
# upcast attention to fp32
|
| 388 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 389 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 390 |
+
|
| 391 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 392 |
+
raise ValueError(
|
| 393 |
+
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
|
| 394 |
+
f' {attn_output.size()}'
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
attn_output = attn_output.transpose(1, 2)
|
| 398 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 399 |
+
|
| 400 |
+
attn_output = self.o_proj(attn_output)
|
| 401 |
+
|
| 402 |
+
if not output_attentions:
|
| 403 |
+
attn_weights = None
|
| 404 |
+
|
| 405 |
+
return attn_output, attn_weights, past_key_value
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class LlamaDecoderLayer(nn.Module):
|
| 409 |
+
def __init__(self, config: LlamaConfig, use_cross_attn: bool):
|
| 410 |
+
super().__init__()
|
| 411 |
+
self.hidden_size = config.hidden_size
|
| 412 |
+
self.self_attn = LlamaAttention(config=config)
|
| 413 |
+
self.cross_attn = LlamaCrossAttention(config=config) if use_cross_attn else None
|
| 414 |
+
self.mlp = LlamaMLP(
|
| 415 |
+
hidden_size=self.hidden_size,
|
| 416 |
+
intermediate_size=config.intermediate_size,
|
| 417 |
+
hidden_act=config.hidden_act,
|
| 418 |
+
)
|
| 419 |
+
self.num_query_token = 96
|
| 420 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 421 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 422 |
+
|
| 423 |
+
def forward(
|
| 424 |
+
self,
|
| 425 |
+
hidden_states: torch.Tensor,
|
| 426 |
+
vision_hidden_states: torch.Tensor,
|
| 427 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 428 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 429 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 430 |
+
output_attentions: Optional[bool] = False,
|
| 431 |
+
use_cache: Optional[bool] = False,
|
| 432 |
+
repeat_time: int = 1,
|
| 433 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 434 |
+
"""
|
| 435 |
+
Args:
|
| 436 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 437 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 438 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 439 |
+
output_attentions (`bool`, *optional*):
|
| 440 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 441 |
+
returned tensors for more detail.
|
| 442 |
+
use_cache (`bool`, *optional*):
|
| 443 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 444 |
+
(see `past_key_values`).
|
| 445 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
residual = hidden_states
|
| 449 |
+
|
| 450 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 451 |
+
|
| 452 |
+
# Self Attention
|
| 453 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 454 |
+
hidden_states=hidden_states,
|
| 455 |
+
attention_mask=attention_mask,
|
| 456 |
+
position_ids=position_ids,
|
| 457 |
+
past_key_value=past_key_value,
|
| 458 |
+
output_attentions=output_attentions,
|
| 459 |
+
use_cache=use_cache,
|
| 460 |
+
)
|
| 461 |
+
hidden_states = residual + hidden_states
|
| 462 |
+
|
| 463 |
+
# when using generate function and cache mode, the size of hidden_states is 1,
|
| 464 |
+
# so we should not use cross attention
|
| 465 |
+
if self.cross_attn is not None and hidden_states.size(1) >= self.num_query_token \
|
| 466 |
+
and vision_hidden_states is not None:
|
| 467 |
+
query_feats = hidden_states[:, :self.num_query_token, :]
|
| 468 |
+
text_feats = hidden_states[:, self.num_query_token:, :]
|
| 469 |
+
residual = query_feats
|
| 470 |
+
query_feats, _, _ = self.cross_attn(
|
| 471 |
+
hidden_states=query_feats,
|
| 472 |
+
vision_hidden_states=vision_hidden_states,
|
| 473 |
+
attention_mask=None, # not use attention mask in cross attention
|
| 474 |
+
past_key_value=past_key_value,
|
| 475 |
+
output_attentions=output_attentions,
|
| 476 |
+
use_cache=use_cache,
|
| 477 |
+
repeat_time=repeat_time,
|
| 478 |
+
)
|
| 479 |
+
query_feats = residual + query_feats
|
| 480 |
+
hidden_states = torch.cat([query_feats, text_feats], dim=1)
|
| 481 |
+
|
| 482 |
+
# Fully Connected
|
| 483 |
+
residual = hidden_states
|
| 484 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 485 |
+
hidden_states = self.mlp(hidden_states)
|
| 486 |
+
hidden_states = residual + hidden_states
|
| 487 |
+
|
| 488 |
+
outputs = (hidden_states,)
|
| 489 |
+
|
| 490 |
+
if output_attentions:
|
| 491 |
+
outputs += (self_attn_weights,)
|
| 492 |
+
|
| 493 |
+
if use_cache:
|
| 494 |
+
outputs += (present_key_value,)
|
| 495 |
+
|
| 496 |
+
return outputs
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
LLAMA_START_DOCSTRING = r"""
|
| 500 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 501 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 502 |
+
etc.)
|
| 503 |
+
|
| 504 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 505 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 506 |
+
and behavior.
|
| 507 |
+
|
| 508 |
+
Parameters:
|
| 509 |
+
config ([`LlamaConfig`]):
|
| 510 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 511 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 512 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@add_start_docstrings(
|
| 517 |
+
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
|
| 518 |
+
LLAMA_START_DOCSTRING,
|
| 519 |
+
)
|
| 520 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
| 521 |
+
config_class = LlamaConfig
|
| 522 |
+
base_model_prefix = 'model'
|
| 523 |
+
supports_gradient_checkpointing = True
|
| 524 |
+
_no_split_modules = ['LlamaDecoderLayer']
|
| 525 |
+
_keys_to_ignore_on_load_unexpected = [r'decoder\.version']
|
| 526 |
+
|
| 527 |
+
def _init_weights(self, module):
|
| 528 |
+
std = self.config.initializer_range
|
| 529 |
+
if isinstance(module, nn.Linear):
|
| 530 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 531 |
+
if module.bias is not None:
|
| 532 |
+
module.bias.data.zero_()
|
| 533 |
+
elif isinstance(module, nn.Embedding):
|
| 534 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 535 |
+
if module.padding_idx is not None:
|
| 536 |
+
module.weight.data[module.padding_idx].zero_()
|
| 537 |
+
|
| 538 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 539 |
+
if isinstance(module, LlamaModel):
|
| 540 |
+
module.gradient_checkpointing = value
|
| 541 |
+
if isinstance(module, LlamaDecoderLayer):
|
| 542 |
+
module.gradient_checkpointing = value
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
| 546 |
+
Args:
|
| 547 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 548 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 549 |
+
it.
|
| 550 |
+
|
| 551 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 552 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 553 |
+
|
| 554 |
+
[What are input IDs?](../glossary#input-ids)
|
| 555 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 556 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 557 |
+
|
| 558 |
+
- 1 for tokens that are **not masked**,
|
| 559 |
+
- 0 for tokens that are **masked**.
|
| 560 |
+
|
| 561 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 562 |
+
|
| 563 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 564 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 565 |
+
|
| 566 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 567 |
+
`past_key_values`).
|
| 568 |
+
|
| 569 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 570 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 571 |
+
information on the default strategy.
|
| 572 |
+
|
| 573 |
+
- 1 indicates the head is **not masked**,
|
| 574 |
+
- 0 indicates the head is **masked**.
|
| 575 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 576 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 577 |
+
config.n_positions - 1]`.
|
| 578 |
+
|
| 579 |
+
[What are position IDs?](../glossary#position-ids)
|
| 580 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 581 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 582 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
| 583 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 584 |
+
|
| 585 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 586 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 587 |
+
|
| 588 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 589 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 590 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 591 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 592 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 593 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 594 |
+
model's internal embedding lookup matrix.
|
| 595 |
+
use_cache (`bool`, *optional*):
|
| 596 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 597 |
+
`past_key_values`).
|
| 598 |
+
output_attentions (`bool`, *optional*):
|
| 599 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 600 |
+
tensors for more detail.
|
| 601 |
+
output_hidden_states (`bool`, *optional*):
|
| 602 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 603 |
+
more detail.
|
| 604 |
+
return_dict (`bool`, *optional*):
|
| 605 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 606 |
+
"""
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
@add_start_docstrings(
|
| 610 |
+
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
|
| 611 |
+
LLAMA_START_DOCSTRING,
|
| 612 |
+
)
|
| 613 |
+
class LlamaModel(LlamaPreTrainedModel):
|
| 614 |
+
"""
|
| 615 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
config: LlamaConfig
|
| 619 |
+
"""
|
| 620 |
+
|
| 621 |
+
def __init__(self, config: LlamaConfig):
|
| 622 |
+
super().__init__(config)
|
| 623 |
+
self.padding_idx = config.pad_token_id
|
| 624 |
+
self.vocab_size = config.vocab_size
|
| 625 |
+
self.cross_attention_frequency = config.cross_attention_frequency
|
| 626 |
+
self.num_query_token = config.num_query_token
|
| 627 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 628 |
+
use_cross_attn = [idx % self.cross_attention_frequency == 0 for idx in range(config.num_hidden_layers)]
|
| 629 |
+
self.layers = nn.ModuleList(
|
| 630 |
+
[LlamaDecoderLayer(config, use_cross_attn[idx]) for idx in range(config.num_hidden_layers)])
|
| 631 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 632 |
+
self.gradient_checkpointing = False
|
| 633 |
+
# Initialize weights and apply final processing
|
| 634 |
+
# self.post_init()
|
| 635 |
+
|
| 636 |
+
def get_input_embeddings(self):
|
| 637 |
+
return self.embed_tokens
|
| 638 |
+
|
| 639 |
+
def set_input_embeddings(self, value):
|
| 640 |
+
self.embed_tokens = value
|
| 641 |
+
|
| 642 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
| 643 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 644 |
+
# create causal mask
|
| 645 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 646 |
+
combined_attention_mask = None
|
| 647 |
+
if input_shape[-1] > 1:
|
| 648 |
+
combined_attention_mask = _make_causal_mask(
|
| 649 |
+
input_shape,
|
| 650 |
+
inputs_embeds.dtype,
|
| 651 |
+
device=inputs_embeds.device,
|
| 652 |
+
past_key_values_length=past_key_values_length,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if attention_mask is not None:
|
| 656 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 657 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 658 |
+
inputs_embeds.device
|
| 659 |
+
)
|
| 660 |
+
combined_attention_mask = (
|
| 661 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
return combined_attention_mask
|
| 665 |
+
|
| 666 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
input_ids: torch.LongTensor = None,
|
| 670 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 671 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 672 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 673 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 674 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 675 |
+
repeat_time: Optional[int] = 1,
|
| 676 |
+
use_cache: Optional[bool] = None,
|
| 677 |
+
output_attentions: Optional[bool] = None,
|
| 678 |
+
output_hidden_states: Optional[bool] = None,
|
| 679 |
+
use_zero_attention_mask: Optional[bool] = None,
|
| 680 |
+
return_dict: Optional[bool] = None,
|
| 681 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 682 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 683 |
+
output_hidden_states = (
|
| 684 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 685 |
+
)
|
| 686 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 687 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 688 |
+
|
| 689 |
+
# retrieve input_ids and inputs_embeds
|
| 690 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 691 |
+
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
|
| 692 |
+
elif input_ids is not None:
|
| 693 |
+
batch_size, seq_length = input_ids.shape
|
| 694 |
+
elif inputs_embeds is not None:
|
| 695 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 696 |
+
else:
|
| 697 |
+
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
|
| 698 |
+
seq_length_with_past = seq_length
|
| 699 |
+
past_key_values_length = 0
|
| 700 |
+
|
| 701 |
+
if past_key_values is not None:
|
| 702 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 703 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 704 |
+
|
| 705 |
+
if position_ids is None:
|
| 706 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 707 |
+
position_ids = torch.arange(
|
| 708 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 709 |
+
)
|
| 710 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 711 |
+
else:
|
| 712 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 713 |
+
|
| 714 |
+
if inputs_embeds is None:
|
| 715 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 716 |
+
# embed positions
|
| 717 |
+
if attention_mask is None:
|
| 718 |
+
attention_mask = torch.ones(
|
| 719 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 720 |
+
)
|
| 721 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 722 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 723 |
+
)
|
| 724 |
+
if use_zero_attention_mask:
|
| 725 |
+
attention_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
|
| 726 |
+
|
| 727 |
+
hidden_states = inputs_embeds
|
| 728 |
+
|
| 729 |
+
if self.gradient_checkpointing and self.training:
|
| 730 |
+
if use_cache:
|
| 731 |
+
logger.warning_once(
|
| 732 |
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
| 733 |
+
)
|
| 734 |
+
use_cache = False
|
| 735 |
+
|
| 736 |
+
# decoder layers
|
| 737 |
+
all_hidden_states = () if output_hidden_states else None
|
| 738 |
+
all_self_attns = () if output_attentions else None
|
| 739 |
+
next_decoder_cache = () if use_cache else None
|
| 740 |
+
|
| 741 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 742 |
+
if output_hidden_states:
|
| 743 |
+
all_hidden_states += (hidden_states,)
|
| 744 |
+
|
| 745 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 746 |
+
|
| 747 |
+
layer_outputs = decoder_layer(
|
| 748 |
+
hidden_states,
|
| 749 |
+
vision_hidden_states,
|
| 750 |
+
attention_mask=attention_mask,
|
| 751 |
+
position_ids=position_ids,
|
| 752 |
+
past_key_value=past_key_value,
|
| 753 |
+
output_attentions=output_attentions,
|
| 754 |
+
use_cache=use_cache,
|
| 755 |
+
repeat_time=repeat_time,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
hidden_states = layer_outputs[0]
|
| 759 |
+
|
| 760 |
+
if use_cache:
|
| 761 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 762 |
+
|
| 763 |
+
if output_attentions:
|
| 764 |
+
all_self_attns += (layer_outputs[1],)
|
| 765 |
+
|
| 766 |
+
hidden_states = self.norm(hidden_states)
|
| 767 |
+
|
| 768 |
+
# add hidden states from the last decoder layer
|
| 769 |
+
if output_hidden_states:
|
| 770 |
+
all_hidden_states += (hidden_states,)
|
| 771 |
+
|
| 772 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 773 |
+
if not return_dict:
|
| 774 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 775 |
+
return BaseModelOutputWithPast(
|
| 776 |
+
last_hidden_state=hidden_states,
|
| 777 |
+
past_key_values=next_cache,
|
| 778 |
+
hidden_states=all_hidden_states,
|
| 779 |
+
attentions=all_self_attns,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 783 |
+
def forward_train(
|
| 784 |
+
self,
|
| 785 |
+
input_ids: torch.LongTensor = None,
|
| 786 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 787 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 788 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 789 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 790 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 791 |
+
repeat_time: Optional[int] = 1,
|
| 792 |
+
use_cache: Optional[bool] = None,
|
| 793 |
+
output_attentions: Optional[bool] = None,
|
| 794 |
+
output_hidden_states: Optional[bool] = None,
|
| 795 |
+
return_dict: Optional[bool] = None,
|
| 796 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 797 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 798 |
+
output_hidden_states = (
|
| 799 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 800 |
+
)
|
| 801 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 802 |
+
|
| 803 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 804 |
+
|
| 805 |
+
# retrieve input_ids and inputs_embeds
|
| 806 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 807 |
+
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
|
| 808 |
+
elif input_ids is not None:
|
| 809 |
+
batch_size, seq_length = input_ids.shape
|
| 810 |
+
elif inputs_embeds is not None:
|
| 811 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 812 |
+
else:
|
| 813 |
+
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
|
| 814 |
+
|
| 815 |
+
seq_length_with_past = seq_length
|
| 816 |
+
past_key_values_length = 0
|
| 817 |
+
|
| 818 |
+
if past_key_values is not None:
|
| 819 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 820 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 821 |
+
|
| 822 |
+
if position_ids is None:
|
| 823 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 824 |
+
position_ids = torch.arange(
|
| 825 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 826 |
+
)
|
| 827 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 828 |
+
else:
|
| 829 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 830 |
+
|
| 831 |
+
if inputs_embeds is None:
|
| 832 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 833 |
+
# embed positions
|
| 834 |
+
# if attention_mask is None:
|
| 835 |
+
# attention_mask = torch.ones(
|
| 836 |
+
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 837 |
+
# )
|
| 838 |
+
# attention_mask = self._prepare_decoder_attention_mask(
|
| 839 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 840 |
+
# )
|
| 841 |
+
hidden_states = inputs_embeds
|
| 842 |
+
|
| 843 |
+
if self.gradient_checkpointing and self.training:
|
| 844 |
+
if use_cache:
|
| 845 |
+
logger.warning_once(
|
| 846 |
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
| 847 |
+
)
|
| 848 |
+
use_cache = False
|
| 849 |
+
|
| 850 |
+
# decoder layers
|
| 851 |
+
all_hidden_states = () if output_hidden_states else None
|
| 852 |
+
all_self_attns = () if output_attentions else None
|
| 853 |
+
next_decoder_cache = () if use_cache else None
|
| 854 |
+
|
| 855 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 856 |
+
if output_hidden_states:
|
| 857 |
+
all_hidden_states += (hidden_states,)
|
| 858 |
+
|
| 859 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 860 |
+
|
| 861 |
+
if self.gradient_checkpointing and self.training:
|
| 862 |
+
|
| 863 |
+
def create_custom_forward(module):
|
| 864 |
+
def custom_forward(*inputs):
|
| 865 |
+
# None for past_key_value
|
| 866 |
+
return module(*inputs, output_attentions, None, repeat_time)
|
| 867 |
+
|
| 868 |
+
return custom_forward
|
| 869 |
+
|
| 870 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 871 |
+
create_custom_forward(decoder_layer),
|
| 872 |
+
hidden_states,
|
| 873 |
+
vision_hidden_states,
|
| 874 |
+
attention_mask,
|
| 875 |
+
position_ids,
|
| 876 |
+
None,
|
| 877 |
+
)
|
| 878 |
+
else:
|
| 879 |
+
layer_outputs = decoder_layer(
|
| 880 |
+
hidden_states,
|
| 881 |
+
vision_hidden_states,
|
| 882 |
+
attention_mask=attention_mask,
|
| 883 |
+
position_ids=position_ids,
|
| 884 |
+
past_key_value=past_key_value,
|
| 885 |
+
output_attentions=output_attentions,
|
| 886 |
+
use_cache=use_cache,
|
| 887 |
+
repeat_time=repeat_time,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
hidden_states = layer_outputs[0]
|
| 891 |
+
|
| 892 |
+
if use_cache:
|
| 893 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 894 |
+
|
| 895 |
+
if output_attentions:
|
| 896 |
+
all_self_attns += (layer_outputs[1],)
|
| 897 |
+
|
| 898 |
+
hidden_states = self.norm(hidden_states)
|
| 899 |
+
|
| 900 |
+
# add hidden states from the last decoder layer
|
| 901 |
+
if output_hidden_states:
|
| 902 |
+
all_hidden_states += (hidden_states,)
|
| 903 |
+
|
| 904 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 905 |
+
if not return_dict:
|
| 906 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 907 |
+
return BaseModelOutputWithPast(
|
| 908 |
+
last_hidden_state=hidden_states,
|
| 909 |
+
past_key_values=next_cache,
|
| 910 |
+
hidden_states=all_hidden_states,
|
| 911 |
+
attentions=all_self_attns,
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
| 916 |
+
def __init__(self, config):
|
| 917 |
+
super().__init__(config)
|
| 918 |
+
self.model = LlamaModel(config)
|
| 919 |
+
|
| 920 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 921 |
+
|
| 922 |
+
# Initialize weights and apply final processing
|
| 923 |
+
# self.post_init()
|
| 924 |
+
|
| 925 |
+
def get_input_embeddings(self):
|
| 926 |
+
return self.model.embed_tokens
|
| 927 |
+
|
| 928 |
+
def set_input_embeddings(self, value):
|
| 929 |
+
self.model.embed_tokens = value
|
| 930 |
+
|
| 931 |
+
def get_output_embeddings(self):
|
| 932 |
+
return self.lm_head
|
| 933 |
+
|
| 934 |
+
def set_output_embeddings(self, new_embeddings):
|
| 935 |
+
self.lm_head = new_embeddings
|
| 936 |
+
|
| 937 |
+
def set_decoder(self, decoder):
|
| 938 |
+
self.model = decoder
|
| 939 |
+
|
| 940 |
+
def get_decoder(self):
|
| 941 |
+
return self.model
|
| 942 |
+
|
| 943 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 944 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 945 |
+
def forward(
|
| 946 |
+
self,
|
| 947 |
+
input_ids: torch.LongTensor = None,
|
| 948 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 949 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 950 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 951 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 952 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 953 |
+
labels: Optional[torch.LongTensor] = None,
|
| 954 |
+
use_cache: Optional[bool] = None,
|
| 955 |
+
output_attentions: Optional[bool] = None,
|
| 956 |
+
output_hidden_states: Optional[bool] = None,
|
| 957 |
+
use_zero_attention_mask: Optional[bool] = None,
|
| 958 |
+
return_dict: Optional[bool] = None,
|
| 959 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 960 |
+
r"""
|
| 961 |
+
Args:
|
| 962 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 963 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 964 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 965 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 966 |
+
|
| 967 |
+
Returns:
|
| 968 |
+
|
| 969 |
+
Example:
|
| 970 |
+
|
| 971 |
+
```python
|
| 972 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
| 973 |
+
|
| 974 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 975 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 976 |
+
|
| 977 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
| 978 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 979 |
+
|
| 980 |
+
>>> # Generate
|
| 981 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 982 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 983 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
| 984 |
+
```"""
|
| 985 |
+
|
| 986 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 987 |
+
output_hidden_states = (
|
| 988 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 989 |
+
)
|
| 990 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 991 |
+
|
| 992 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 993 |
+
outputs = self.model(
|
| 994 |
+
input_ids=input_ids,
|
| 995 |
+
attention_mask=attention_mask,
|
| 996 |
+
position_ids=position_ids,
|
| 997 |
+
past_key_values=past_key_values,
|
| 998 |
+
inputs_embeds=inputs_embeds,
|
| 999 |
+
vision_hidden_states=vision_hidden_states,
|
| 1000 |
+
use_cache=use_cache,
|
| 1001 |
+
output_attentions=output_attentions,
|
| 1002 |
+
output_hidden_states=output_hidden_states,
|
| 1003 |
+
return_dict=return_dict,
|
| 1004 |
+
use_zero_attention_mask=use_zero_attention_mask,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
hidden_states = outputs[0]
|
| 1008 |
+
logits = self.lm_head(hidden_states)
|
| 1009 |
+
|
| 1010 |
+
loss = None
|
| 1011 |
+
if labels is not None:
|
| 1012 |
+
# Shift so that tokens < n predict n
|
| 1013 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1014 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1015 |
+
# Flatten the tokens
|
| 1016 |
+
loss_fct = CrossEntropyLoss()
|
| 1017 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1018 |
+
shift_labels = shift_labels.view(-1)
|
| 1019 |
+
# Enable model parallelism
|
| 1020 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1021 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1022 |
+
|
| 1023 |
+
if not return_dict:
|
| 1024 |
+
output = (logits,) + outputs[1:]
|
| 1025 |
+
return (loss,) + output if loss is not None else output
|
| 1026 |
+
|
| 1027 |
+
return CausalLMOutputWithPast(
|
| 1028 |
+
loss=loss,
|
| 1029 |
+
logits=logits,
|
| 1030 |
+
past_key_values=outputs.past_key_values,
|
| 1031 |
+
hidden_states=outputs.hidden_states,
|
| 1032 |
+
attentions=outputs.attentions,
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
def prepare_inputs_for_generation(
|
| 1036 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
| 1037 |
+
vision_hidden_states=None, use_zero_attention_mask=None, **kwargs
|
| 1038 |
+
):
|
| 1039 |
+
if past_key_values:
|
| 1040 |
+
input_ids = input_ids[:, -1:]
|
| 1041 |
+
|
| 1042 |
+
position_ids = kwargs.get('position_ids', None)
|
| 1043 |
+
if attention_mask is not None and position_ids is None:
|
| 1044 |
+
# create position_ids on the fly for batch generation
|
| 1045 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1046 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1047 |
+
if past_key_values:
|
| 1048 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1049 |
+
|
| 1050 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1051 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1052 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
| 1053 |
+
else:
|
| 1054 |
+
model_inputs = {'input_ids': input_ids}
|
| 1055 |
+
|
| 1056 |
+
model_inputs.update(
|
| 1057 |
+
{
|
| 1058 |
+
'position_ids': position_ids,
|
| 1059 |
+
'past_key_values': past_key_values,
|
| 1060 |
+
'use_cache': kwargs.get('use_cache'),
|
| 1061 |
+
'attention_mask': attention_mask,
|
| 1062 |
+
'vision_hidden_states': vision_hidden_states,
|
| 1063 |
+
'use_zero_attention_mask': use_zero_attention_mask,
|
| 1064 |
+
}
|
| 1065 |
+
)
|
| 1066 |
+
return model_inputs
|
| 1067 |
+
|
| 1068 |
+
@staticmethod
|
| 1069 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1070 |
+
reordered_past = ()
|
| 1071 |
+
for layer_past in past_key_values:
|
| 1072 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 1073 |
+
return reordered_past
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torchvision.transforms import InterpolationMode
|
| 11 |
+
from transformers import LlamaTokenizer
|
| 12 |
+
|
| 13 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 14 |
+
from .configuration_internvl import InternVLConfig
|
| 15 |
+
from .modeling_intern_vit import InternVisionModel
|
| 16 |
+
from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
|
| 17 |
+
|
| 18 |
+
__all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
|
| 19 |
+
'InternVLModel', 'InternVL_C', 'InternVL_G']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Prefix the text "summarize:"
|
| 23 |
+
class InternVLTokenizer(nn.Module):
|
| 24 |
+
def __init__(self, model_path):
|
| 25 |
+
super(InternVLTokenizer, self).__init__()
|
| 26 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
| 27 |
+
self.tokenizer.pad_token = ' ' # allow padding
|
| 28 |
+
self.tokenizer.add_eos_token = True
|
| 29 |
+
|
| 30 |
+
def forward(self, text, prefix='summarize:'):
|
| 31 |
+
if type(text) == str:
|
| 32 |
+
text = prefix + text
|
| 33 |
+
elif type(text) == list:
|
| 34 |
+
text = [prefix + item for item in text]
|
| 35 |
+
text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
|
| 36 |
+
return text
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
| 40 |
+
if task == 'retrieval':
|
| 41 |
+
transform = T.Compose([
|
| 42 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 43 |
+
T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
| 44 |
+
T.ToTensor(),
|
| 45 |
+
T.Normalize(mean=mean, std=std)])
|
| 46 |
+
else:
|
| 47 |
+
transform = T.Compose([
|
| 48 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 49 |
+
T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
| 50 |
+
T.CenterCrop(image_size),
|
| 51 |
+
T.ToTensor(),
|
| 52 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
| 53 |
+
return transform
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_internvl_c_huggingface(ckpt_path, device, task):
|
| 57 |
+
model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
|
| 58 |
+
if model.config.use_backbone_lora:
|
| 59 |
+
model.vision_model.merge_and_unload()
|
| 60 |
+
model.vision_model = model.vision_model.model
|
| 61 |
+
if model.config.use_qllama_lora:
|
| 62 |
+
model.qllama.merge_and_unload()
|
| 63 |
+
model.qllama = model.qllama.model
|
| 64 |
+
if model.config.force_image_size is not None:
|
| 65 |
+
image_size = model.config.force_image_size
|
| 66 |
+
else:
|
| 67 |
+
image_size = model.config.vision_config.image_size
|
| 68 |
+
transform = build_transform(task, image_size)
|
| 69 |
+
tokenizer = InternVLTokenizer(ckpt_path)
|
| 70 |
+
return model, transform, tokenizer
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_internvl_g_huggingface(ckpt_path, device, task):
|
| 74 |
+
model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
|
| 75 |
+
if model.config.use_backbone_lora:
|
| 76 |
+
model.vision_model.merge_and_unload()
|
| 77 |
+
model.vision_model = model.vision_model.model
|
| 78 |
+
if model.config.use_qllama_lora:
|
| 79 |
+
model.qllama.merge_and_unload()
|
| 80 |
+
model.qllama = model.qllama.model
|
| 81 |
+
if model.config.force_image_size is not None:
|
| 82 |
+
image_size = model.config.force_image_size
|
| 83 |
+
else:
|
| 84 |
+
image_size = model.config.vision_config.image_size
|
| 85 |
+
transform = build_transform(task, image_size)
|
| 86 |
+
tokenizer = InternVLTokenizer(ckpt_path)
|
| 87 |
+
return model, transform, tokenizer
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_intern_vit.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InternVisionConfig(PretrainedConfig):
|
| 16 |
+
r"""
|
| 17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
| 18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
| 26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 27 |
+
The size (resolution) of each patch.
|
| 28 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 29 |
+
The size (resolution) of each image.
|
| 30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
| 31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
| 33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
| 35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
| 37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
| 43 |
+
Whether to use flash attention mechanism.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
| 47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 48 |
+
The epsilon used by the layer normalization layers.
|
| 49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
Dropout rate for stochastic depth.
|
| 53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout ratio for the attention probabilities.
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
A factor for layer scale.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_type = 'intern_vit_6b'
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
num_channels=3,
|
| 66 |
+
patch_size=14,
|
| 67 |
+
image_size=224,
|
| 68 |
+
qkv_bias=False,
|
| 69 |
+
hidden_size=3200,
|
| 70 |
+
num_attention_heads=25,
|
| 71 |
+
intermediate_size=12800,
|
| 72 |
+
qk_normalization=True,
|
| 73 |
+
num_hidden_layers=48,
|
| 74 |
+
use_flash_attn=True,
|
| 75 |
+
hidden_act='gelu',
|
| 76 |
+
layer_norm_eps=1e-6,
|
| 77 |
+
dropout=0.0,
|
| 78 |
+
drop_path_rate=0.0,
|
| 79 |
+
attention_dropout=0.0,
|
| 80 |
+
initializer_range=0.02,
|
| 81 |
+
initializer_factor=0.1,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.intermediate_size = intermediate_size
|
| 88 |
+
self.dropout = dropout
|
| 89 |
+
self.drop_path_rate = drop_path_rate
|
| 90 |
+
self.num_hidden_layers = num_hidden_layers
|
| 91 |
+
self.num_attention_heads = num_attention_heads
|
| 92 |
+
self.num_channels = num_channels
|
| 93 |
+
self.patch_size = patch_size
|
| 94 |
+
self.image_size = image_size
|
| 95 |
+
self.initializer_range = initializer_range
|
| 96 |
+
self.initializer_factor = initializer_factor
|
| 97 |
+
self.attention_dropout = attention_dropout
|
| 98 |
+
self.layer_norm_eps = layer_norm_eps
|
| 99 |
+
self.hidden_act = hidden_act
|
| 100 |
+
self.qkv_bias = qkv_bias
|
| 101 |
+
self.qk_normalization = qk_normalization
|
| 102 |
+
self.use_flash_attn = use_flash_attn
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
|
| 106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 107 |
+
|
| 108 |
+
if 'vision_config' in config_dict:
|
| 109 |
+
config_dict = config_dict['vision_config']
|
| 110 |
+
|
| 111 |
+
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
|
| 112 |
+
logger.warning(
|
| 113 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 114 |
+
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return cls.from_dict(config_dict, **kwargs)
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_internvl.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
from transformers import LlamaConfig
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 13 |
+
|
| 14 |
+
logger = logging.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class InternVLConfig(PretrainedConfig):
|
| 18 |
+
r"""
|
| 19 |
+
[`InternVLConfig`] is the configuration class to store the configuration of a
|
| 20 |
+
[`InternVLModel`]. It is used to instantiate a InternVLModel according to the specified
|
| 21 |
+
arguments, defining the InternViT-6B and QLLaMA configs. Instantiating a configuration with
|
| 22 |
+
the defaults will yield a similar configuration to that of the InternVL architecture.
|
| 23 |
+
|
| 24 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 25 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
vision_config (`dict`, *optional*):
|
| 29 |
+
Dictionary of configuration options used to initialize [`InternVisionConfig`].
|
| 30 |
+
qllama_config (`dict`, *optional*):
|
| 31 |
+
Dictionary of configuration options used to initialize [`LLaMAConfig`].
|
| 32 |
+
clip_embed_dim (`int`, *optional*, defaults to 768):
|
| 33 |
+
Size of the embeddings from the CLIP model.
|
| 34 |
+
attn_pool_num_heads (`int`, *optional*, defaults to 16):
|
| 35 |
+
Number of attention heads used in the attention pooling layers.
|
| 36 |
+
num_query_token (`int`, *optional*, defaults to 96):
|
| 37 |
+
Number of query tokens used in the transformer.
|
| 38 |
+
label_smoothing (`float`, *optional*, defaults to 0.0):
|
| 39 |
+
The amount of label smoothing to apply.
|
| 40 |
+
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
| 41 |
+
The frequency of cross-attention layers in the model.
|
| 42 |
+
use_backbone_lora (`int`, *optional*, defaults to 0):
|
| 43 |
+
If non-zero, indicates the use of LoRA in the backbone of the model.
|
| 44 |
+
use_qllama_lora (`int`, *optional*, defaults to 0):
|
| 45 |
+
If non-zero, indicates the use of LoRA in the QLLaMA of the model.
|
| 46 |
+
force_image_size (`int` or `None`, *optional*):
|
| 47 |
+
If not None, forces the model to use this specific image size.
|
| 48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 49 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 50 |
+
kwargs (*optional*):
|
| 51 |
+
Dictionary of additional keyword arguments.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
model_type = 'internvl'
|
| 55 |
+
is_composition = True
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
vision_config=None,
|
| 60 |
+
qllama_config=None,
|
| 61 |
+
clip_embed_dim=768,
|
| 62 |
+
attn_pool_num_heads=16,
|
| 63 |
+
num_query_token=96,
|
| 64 |
+
label_smoothing=0.0,
|
| 65 |
+
cross_attention_frequency=2,
|
| 66 |
+
use_backbone_lora=0,
|
| 67 |
+
use_qllama_lora=0,
|
| 68 |
+
force_image_size=None,
|
| 69 |
+
initializer_range=0.02,
|
| 70 |
+
**kwargs):
|
| 71 |
+
super().__init__(**kwargs)
|
| 72 |
+
|
| 73 |
+
if vision_config is None:
|
| 74 |
+
vision_config = {}
|
| 75 |
+
logger.info('vision_config is None. initializing the InternVisionConfig with default values.')
|
| 76 |
+
|
| 77 |
+
if qllama_config is None:
|
| 78 |
+
qllama_config = {}
|
| 79 |
+
logger.info(
|
| 80 |
+
'qllama_config is None. Initializing the InternTextConfig config with default values (`LlamaConfig`).')
|
| 81 |
+
|
| 82 |
+
self.vision_config = InternVisionConfig(**vision_config)
|
| 83 |
+
self.qllama_config = LlamaConfig(**qllama_config)
|
| 84 |
+
self.qllama_config.num_query_token = num_query_token
|
| 85 |
+
self.qllama_config.cross_attention_frequency = cross_attention_frequency
|
| 86 |
+
self.hidden_size = self.qllama_config.hidden_size
|
| 87 |
+
|
| 88 |
+
self.clip_embed_dim = clip_embed_dim
|
| 89 |
+
self.attn_pool_num_heads = attn_pool_num_heads
|
| 90 |
+
self.num_query_token = num_query_token
|
| 91 |
+
self.label_smoothing = label_smoothing
|
| 92 |
+
self.use_backbone_lora = use_backbone_lora
|
| 93 |
+
self.use_qllama_lora = use_qllama_lora
|
| 94 |
+
self.force_image_size = force_image_size
|
| 95 |
+
self.initializer_range = initializer_range
|
| 96 |
+
|
| 97 |
+
def to_dict(self):
|
| 98 |
+
"""
|
| 99 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 103 |
+
"""
|
| 104 |
+
output = copy.deepcopy(self.__dict__)
|
| 105 |
+
output['vision_config'] = self.vision_config.to_dict()
|
| 106 |
+
output['qllama_config'] = self.qllama_config.to_dict()
|
| 107 |
+
output['model_type'] = self.__class__.model_type
|
| 108 |
+
return output
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/flash_attention.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
try: # v1
|
| 7 |
+
from flash_attn.flash_attn_interface import \
|
| 8 |
+
flash_attn_unpadded_qkvpacked_func
|
| 9 |
+
except: # v2
|
| 10 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
| 11 |
+
|
| 12 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlashAttention(nn.Module):
|
| 16 |
+
"""Implement the scaled dot product attention with softmax.
|
| 17 |
+
Arguments
|
| 18 |
+
---------
|
| 19 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 20 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 21 |
+
runtime)
|
| 22 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 23 |
+
(default: 0.0)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.softmax_scale = softmax_scale
|
| 29 |
+
self.dropout_p = attention_dropout
|
| 30 |
+
|
| 31 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
| 32 |
+
max_s=None, need_weights=False):
|
| 33 |
+
"""Implements the multihead softmax attention.
|
| 34 |
+
Arguments
|
| 35 |
+
---------
|
| 36 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
| 37 |
+
if unpadded: (nnz, 3, h, d)
|
| 38 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
| 39 |
+
"""
|
| 40 |
+
assert not need_weights
|
| 41 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 42 |
+
assert qkv.is_cuda
|
| 43 |
+
|
| 44 |
+
if cu_seqlens is None:
|
| 45 |
+
batch_size = qkv.shape[0]
|
| 46 |
+
seqlen = qkv.shape[1]
|
| 47 |
+
if key_padding_mask is None:
|
| 48 |
+
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
| 49 |
+
max_s = seqlen
|
| 50 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 51 |
+
device=qkv.device)
|
| 52 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 53 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 54 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 55 |
+
)
|
| 56 |
+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
| 57 |
+
else:
|
| 58 |
+
nheads = qkv.shape[-2]
|
| 59 |
+
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
| 60 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
| 61 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
| 62 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
| 63 |
+
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 64 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 65 |
+
)
|
| 66 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
| 67 |
+
indices, batch_size, seqlen),
|
| 68 |
+
'b s (h d) -> b s h d', h=nheads)
|
| 69 |
+
else:
|
| 70 |
+
assert max_s is not None
|
| 71 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 72 |
+
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
| 73 |
+
softmax_scale=self.softmax_scale, causal=causal
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return output, None
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from timm.models.layers import DropPath
|
| 13 |
+
from torch import nn
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.modeling_outputs import (BaseModelOutput,
|
| 16 |
+
BaseModelOutputWithPooling)
|
| 17 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from .configuration_intern_vit import InternVisionConfig
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from .flash_attention import FlashAttention
|
| 24 |
+
has_flash_attn = True
|
| 25 |
+
except:
|
| 26 |
+
print('FlashAttention is not installed.')
|
| 27 |
+
has_flash_attn = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class InternRMSNorm(nn.Module):
|
| 34 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 37 |
+
self.variance_epsilon = eps
|
| 38 |
+
|
| 39 |
+
def forward(self, hidden_states):
|
| 40 |
+
input_dtype = hidden_states.dtype
|
| 41 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 42 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 43 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 44 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from apex.normalization import FusedRMSNorm
|
| 49 |
+
|
| 50 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
| 51 |
+
|
| 52 |
+
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
|
| 53 |
+
except ImportError:
|
| 54 |
+
# using the normal InternRMSNorm
|
| 55 |
+
pass
|
| 56 |
+
except Exception:
|
| 57 |
+
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class InternVisionEmbeddings(nn.Module):
|
| 62 |
+
def __init__(self, config: InternVisionConfig):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.config = config
|
| 65 |
+
self.embed_dim = config.hidden_size
|
| 66 |
+
self.image_size = config.image_size
|
| 67 |
+
self.patch_size = config.patch_size
|
| 68 |
+
|
| 69 |
+
self.class_embedding = nn.Parameter(
|
| 70 |
+
torch.randn(1, 1, self.embed_dim),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.patch_embedding = nn.Conv2d(
|
| 74 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 78 |
+
self.num_positions = self.num_patches + 1
|
| 79 |
+
|
| 80 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 81 |
+
|
| 82 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 83 |
+
batch_size = pixel_values.shape[0]
|
| 84 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 85 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 86 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 87 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 88 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 89 |
+
embeddings = embeddings + self.position_embedding.to(target_dtype)
|
| 90 |
+
return embeddings
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class InternAttention(nn.Module):
|
| 94 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, config: InternVisionConfig):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.embed_dim = config.hidden_size
|
| 100 |
+
self.num_heads = config.num_attention_heads
|
| 101 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
| 102 |
+
if config.use_flash_attn and not has_flash_attn:
|
| 103 |
+
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
|
| 104 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 105 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
| 108 |
+
f' {self.num_heads}).'
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.scale = self.head_dim ** -0.5
|
| 112 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
| 113 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
| 114 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
| 115 |
+
|
| 116 |
+
self.qk_normalization = config.qk_normalization
|
| 117 |
+
|
| 118 |
+
if self.qk_normalization:
|
| 119 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 120 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 121 |
+
|
| 122 |
+
if self.use_flash_attn:
|
| 123 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
| 124 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 125 |
+
|
| 126 |
+
def _naive_attn(self, x):
|
| 127 |
+
B, N, C = x.shape
|
| 128 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 129 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 130 |
+
|
| 131 |
+
if self.qk_normalization:
|
| 132 |
+
B_, H_, N_, D_ = q.shape
|
| 133 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 134 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 135 |
+
|
| 136 |
+
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
| 137 |
+
attn = attn.softmax(dim=-1)
|
| 138 |
+
attn = self.attn_drop(attn)
|
| 139 |
+
|
| 140 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 141 |
+
x = self.proj(x)
|
| 142 |
+
x = self.proj_drop(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
| 146 |
+
qkv = self.qkv(x)
|
| 147 |
+
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
| 148 |
+
|
| 149 |
+
if self.qk_normalization:
|
| 150 |
+
q, k, v = qkv.unbind(2)
|
| 151 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
| 152 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
| 153 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 154 |
+
|
| 155 |
+
context, _ = self.inner_attn(
|
| 156 |
+
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
|
| 157 |
+
)
|
| 158 |
+
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
|
| 159 |
+
outs = self.proj_drop(outs)
|
| 160 |
+
return outs
|
| 161 |
+
|
| 162 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class InternMLP(nn.Module):
|
| 168 |
+
def __init__(self, config: InternVisionConfig):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.config = config
|
| 171 |
+
self.act = ACT2FN[config.hidden_act]
|
| 172 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 173 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 174 |
+
|
| 175 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
hidden_states = self.fc1(hidden_states)
|
| 177 |
+
hidden_states = self.act(hidden_states)
|
| 178 |
+
hidden_states = self.fc2(hidden_states)
|
| 179 |
+
return hidden_states
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class InternVisionEncoderLayer(nn.Module):
|
| 183 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.embed_dim = config.hidden_size
|
| 186 |
+
self.intermediate_size = config.intermediate_size
|
| 187 |
+
|
| 188 |
+
self.attn = InternAttention(config)
|
| 189 |
+
self.mlp = InternMLP(config)
|
| 190 |
+
self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 191 |
+
self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 192 |
+
|
| 193 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 194 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 195 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 196 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 197 |
+
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
hidden_states: torch.Tensor,
|
| 201 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
| 202 |
+
"""
|
| 203 |
+
Args:
|
| 204 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 205 |
+
"""
|
| 206 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
| 207 |
+
|
| 208 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
|
| 209 |
+
|
| 210 |
+
return hidden_states
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class InternVisionEncoder(nn.Module):
|
| 214 |
+
"""
|
| 215 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 216 |
+
[`InternEncoderLayer`].
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
config (`InternConfig`):
|
| 220 |
+
The corresponding vision configuration for the `InternEncoder`.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, config: InternVisionConfig):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.config = config
|
| 226 |
+
# stochastic depth decay rule
|
| 227 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
| 228 |
+
self.layers = nn.ModuleList([
|
| 229 |
+
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
|
| 230 |
+
self.gradient_checkpointing = True
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
inputs_embeds,
|
| 235 |
+
output_hidden_states: Optional[bool] = None,
|
| 236 |
+
return_dict: Optional[bool] = None,
|
| 237 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 238 |
+
r"""
|
| 239 |
+
Args:
|
| 240 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 241 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 242 |
+
output_hidden_states (`bool`, *optional*):
|
| 243 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 244 |
+
for more detail.
|
| 245 |
+
return_dict (`bool`, *optional*):
|
| 246 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 247 |
+
"""
|
| 248 |
+
output_hidden_states = (
|
| 249 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 250 |
+
)
|
| 251 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 252 |
+
|
| 253 |
+
encoder_states = () if output_hidden_states else None
|
| 254 |
+
hidden_states = inputs_embeds
|
| 255 |
+
|
| 256 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 257 |
+
if output_hidden_states:
|
| 258 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 259 |
+
if self.gradient_checkpointing and self.training:
|
| 260 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 261 |
+
encoder_layer,
|
| 262 |
+
hidden_states)
|
| 263 |
+
else:
|
| 264 |
+
layer_outputs = encoder_layer(
|
| 265 |
+
hidden_states,
|
| 266 |
+
)
|
| 267 |
+
hidden_states = layer_outputs
|
| 268 |
+
|
| 269 |
+
if output_hidden_states:
|
| 270 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 271 |
+
|
| 272 |
+
if not return_dict:
|
| 273 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
| 274 |
+
return BaseModelOutput(
|
| 275 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class InternVisionModel(PreTrainedModel):
|
| 280 |
+
main_input_name = 'pixel_values'
|
| 281 |
+
config_class = InternVisionConfig
|
| 282 |
+
|
| 283 |
+
def __init__(self, config: InternVisionConfig):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
|
| 287 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 288 |
+
self.encoder = InternVisionEncoder(config)
|
| 289 |
+
|
| 290 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
| 291 |
+
pos_emb = self.embeddings.position_embedding
|
| 292 |
+
_, num_positions, embed_dim = pos_emb.shape
|
| 293 |
+
cls_emb = pos_emb[:, :1, :]
|
| 294 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
| 295 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
|
| 296 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
| 297 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
| 298 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
| 299 |
+
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
|
| 300 |
+
|
| 301 |
+
def get_input_embeddings(self):
|
| 302 |
+
return self.embeddings
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 307 |
+
output_hidden_states: Optional[bool] = None,
|
| 308 |
+
return_dict: Optional[bool] = None,
|
| 309 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 310 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 311 |
+
output_hidden_states = (
|
| 312 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 313 |
+
)
|
| 314 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 315 |
+
|
| 316 |
+
if pixel_values is None and pixel_embeds is None:
|
| 317 |
+
raise ValueError('You have to specify pixel_values or pixel_embeds')
|
| 318 |
+
|
| 319 |
+
if pixel_embeds is not None:
|
| 320 |
+
hidden_states = pixel_embeds
|
| 321 |
+
else:
|
| 322 |
+
if len(pixel_values.shape) == 4:
|
| 323 |
+
hidden_states = self.embeddings(pixel_values)
|
| 324 |
+
else:
|
| 325 |
+
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
|
| 326 |
+
encoder_outputs = self.encoder(
|
| 327 |
+
inputs_embeds=hidden_states,
|
| 328 |
+
output_hidden_states=output_hidden_states,
|
| 329 |
+
return_dict=return_dict,
|
| 330 |
+
)
|
| 331 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 332 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 333 |
+
|
| 334 |
+
if not return_dict:
|
| 335 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 336 |
+
|
| 337 |
+
return BaseModelOutputWithPooling(
|
| 338 |
+
last_hidden_state=last_hidden_state,
|
| 339 |
+
pooler_output=pooled_output,
|
| 340 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 341 |
+
attentions=encoder_outputs.attentions,
|
| 342 |
+
)
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.utils.checkpoint
|
| 15 |
+
from peft import LoraConfig, get_peft_model
|
| 16 |
+
from timm.models.layers import DropPath
|
| 17 |
+
from torch import nn
|
| 18 |
+
from transformers import GenerationConfig
|
| 19 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 20 |
+
from transformers.utils import ModelOutput, logging
|
| 21 |
+
|
| 22 |
+
from .configuration_internvl import InternVLConfig
|
| 23 |
+
from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
|
| 24 |
+
InternVisionModel)
|
| 25 |
+
from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from .flash_attention import FlashAttention # v1/v2
|
| 29 |
+
except:
|
| 30 |
+
print('FlashAttention is not installed.')
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class InternVLPreTrainedModel(PreTrainedModel):
|
| 36 |
+
"""
|
| 37 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 38 |
+
models.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
config_class = InternVLConfig
|
| 42 |
+
base_model_prefix = 'internvl'
|
| 43 |
+
supports_gradient_checkpointing = True
|
| 44 |
+
_keys_to_ignore_on_load_missing = [
|
| 45 |
+
r'position_ids',
|
| 46 |
+
]
|
| 47 |
+
_no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
|
| 48 |
+
_skip_keys_device_placement = 'past_key_values'
|
| 49 |
+
_keep_in_fp32_modules = ['wo']
|
| 50 |
+
|
| 51 |
+
# def _init_weights(self, module):
|
| 52 |
+
# """Initialize the weights"""
|
| 53 |
+
# factor = self.config.initializer_range
|
| 54 |
+
# if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
| 55 |
+
# module.weight.data.normal_(mean=0.0, std=factor)
|
| 56 |
+
# if hasattr(module, 'bias') and module.bias is not None:
|
| 57 |
+
# module.bias.data.zero_()
|
| 58 |
+
# if isinstance(module, InternVisionEmbeddings):
|
| 59 |
+
# if hasattr(self.config, 'vision_config'):
|
| 60 |
+
# factor = self.config.vision_config.initializer_range
|
| 61 |
+
# nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 62 |
+
# nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 63 |
+
# elif isinstance(module, nn.LayerNorm):
|
| 64 |
+
# module.bias.data.zero_()
|
| 65 |
+
# module.weight.data.fill_(1.0)
|
| 66 |
+
# elif isinstance(module, nn.Linear) and module.bias is not None:
|
| 67 |
+
# module.bias.data.zero_()
|
| 68 |
+
|
| 69 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 70 |
+
if isinstance(module, InternVisionModel):
|
| 71 |
+
module.gradient_checkpointing = value
|
| 72 |
+
if isinstance(module, InternVisionEncoder):
|
| 73 |
+
module.gradient_checkpointing = value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CrossAttention(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 79 |
+
proj_drop=0., attn_head_dim=None, out_dim=None):
|
| 80 |
+
super().__init__()
|
| 81 |
+
if out_dim is None:
|
| 82 |
+
out_dim = dim
|
| 83 |
+
self.num_heads = num_heads
|
| 84 |
+
head_dim = dim // num_heads
|
| 85 |
+
if attn_head_dim is not None:
|
| 86 |
+
head_dim = attn_head_dim
|
| 87 |
+
all_head_dim = head_dim * self.num_heads
|
| 88 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 89 |
+
assert all_head_dim == dim
|
| 90 |
+
|
| 91 |
+
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
| 92 |
+
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
| 93 |
+
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
| 94 |
+
|
| 95 |
+
if qkv_bias:
|
| 96 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 97 |
+
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 98 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 99 |
+
else:
|
| 100 |
+
self.q_bias = None
|
| 101 |
+
self.k_bias = None
|
| 102 |
+
self.v_bias = None
|
| 103 |
+
|
| 104 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 105 |
+
self.proj = nn.Linear(all_head_dim, out_dim)
|
| 106 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 107 |
+
|
| 108 |
+
def forward(self, x, k=None, v=None):
|
| 109 |
+
B, N, C = x.shape
|
| 110 |
+
N_k = k.shape[1]
|
| 111 |
+
N_v = v.shape[1]
|
| 112 |
+
|
| 113 |
+
q_bias, k_bias, v_bias = None, None, None
|
| 114 |
+
if self.q_bias is not None:
|
| 115 |
+
q_bias = self.q_bias
|
| 116 |
+
k_bias = self.k_bias
|
| 117 |
+
v_bias = self.v_bias
|
| 118 |
+
|
| 119 |
+
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
| 120 |
+
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
| 121 |
+
|
| 122 |
+
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
| 123 |
+
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
| 124 |
+
|
| 125 |
+
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
| 126 |
+
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
| 127 |
+
|
| 128 |
+
q = q * self.scale
|
| 129 |
+
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
| 130 |
+
|
| 131 |
+
attn = attn.softmax(dim=-1)
|
| 132 |
+
attn = self.attn_drop(attn)
|
| 133 |
+
|
| 134 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 135 |
+
x = self.proj(x)
|
| 136 |
+
x = self.proj_drop(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AttentiveBlock(nn.Module):
|
| 142 |
+
|
| 143 |
+
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 144 |
+
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.norm1_q = norm_layer(dim)
|
| 148 |
+
self.norm1_k = norm_layer(dim)
|
| 149 |
+
self.norm1_v = norm_layer(dim)
|
| 150 |
+
self.cross_attn = CrossAttention(
|
| 151 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
| 152 |
+
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
|
| 153 |
+
|
| 154 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 155 |
+
|
| 156 |
+
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
|
| 157 |
+
x_q = self.norm1_q(x_q + pos_q)
|
| 158 |
+
x_k = self.norm1_k(x_kv + pos_k)
|
| 159 |
+
x_v = self.norm1_v(x_kv)
|
| 160 |
+
x = self.cross_attn(x_q, k=x_k, v=x_v)
|
| 161 |
+
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class AttentionPoolingBlock(AttentiveBlock):
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
x_q = x.mean(1, keepdim=True)
|
| 169 |
+
x_kv, pos_q, pos_k = x, 0, 0
|
| 170 |
+
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
|
| 171 |
+
x = x.squeeze(1)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dataclass
|
| 176 |
+
class InternVLModelOutput(ModelOutput):
|
| 177 |
+
"""
|
| 178 |
+
Class defining the outputs of [`InternVLModelOutput`].
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
loss: Optional[torch.FloatTensor] = None
|
| 182 |
+
loss_itm: Optional[torch.FloatTensor] = None
|
| 183 |
+
loss_itc: Optional[torch.FloatTensor] = None
|
| 184 |
+
loss_itg: Optional[torch.FloatTensor] = None
|
| 185 |
+
|
| 186 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 187 |
+
return tuple(
|
| 188 |
+
self[k]
|
| 189 |
+
if k not in ['loss', 'loss_itm', 'loss_itc', 'loss_itg']
|
| 190 |
+
else getattr(self, k).to_tuple()
|
| 191 |
+
for k in self.keys()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GatherLayer(torch.autograd.Function):
|
| 196 |
+
"""Gather tensors from all process, supporting backward propagation.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def forward(ctx, input):
|
| 201 |
+
ctx.save_for_backward(input)
|
| 202 |
+
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
|
| 203 |
+
dist.all_gather(output, input)
|
| 204 |
+
return torch.stack(output, 0)
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def backward(ctx, grads):
|
| 208 |
+
input, = ctx.saved_tensors
|
| 209 |
+
dist.all_reduce(grads)
|
| 210 |
+
grad_out = torch.zeros_like(input)
|
| 211 |
+
grad_out[:] = grads[dist.get_rank()]
|
| 212 |
+
return grad_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class InternVLModel(InternVLPreTrainedModel):
|
| 216 |
+
config_class = InternVLConfig
|
| 217 |
+
main_input_name = 'pixel_values'
|
| 218 |
+
|
| 219 |
+
def __init__(self, config: InternVLConfig):
|
| 220 |
+
super().__init__(config)
|
| 221 |
+
|
| 222 |
+
text_hidden_size = config.qllama_config.hidden_size
|
| 223 |
+
vision_hidden_size = config.vision_config.hidden_size
|
| 224 |
+
clip_embed_dim = config.clip_embed_dim
|
| 225 |
+
attn_pool_num_heads = config.attn_pool_num_heads
|
| 226 |
+
config.qllama_config.num_query_token = config.num_query_token
|
| 227 |
+
self.num_query_token = config.num_query_token
|
| 228 |
+
self.label_smoothing = config.label_smoothing
|
| 229 |
+
|
| 230 |
+
self.vision_model = InternVisionModel(config.vision_config) # frozen
|
| 231 |
+
self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
|
| 232 |
+
self.query_tokens = nn.Parameter( # trainable
|
| 233 |
+
torch.zeros(1, config.num_query_token, text_hidden_size)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
|
| 237 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
|
| 238 |
+
self.clip_projector = AttentionPoolingBlock( # frozen
|
| 239 |
+
dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
|
| 240 |
+
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
|
| 241 |
+
self.clip_projector2 = AttentionPoolingBlock( # trainable
|
| 242 |
+
dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
|
| 243 |
+
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
|
| 244 |
+
self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
|
| 245 |
+
self.gradient_checkpointing = True
|
| 246 |
+
|
| 247 |
+
# Initialize weights and apply final processing
|
| 248 |
+
# self.post_init()
|
| 249 |
+
|
| 250 |
+
if config.use_backbone_lora:
|
| 251 |
+
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=config.use_backbone_lora * 2)
|
| 252 |
+
if config.use_qllama_lora:
|
| 253 |
+
self.wrap_qllama_lora(r=config.use_qllama_lora, lora_alpha=config.use_qllama_lora * 2)
|
| 254 |
+
if config.force_image_size:
|
| 255 |
+
self.vision_model.resize_pos_embeddings(
|
| 256 |
+
old_size=config.vision_config.image_size,
|
| 257 |
+
new_size=config.force_image_size,
|
| 258 |
+
patch_size=config.vision_config.patch_size
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 262 |
+
lora_config = LoraConfig(
|
| 263 |
+
r=r,
|
| 264 |
+
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
|
| 265 |
+
lora_alpha=lora_alpha,
|
| 266 |
+
lora_dropout=lora_dropout,
|
| 267 |
+
)
|
| 268 |
+
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
| 269 |
+
self.vision_model.print_trainable_parameters()
|
| 270 |
+
|
| 271 |
+
def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
| 272 |
+
lora_config = LoraConfig(
|
| 273 |
+
r=r,
|
| 274 |
+
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
|
| 275 |
+
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
|
| 276 |
+
lora_alpha=lora_alpha,
|
| 277 |
+
lora_dropout=lora_dropout,
|
| 278 |
+
)
|
| 279 |
+
self.qllama = get_peft_model(self.qllama, lora_config)
|
| 280 |
+
self.qllama.print_trainable_parameters()
|
| 281 |
+
|
| 282 |
+
def get_input_embeddings(self):
|
| 283 |
+
return self.qllama.get_input_embeddings()
|
| 284 |
+
|
| 285 |
+
def set_input_embeddings(self, value):
|
| 286 |
+
self.qllama.set_input_embeddings(value)
|
| 287 |
+
|
| 288 |
+
def set_output_embeddings(self, new_embeddings):
|
| 289 |
+
self.qllama.set_output_embeddings(new_embeddings)
|
| 290 |
+
|
| 291 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 292 |
+
return self.qllama.get_output_embeddings()
|
| 293 |
+
|
| 294 |
+
@torch.no_grad()
|
| 295 |
+
def _prepare_attention_mask(
|
| 296 |
+
self,
|
| 297 |
+
image_attention_mask: torch.LongTensor,
|
| 298 |
+
attention_mask: torch.LongTensor,
|
| 299 |
+
input_embeds: torch.FloatTensor,
|
| 300 |
+
repeat_time: int,
|
| 301 |
+
):
|
| 302 |
+
# itm, itc
|
| 303 |
+
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
| 304 |
+
expand_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 305 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 306 |
+
itm_mask_neg, itm_mask_pos, itc_mask = torch.chunk(expand_mask, repeat_time, dim=0)
|
| 307 |
+
|
| 308 |
+
itc_mask[:, :, :self.num_query_token, self.num_query_token:] = torch.finfo(input_embeds.dtype).min
|
| 309 |
+
itc_mask[:, :, self.num_query_token:, :self.num_query_token] = torch.finfo(input_embeds.dtype).min
|
| 310 |
+
itc_mask_causal = _make_causal_mask(
|
| 311 |
+
(itc_mask.shape[0], itc_mask.shape[2] - self.num_query_token),
|
| 312 |
+
input_embeds.dtype,
|
| 313 |
+
device=input_embeds.device
|
| 314 |
+
)
|
| 315 |
+
# use causal mask for text in itc
|
| 316 |
+
itc_mask[:, :, self.num_query_token:, self.num_query_token:] += itc_mask_causal
|
| 317 |
+
|
| 318 |
+
attention_mask = torch.cat([itm_mask_neg, itm_mask_pos, itc_mask], dim=0)
|
| 319 |
+
|
| 320 |
+
return attention_mask
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
pixel_values: torch.FloatTensor,
|
| 325 |
+
positive_input_ids: torch.FloatTensor,
|
| 326 |
+
positive_attention_mask: torch.LongTensor,
|
| 327 |
+
negative_input_ids: torch.FloatTensor,
|
| 328 |
+
negative_attention_mask: torch.LongTensor,
|
| 329 |
+
summarize_input_ids: torch.FloatTensor,
|
| 330 |
+
summarize_attention_mask: torch.LongTensor,
|
| 331 |
+
input_ids: torch.FloatTensor,
|
| 332 |
+
attention_mask: torch.LongTensor,
|
| 333 |
+
image_ids: torch.LongTensor,
|
| 334 |
+
labels: torch.LongTensor,
|
| 335 |
+
output_attentions: Optional[bool] = None,
|
| 336 |
+
output_hidden_states: Optional[bool] = None,
|
| 337 |
+
return_dict: Optional[bool] = None,
|
| 338 |
+
) -> Union[Tuple, InternVLModelOutput]:
|
| 339 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 340 |
+
|
| 341 |
+
# step 1: forward the images through the vision encoder,
|
| 342 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 343 |
+
vision_outputs = self.vision_model(
|
| 344 |
+
pixel_values=pixel_values,
|
| 345 |
+
output_hidden_states=output_hidden_states,
|
| 346 |
+
return_dict=return_dict)
|
| 347 |
+
image_embeds = vision_outputs[0]
|
| 348 |
+
backbone_embeds = self.clip_projector(image_embeds)
|
| 349 |
+
|
| 350 |
+
# step 2: prepare input_ids and attention_mask for two sub-tasks:
|
| 351 |
+
# 1) image-text matching; 2) image-text contrastive learning.
|
| 352 |
+
batch_size = input_ids.shape[0]
|
| 353 |
+
input_ids = torch.cat([negative_input_ids, positive_input_ids,
|
| 354 |
+
summarize_input_ids], dim=0) # [3 * batch_size, seq_len]
|
| 355 |
+
itm_attention_mask = torch.cat(
|
| 356 |
+
[negative_attention_mask, positive_attention_mask], dim=0)
|
| 357 |
+
attention_mask = torch.cat(
|
| 358 |
+
[itm_attention_mask, summarize_attention_mask], dim=0) # [3 * batch_size, seq_len]
|
| 359 |
+
|
| 360 |
+
repeat_time = input_ids.size(0) // batch_size
|
| 361 |
+
# step 3: forward the input_ids and attention_mask through the text encoder.
|
| 362 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 363 |
+
query_tokens = self.query_tokens.repeat(repeat_time * batch_size, 1, 1)
|
| 364 |
+
input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
|
| 365 |
+
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 366 |
+
attention_mask = self._prepare_attention_mask(
|
| 367 |
+
image_attention_mask, attention_mask, input_embeds, repeat_time
|
| 368 |
+
)
|
| 369 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 370 |
+
outputs = self.qllama.model.model.forward_train(
|
| 371 |
+
inputs_embeds=input_embeds,
|
| 372 |
+
vision_hidden_states=image_embeds,
|
| 373 |
+
attention_mask=attention_mask,
|
| 374 |
+
output_attentions=output_attentions,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
return_dict=return_dict,
|
| 377 |
+
repeat_time=repeat_time,
|
| 378 |
+
).last_hidden_state
|
| 379 |
+
else:
|
| 380 |
+
outputs = self.qllama.model.forward_train(
|
| 381 |
+
inputs_embeds=input_embeds,
|
| 382 |
+
vision_hidden_states=image_embeds,
|
| 383 |
+
attention_mask=attention_mask,
|
| 384 |
+
output_attentions=output_attentions,
|
| 385 |
+
output_hidden_states=output_hidden_states,
|
| 386 |
+
return_dict=return_dict,
|
| 387 |
+
repeat_time=repeat_time,
|
| 388 |
+
).last_hidden_state
|
| 389 |
+
image_embeds = outputs[:, :self.num_query_token]
|
| 390 |
+
text_embeds = outputs[:, self.num_query_token:]
|
| 391 |
+
image_itm_neg, image_itm_pos, image_itc = image_embeds.chunk(repeat_time, dim=0)
|
| 392 |
+
text_itm_neg, text_itm_pos, text_itc = text_embeds.chunk(repeat_time, dim=0)
|
| 393 |
+
image_itm = torch.cat([image_itm_neg, image_itm_pos], dim=0)
|
| 394 |
+
|
| 395 |
+
###============== Image-Text Matching ===================###
|
| 396 |
+
image_itm = self.itm_head(image_itm)
|
| 397 |
+
logits = image_itm.mean(dim=1)
|
| 398 |
+
itm_labels = torch.cat([
|
| 399 |
+
torch.zeros(batch_size, dtype=torch.long, device=logits.device),
|
| 400 |
+
torch.ones(batch_size, dtype=torch.long, device=logits.device)
|
| 401 |
+
], dim=0)
|
| 402 |
+
loss_itm = F.cross_entropy(logits, itm_labels)
|
| 403 |
+
neg_match_acc = ((logits[:batch_size].argmax(dim=-1) == 0) / batch_size).sum()
|
| 404 |
+
pos_match_acc = ((logits[batch_size:].argmax(dim=-1) == 1) / batch_size).sum()
|
| 405 |
+
|
| 406 |
+
###============== Image-Text Contrastive ===================###
|
| 407 |
+
image_itc = self.clip_projector2(image_itc)
|
| 408 |
+
|
| 409 |
+
selected = summarize_attention_mask.sum(1) - 1
|
| 410 |
+
text_itc = text_itc[torch.arange(text_itc.shape[0]), selected]
|
| 411 |
+
text_itc = text_itc @ self.text_projection
|
| 412 |
+
|
| 413 |
+
# normalized features
|
| 414 |
+
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
|
| 415 |
+
image_itc = image_itc / image_itc.norm(dim=1, keepdim=True)
|
| 416 |
+
text_itc = text_itc / text_itc.norm(dim=1, keepdim=True)
|
| 417 |
+
backbone_embeds_all = GatherLayer.apply(backbone_embeds).flatten(0, 1)
|
| 418 |
+
image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
|
| 419 |
+
text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
|
| 420 |
+
|
| 421 |
+
# cosine similarity as logits
|
| 422 |
+
logit_scale = self.logit_scale.exp()
|
| 423 |
+
sim_i2t = logit_scale * (image_itc @ text_itc_all.t())
|
| 424 |
+
sim_t2i = logit_scale * (text_itc @ image_itc_all.t())
|
| 425 |
+
backbone_i2t = logit_scale * (backbone_embeds @ text_itc_all.t())
|
| 426 |
+
backbone_t2i = logit_scale * (text_itc @ backbone_embeds_all.t())
|
| 427 |
+
|
| 428 |
+
image_ids = image_ids.view(-1, 1)
|
| 429 |
+
image_ids_all = GatherLayer.apply(image_ids).flatten(0, 1)
|
| 430 |
+
pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
|
| 431 |
+
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
|
| 432 |
+
|
| 433 |
+
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
|
| 434 |
+
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
|
| 435 |
+
loss_backbone_t2i = -torch.sum(F.log_softmax(backbone_t2i, dim=1) * sim_targets, dim=1).mean()
|
| 436 |
+
loss_backbone_i2t = -torch.sum(F.log_softmax(backbone_i2t, dim=1) * sim_targets, dim=1).mean()
|
| 437 |
+
loss_itc = (loss_t2i + loss_i2t) / 2 + (loss_backbone_t2i + loss_backbone_i2t) / 2
|
| 438 |
+
|
| 439 |
+
vision_sim = F.cosine_similarity(backbone_embeds.detach(), image_itc).mean()
|
| 440 |
+
|
| 441 |
+
loss = loss_itm + loss_itc
|
| 442 |
+
if dist.get_rank() == 0:
|
| 443 |
+
print(f'loss: {loss.item()}, loss_itm: {loss_itm.item()}, loss_itc: {loss_itc.item()}, '
|
| 444 |
+
f'vision_similarity: {round(vision_sim.item(), 5)}, '
|
| 445 |
+
f'logit scale: {round(1.0 / logit_scale.item(), 5)}, '
|
| 446 |
+
f'pos_match_acc: {round(pos_match_acc.item(), 4)}, '
|
| 447 |
+
f'neg_match_acc: {round(neg_match_acc.item(), 4)}')
|
| 448 |
+
|
| 449 |
+
return InternVLModelOutput(
|
| 450 |
+
loss=loss,
|
| 451 |
+
loss_itc=loss_itc.detach(),
|
| 452 |
+
loss_itm=loss_itm.detach(),
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
@torch.no_grad()
|
| 456 |
+
def generate(
|
| 457 |
+
self,
|
| 458 |
+
pixel_values: torch.FloatTensor,
|
| 459 |
+
input_ids: torch.FloatTensor,
|
| 460 |
+
attention_mask: torch.LongTensor,
|
| 461 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 462 |
+
output_hidden_states: Optional[bool] = None,
|
| 463 |
+
return_dict: Optional[bool] = None,
|
| 464 |
+
**generate_kwargs,
|
| 465 |
+
) -> torch.LongTensor:
|
| 466 |
+
|
| 467 |
+
vision_outputs = self.vision_model(
|
| 468 |
+
pixel_values=pixel_values,
|
| 469 |
+
output_hidden_states=output_hidden_states,
|
| 470 |
+
return_dict=return_dict)
|
| 471 |
+
image_embeds = vision_outputs[0]
|
| 472 |
+
|
| 473 |
+
batch_size = image_embeds.shape[0]
|
| 474 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 475 |
+
query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
|
| 476 |
+
input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
|
| 477 |
+
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 478 |
+
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
|
| 479 |
+
|
| 480 |
+
outputs = self.qllama.generate(
|
| 481 |
+
inputs_embeds=input_embeds,
|
| 482 |
+
attention_mask=attention_mask,
|
| 483 |
+
vision_hidden_states=image_embeds,
|
| 484 |
+
generation_config=generation_config,
|
| 485 |
+
use_zero_attention_mask=True,
|
| 486 |
+
**generate_kwargs,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
return outputs
|
| 490 |
+
|
| 491 |
+
def get_text_features(
|
| 492 |
+
self,
|
| 493 |
+
input_ids: torch.Tensor,
|
| 494 |
+
attention_mask: torch.Tensor,
|
| 495 |
+
output_attentions: Optional[bool] = None,
|
| 496 |
+
output_hidden_states: Optional[bool] = None,
|
| 497 |
+
return_dict: Optional[bool] = None,
|
| 498 |
+
):
|
| 499 |
+
r"""
|
| 500 |
+
Returns:
|
| 501 |
+
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
| 502 |
+
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
| 503 |
+
contains the language model logits, the past key values and the hidden states if
|
| 504 |
+
`output_hidden_states=True`.
|
| 505 |
+
```"""
|
| 506 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 507 |
+
output_hidden_states = (
|
| 508 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 509 |
+
)
|
| 510 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 511 |
+
|
| 512 |
+
input_embeds = self.get_input_embeddings()(input_ids)
|
| 513 |
+
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 514 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 515 |
+
attention_mask += _make_causal_mask(
|
| 516 |
+
(attention_mask.shape[0], attention_mask.shape[2]),
|
| 517 |
+
input_embeds.dtype,
|
| 518 |
+
device=input_embeds.device
|
| 519 |
+
)
|
| 520 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 521 |
+
outputs = self.qllama.model.model.forward_train(
|
| 522 |
+
inputs_embeds=input_embeds,
|
| 523 |
+
vision_hidden_states=None,
|
| 524 |
+
attention_mask=attention_mask,
|
| 525 |
+
output_attentions=output_attentions,
|
| 526 |
+
output_hidden_states=output_hidden_states,
|
| 527 |
+
return_dict=return_dict,
|
| 528 |
+
).last_hidden_state
|
| 529 |
+
else:
|
| 530 |
+
outputs = self.qllama.model.forward_train(
|
| 531 |
+
inputs_embeds=input_embeds,
|
| 532 |
+
vision_hidden_states=None,
|
| 533 |
+
attention_mask=attention_mask,
|
| 534 |
+
output_attentions=output_attentions,
|
| 535 |
+
output_hidden_states=output_hidden_states,
|
| 536 |
+
return_dict=return_dict,
|
| 537 |
+
).last_hidden_state
|
| 538 |
+
return outputs
|
| 539 |
+
|
| 540 |
+
def get_image_features(
|
| 541 |
+
self,
|
| 542 |
+
pixel_values: torch.FloatTensor,
|
| 543 |
+
output_attentions: Optional[bool] = None,
|
| 544 |
+
output_hidden_states: Optional[bool] = None,
|
| 545 |
+
return_dict: Optional[bool] = None,
|
| 546 |
+
):
|
| 547 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 548 |
+
output_hidden_states = (
|
| 549 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 550 |
+
)
|
| 551 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 552 |
+
|
| 553 |
+
vision_outputs = self.vision_model(
|
| 554 |
+
pixel_values=pixel_values,
|
| 555 |
+
output_hidden_states=output_hidden_states,
|
| 556 |
+
return_dict=return_dict)
|
| 557 |
+
image_embeds = vision_outputs[0]
|
| 558 |
+
backbone_embeds = image_embeds
|
| 559 |
+
|
| 560 |
+
batch_size = image_embeds.shape[0]
|
| 561 |
+
input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
|
| 562 |
+
|
| 563 |
+
attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 564 |
+
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
|
| 565 |
+
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
|
| 566 |
+
if type(self.qllama.model) == LlamaForCausalLM:
|
| 567 |
+
outputs = self.qllama.model.model.forward_train(
|
| 568 |
+
inputs_embeds=input_embeds,
|
| 569 |
+
vision_hidden_states=image_embeds,
|
| 570 |
+
attention_mask=attention_mask,
|
| 571 |
+
output_attentions=output_attentions,
|
| 572 |
+
output_hidden_states=output_hidden_states,
|
| 573 |
+
return_dict=return_dict,
|
| 574 |
+
).last_hidden_state
|
| 575 |
+
else:
|
| 576 |
+
outputs = self.qllama.model.forward_train(
|
| 577 |
+
inputs_embeds=input_embeds,
|
| 578 |
+
vision_hidden_states=image_embeds,
|
| 579 |
+
attention_mask=attention_mask,
|
| 580 |
+
output_attentions=output_attentions,
|
| 581 |
+
output_hidden_states=output_hidden_states,
|
| 582 |
+
return_dict=return_dict,
|
| 583 |
+
).last_hidden_state
|
| 584 |
+
return backbone_embeds, outputs
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class InternVL_C(InternVLModel):
|
| 588 |
+
|
| 589 |
+
def encode_image(self, image):
|
| 590 |
+
vision_outputs = self.vision_model(
|
| 591 |
+
pixel_values=image,
|
| 592 |
+
output_hidden_states=False,
|
| 593 |
+
return_dict=True)
|
| 594 |
+
image_embeds = vision_outputs[0]
|
| 595 |
+
image_embeds = self.clip_projector(image_embeds)
|
| 596 |
+
return image_embeds
|
| 597 |
+
|
| 598 |
+
def encode_text(self, text):
|
| 599 |
+
attention_mask = text > 0
|
| 600 |
+
text_embeds = self.get_text_features(
|
| 601 |
+
input_ids=text,
|
| 602 |
+
attention_mask=attention_mask,
|
| 603 |
+
output_attentions=False,
|
| 604 |
+
output_hidden_states=False,
|
| 605 |
+
return_dict=True,
|
| 606 |
+
)
|
| 607 |
+
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
|
| 608 |
+
text_embeds = text_embeds @ self.text_projection
|
| 609 |
+
return text_embeds
|
| 610 |
+
|
| 611 |
+
def forward(self, image, text):
|
| 612 |
+
image_features = self.encode_image(image)
|
| 613 |
+
text_features = self.encode_text(text)
|
| 614 |
+
|
| 615 |
+
# normalized features
|
| 616 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 617 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 618 |
+
|
| 619 |
+
# cosine similarity as logits
|
| 620 |
+
logit_scale = self.logit_scale.exp()
|
| 621 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 622 |
+
logits_per_text = logits_per_image.t()
|
| 623 |
+
|
| 624 |
+
return logits_per_image, logits_per_text
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class InternVL_G(InternVLModel):
|
| 628 |
+
|
| 629 |
+
def encode_image(self, image):
|
| 630 |
+
backbone_embeds, image_embeds = self.get_image_features(
|
| 631 |
+
pixel_values=image,
|
| 632 |
+
output_hidden_states=False,
|
| 633 |
+
return_dict=True,
|
| 634 |
+
)
|
| 635 |
+
backbone_embeds = self.clip_projector(backbone_embeds)
|
| 636 |
+
image_embeds = self.clip_projector2(image_embeds)
|
| 637 |
+
# ensemble
|
| 638 |
+
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
|
| 639 |
+
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
|
| 640 |
+
image_embeds = image_embeds + backbone_embeds
|
| 641 |
+
return image_embeds
|
| 642 |
+
|
| 643 |
+
def encode_text(self, text):
|
| 644 |
+
attention_mask = text > 0
|
| 645 |
+
text_embeds = self.get_text_features(
|
| 646 |
+
input_ids=text,
|
| 647 |
+
attention_mask=attention_mask,
|
| 648 |
+
output_attentions=False,
|
| 649 |
+
output_hidden_states=False,
|
| 650 |
+
return_dict=True,
|
| 651 |
+
)
|
| 652 |
+
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
|
| 653 |
+
text_embeds = text_embeds @ self.text_projection
|
| 654 |
+
return text_embeds
|
| 655 |
+
|
| 656 |
+
def forward(self, image, text):
|
| 657 |
+
image_features = self.encode_image(image)
|
| 658 |
+
text_features = self.encode_text(text)
|
| 659 |
+
|
| 660 |
+
# normalized features
|
| 661 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 662 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 663 |
+
|
| 664 |
+
# cosine similarity as logits
|
| 665 |
+
logit_scale = self.logit_scale.exp()
|
| 666 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 667 |
+
logits_per_text = logits_per_image.t()
|
| 668 |
+
|
| 669 |
+
return logits_per_image, logits_per_text
|
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 4 |
+
# and OPT implementations in this library. It has been modified from its
|
| 5 |
+
# original forms to accommodate minor architectural differences compared
|
| 6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
""" PyTorch QLLaMA model."""
|
| 20 |
+
import math
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import CrossEntropyLoss
|
| 27 |
+
from transformers import LlamaConfig
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 30 |
+
CausalLMOutputWithPast)
|
| 31 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
+
from transformers.utils import (add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward, logging,
|
| 34 |
+
replace_return_docstrings)
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
_CONFIG_FOR_DOC = 'LlamaConfig'
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 42 |
+
def _make_causal_mask(
|
| 43 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Make causal mask used for bi-directional self-attention.
|
| 47 |
+
"""
|
| 48 |
+
bsz, tgt_len = input_ids_shape
|
| 49 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 50 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 51 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 52 |
+
mask = mask.to(dtype)
|
| 53 |
+
|
| 54 |
+
if past_key_values_length > 0:
|
| 55 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 56 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
| 60 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 61 |
+
"""
|
| 62 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 63 |
+
"""
|
| 64 |
+
bsz, src_len = mask.size()
|
| 65 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 66 |
+
|
| 67 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 68 |
+
|
| 69 |
+
inverted_mask = 1.0 - expanded_mask
|
| 70 |
+
|
| 71 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LlamaRMSNorm(nn.Module):
|
| 75 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 76 |
+
"""
|
| 77 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
| 78 |
+
"""
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 81 |
+
self.variance_epsilon = eps
|
| 82 |
+
|
| 83 |
+
def forward(self, hidden_states):
|
| 84 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 85 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 86 |
+
|
| 87 |
+
# convert into half-precision if necessary
|
| 88 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 89 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 90 |
+
|
| 91 |
+
return self.weight * hidden_states
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
from functools import partial
|
| 96 |
+
|
| 97 |
+
from apex.normalization import FusedRMSNorm
|
| 98 |
+
|
| 99 |
+
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
|
| 100 |
+
print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
|
| 101 |
+
except ImportError:
|
| 102 |
+
# using the normal LlamaRMSNorm
|
| 103 |
+
pass
|
| 104 |
+
except Exception:
|
| 105 |
+
print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
| 110 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 113 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 114 |
+
|
| 115 |
+
# Build here to make `torch.jit.trace` work.
|
| 116 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 117 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 118 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
| 119 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 120 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 121 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 122 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 123 |
+
|
| 124 |
+
def forward(self, x, seq_len=None):
|
| 125 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 126 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
| 127 |
+
if seq_len > self.max_seq_len_cached:
|
| 128 |
+
self.max_seq_len_cached = seq_len
|
| 129 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
| 130 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
| 131 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 132 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 133 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 134 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 135 |
+
return (
|
| 136 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 137 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class FixedLlamaRotaryEmbedding(torch.nn.Module):
|
| 142 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.dim = dim
|
| 146 |
+
self.max_position_embeddings = max_position_embeddings
|
| 147 |
+
self.base = base
|
| 148 |
+
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 149 |
+
|
| 150 |
+
# Build here to make `torch.jit.trace` work.
|
| 151 |
+
self._set_cos_sin_cache(
|
| 152 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 156 |
+
self.max_seq_len_cached = seq_len
|
| 157 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 158 |
+
|
| 159 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 160 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 161 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 162 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 163 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x, seq_len=None):
|
| 166 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 167 |
+
if seq_len > self.max_seq_len_cached:
|
| 168 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 169 |
+
|
| 170 |
+
return (
|
| 171 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 172 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def rotate_half(x):
|
| 180 |
+
"""Rotates half the hidden dims of the input."""
|
| 181 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 182 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 183 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 187 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
| 188 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
| 189 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
| 190 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
| 191 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 192 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 193 |
+
return q_embed, k_embed
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LlamaMLP(nn.Module):
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
hidden_size: int,
|
| 200 |
+
intermediate_size: int,
|
| 201 |
+
hidden_act: str,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 205 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 206 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 207 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class LlamaAttention(nn.Module):
|
| 214 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, config: LlamaConfig):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.hidden_size = config.hidden_size
|
| 220 |
+
self.num_heads = config.num_attention_heads
|
| 221 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 222 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 223 |
+
|
| 224 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
|
| 227 |
+
f' and `num_heads`: {self.num_heads}).'
|
| 228 |
+
)
|
| 229 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 230 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 231 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 232 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 233 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
| 234 |
+
|
| 235 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 236 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
hidden_states: torch.Tensor,
|
| 241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 242 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 243 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 244 |
+
output_attentions: bool = False,
|
| 245 |
+
use_cache: bool = False,
|
| 246 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 247 |
+
bsz, q_len, _ = hidden_states.size()
|
| 248 |
+
|
| 249 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 250 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 251 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 252 |
+
|
| 253 |
+
kv_seq_len = key_states.shape[-2]
|
| 254 |
+
if past_key_value is not None:
|
| 255 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 256 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 257 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 258 |
+
# [bsz, nh, t, hd]
|
| 259 |
+
|
| 260 |
+
if past_key_value is not None:
|
| 261 |
+
# reuse k, v, self_attention
|
| 262 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 263 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 264 |
+
|
| 265 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 266 |
+
|
| 267 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 268 |
+
|
| 269 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
|
| 272 |
+
f' {attn_weights.size()}'
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if attention_mask is not None:
|
| 276 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
|
| 279 |
+
)
|
| 280 |
+
attn_weights = attn_weights + attention_mask
|
| 281 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 282 |
+
|
| 283 |
+
# upcast attention to fp32
|
| 284 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 285 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 286 |
+
|
| 287 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
|
| 290 |
+
f' {attn_output.size()}'
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
attn_output = attn_output.transpose(1, 2)
|
| 294 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 295 |
+
|
| 296 |
+
attn_output = self.o_proj(attn_output)
|
| 297 |
+
|
| 298 |
+
if not output_attentions:
|
| 299 |
+
attn_weights = None
|
| 300 |
+
|
| 301 |
+
return attn_output, attn_weights, past_key_value
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class LlamaCrossAttention(nn.Module):
|
| 305 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, config: LlamaConfig):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.config = config
|
| 310 |
+
self.hidden_size = config.hidden_size
|
| 311 |
+
self.num_heads = config.num_attention_heads
|
| 312 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 313 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 314 |
+
self.vision_hidden_size = 3200
|
| 315 |
+
|
| 316 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
|
| 319 |
+
f' and `num_heads`: {self.num_heads}).'
|
| 320 |
+
)
|
| 321 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 322 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 323 |
+
self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 324 |
+
|
| 325 |
+
self.k_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 326 |
+
self.v_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 327 |
+
self.norm2 = LlamaRMSNorm(self.vision_hidden_size, eps=config.rms_norm_eps)
|
| 328 |
+
|
| 329 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 330 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
vision_hidden_states: torch.Tensor,
|
| 336 |
+
repeat_time: int = 1,
|
| 337 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 338 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 339 |
+
output_attentions: bool = False,
|
| 340 |
+
use_cache: bool = False,
|
| 341 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 342 |
+
hidden_states = self.norm1(hidden_states)
|
| 343 |
+
|
| 344 |
+
bsz, q_len, _ = hidden_states.size()
|
| 345 |
+
|
| 346 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 347 |
+
|
| 348 |
+
vision_hidden_states = self.norm2(vision_hidden_states)
|
| 349 |
+
|
| 350 |
+
bs_v, kv_len, _ = vision_hidden_states.size()
|
| 351 |
+
|
| 352 |
+
key_states = self.k_proj(vision_hidden_states).view(
|
| 353 |
+
bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 354 |
+
value_states = self.v_proj(vision_hidden_states).view(
|
| 355 |
+
bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 356 |
+
|
| 357 |
+
key_states = key_states.repeat(repeat_time, 1, 1, 1)
|
| 358 |
+
value_states = value_states.repeat(repeat_time, 1, 1, 1)
|
| 359 |
+
|
| 360 |
+
kv_seq_len = key_states.shape[-2]
|
| 361 |
+
if past_key_value is not None:
|
| 362 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 363 |
+
|
| 364 |
+
if past_key_value is not None:
|
| 365 |
+
# reuse k, v, self_attention
|
| 366 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 367 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 368 |
+
|
| 369 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 370 |
+
|
| 371 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 372 |
+
|
| 373 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
|
| 376 |
+
f' {attn_weights.size()}'
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if attention_mask is not None:
|
| 380 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
|
| 383 |
+
)
|
| 384 |
+
attn_weights = attn_weights + attention_mask
|
| 385 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 386 |
+
|
| 387 |
+
# upcast attention to fp32
|
| 388 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 389 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 390 |
+
|
| 391 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 392 |
+
raise ValueError(
|
| 393 |
+
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
|
| 394 |
+
f' {attn_output.size()}'
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
attn_output = attn_output.transpose(1, 2)
|
| 398 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 399 |
+
|
| 400 |
+
attn_output = self.o_proj(attn_output)
|
| 401 |
+
|
| 402 |
+
if not output_attentions:
|
| 403 |
+
attn_weights = None
|
| 404 |
+
|
| 405 |
+
return attn_output, attn_weights, past_key_value
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class LlamaDecoderLayer(nn.Module):
|
| 409 |
+
def __init__(self, config: LlamaConfig, use_cross_attn: bool):
|
| 410 |
+
super().__init__()
|
| 411 |
+
self.hidden_size = config.hidden_size
|
| 412 |
+
self.self_attn = LlamaAttention(config=config)
|
| 413 |
+
self.cross_attn = LlamaCrossAttention(config=config) if use_cross_attn else None
|
| 414 |
+
self.mlp = LlamaMLP(
|
| 415 |
+
hidden_size=self.hidden_size,
|
| 416 |
+
intermediate_size=config.intermediate_size,
|
| 417 |
+
hidden_act=config.hidden_act,
|
| 418 |
+
)
|
| 419 |
+
self.num_query_token = 96
|
| 420 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 421 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 422 |
+
|
| 423 |
+
def forward(
|
| 424 |
+
self,
|
| 425 |
+
hidden_states: torch.Tensor,
|
| 426 |
+
vision_hidden_states: torch.Tensor,
|
| 427 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 428 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 429 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 430 |
+
output_attentions: Optional[bool] = False,
|
| 431 |
+
use_cache: Optional[bool] = False,
|
| 432 |
+
repeat_time: int = 1,
|
| 433 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 434 |
+
"""
|
| 435 |
+
Args:
|
| 436 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 437 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 438 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 439 |
+
output_attentions (`bool`, *optional*):
|
| 440 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 441 |
+
returned tensors for more detail.
|
| 442 |
+
use_cache (`bool`, *optional*):
|
| 443 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 444 |
+
(see `past_key_values`).
|
| 445 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
residual = hidden_states
|
| 449 |
+
|
| 450 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 451 |
+
|
| 452 |
+
# Self Attention
|
| 453 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 454 |
+
hidden_states=hidden_states,
|
| 455 |
+
attention_mask=attention_mask,
|
| 456 |
+
position_ids=position_ids,
|
| 457 |
+
past_key_value=past_key_value,
|
| 458 |
+
output_attentions=output_attentions,
|
| 459 |
+
use_cache=use_cache,
|
| 460 |
+
)
|
| 461 |
+
hidden_states = residual + hidden_states
|
| 462 |
+
|
| 463 |
+
# when using generate function and cache mode, the size of hidden_states is 1,
|
| 464 |
+
# so we should not use cross attention
|
| 465 |
+
if self.cross_attn is not None and hidden_states.size(1) >= self.num_query_token \
|
| 466 |
+
and vision_hidden_states is not None:
|
| 467 |
+
query_feats = hidden_states[:, :self.num_query_token, :]
|
| 468 |
+
text_feats = hidden_states[:, self.num_query_token:, :]
|
| 469 |
+
residual = query_feats
|
| 470 |
+
query_feats, _, _ = self.cross_attn(
|
| 471 |
+
hidden_states=query_feats,
|
| 472 |
+
vision_hidden_states=vision_hidden_states,
|
| 473 |
+
attention_mask=None, # not use attention mask in cross attention
|
| 474 |
+
past_key_value=past_key_value,
|
| 475 |
+
output_attentions=output_attentions,
|
| 476 |
+
use_cache=use_cache,
|
| 477 |
+
repeat_time=repeat_time,
|
| 478 |
+
)
|
| 479 |
+
query_feats = residual + query_feats
|
| 480 |
+
hidden_states = torch.cat([query_feats, text_feats], dim=1)
|
| 481 |
+
|
| 482 |
+
# Fully Connected
|
| 483 |
+
residual = hidden_states
|
| 484 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 485 |
+
hidden_states = self.mlp(hidden_states)
|
| 486 |
+
hidden_states = residual + hidden_states
|
| 487 |
+
|
| 488 |
+
outputs = (hidden_states,)
|
| 489 |
+
|
| 490 |
+
if output_attentions:
|
| 491 |
+
outputs += (self_attn_weights,)
|
| 492 |
+
|
| 493 |
+
if use_cache:
|
| 494 |
+
outputs += (present_key_value,)
|
| 495 |
+
|
| 496 |
+
return outputs
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
LLAMA_START_DOCSTRING = r"""
|
| 500 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 501 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 502 |
+
etc.)
|
| 503 |
+
|
| 504 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 505 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 506 |
+
and behavior.
|
| 507 |
+
|
| 508 |
+
Parameters:
|
| 509 |
+
config ([`LlamaConfig`]):
|
| 510 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 511 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 512 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@add_start_docstrings(
|
| 517 |
+
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
|
| 518 |
+
LLAMA_START_DOCSTRING,
|
| 519 |
+
)
|
| 520 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
| 521 |
+
config_class = LlamaConfig
|
| 522 |
+
base_model_prefix = 'model'
|
| 523 |
+
supports_gradient_checkpointing = True
|
| 524 |
+
_no_split_modules = ['LlamaDecoderLayer']
|
| 525 |
+
_keys_to_ignore_on_load_unexpected = [r'decoder\.version']
|
| 526 |
+
|
| 527 |
+
def _init_weights(self, module):
|
| 528 |
+
std = self.config.initializer_range
|
| 529 |
+
if isinstance(module, nn.Linear):
|
| 530 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 531 |
+
if module.bias is not None:
|
| 532 |
+
module.bias.data.zero_()
|
| 533 |
+
elif isinstance(module, nn.Embedding):
|
| 534 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 535 |
+
if module.padding_idx is not None:
|
| 536 |
+
module.weight.data[module.padding_idx].zero_()
|
| 537 |
+
|
| 538 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 539 |
+
if isinstance(module, LlamaModel):
|
| 540 |
+
module.gradient_checkpointing = value
|
| 541 |
+
if isinstance(module, LlamaDecoderLayer):
|
| 542 |
+
module.gradient_checkpointing = value
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
| 546 |
+
Args:
|
| 547 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 548 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 549 |
+
it.
|
| 550 |
+
|
| 551 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 552 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 553 |
+
|
| 554 |
+
[What are input IDs?](../glossary#input-ids)
|
| 555 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 556 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 557 |
+
|
| 558 |
+
- 1 for tokens that are **not masked**,
|
| 559 |
+
- 0 for tokens that are **masked**.
|
| 560 |
+
|
| 561 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 562 |
+
|
| 563 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 564 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 565 |
+
|
| 566 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 567 |
+
`past_key_values`).
|
| 568 |
+
|
| 569 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 570 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 571 |
+
information on the default strategy.
|
| 572 |
+
|
| 573 |
+
- 1 indicates the head is **not masked**,
|
| 574 |
+
- 0 indicates the head is **masked**.
|
| 575 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 576 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 577 |
+
config.n_positions - 1]`.
|
| 578 |
+
|
| 579 |
+
[What are position IDs?](../glossary#position-ids)
|
| 580 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 581 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 582 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
| 583 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 584 |
+
|
| 585 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 586 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 587 |
+
|
| 588 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 589 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 590 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 591 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 592 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 593 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 594 |
+
model's internal embedding lookup matrix.
|
| 595 |
+
use_cache (`bool`, *optional*):
|
| 596 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 597 |
+
`past_key_values`).
|
| 598 |
+
output_attentions (`bool`, *optional*):
|
| 599 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 600 |
+
tensors for more detail.
|
| 601 |
+
output_hidden_states (`bool`, *optional*):
|
| 602 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 603 |
+
more detail.
|
| 604 |
+
return_dict (`bool`, *optional*):
|
| 605 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 606 |
+
"""
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
@add_start_docstrings(
|
| 610 |
+
'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
|
| 611 |
+
LLAMA_START_DOCSTRING,
|
| 612 |
+
)
|
| 613 |
+
class LlamaModel(LlamaPreTrainedModel):
|
| 614 |
+
"""
|
| 615 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
config: LlamaConfig
|
| 619 |
+
"""
|
| 620 |
+
|
| 621 |
+
def __init__(self, config: LlamaConfig):
|
| 622 |
+
super().__init__(config)
|
| 623 |
+
self.padding_idx = config.pad_token_id
|
| 624 |
+
self.vocab_size = config.vocab_size
|
| 625 |
+
self.cross_attention_frequency = config.cross_attention_frequency
|
| 626 |
+
self.num_query_token = config.num_query_token
|
| 627 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 628 |
+
use_cross_attn = [idx % self.cross_attention_frequency == 0 for idx in range(config.num_hidden_layers)]
|
| 629 |
+
self.layers = nn.ModuleList(
|
| 630 |
+
[LlamaDecoderLayer(config, use_cross_attn[idx]) for idx in range(config.num_hidden_layers)])
|
| 631 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 632 |
+
self.gradient_checkpointing = False
|
| 633 |
+
# Initialize weights and apply final processing
|
| 634 |
+
# self.post_init()
|
| 635 |
+
|
| 636 |
+
def get_input_embeddings(self):
|
| 637 |
+
return self.embed_tokens
|
| 638 |
+
|
| 639 |
+
def set_input_embeddings(self, value):
|
| 640 |
+
self.embed_tokens = value
|
| 641 |
+
|
| 642 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
| 643 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 644 |
+
# create causal mask
|
| 645 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 646 |
+
combined_attention_mask = None
|
| 647 |
+
if input_shape[-1] > 1:
|
| 648 |
+
combined_attention_mask = _make_causal_mask(
|
| 649 |
+
input_shape,
|
| 650 |
+
inputs_embeds.dtype,
|
| 651 |
+
device=inputs_embeds.device,
|
| 652 |
+
past_key_values_length=past_key_values_length,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if attention_mask is not None:
|
| 656 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 657 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 658 |
+
inputs_embeds.device
|
| 659 |
+
)
|
| 660 |
+
combined_attention_mask = (
|
| 661 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
return combined_attention_mask
|
| 665 |
+
|
| 666 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
input_ids: torch.LongTensor = None,
|
| 670 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 671 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 672 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 673 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 674 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 675 |
+
repeat_time: Optional[int] = 1,
|
| 676 |
+
use_cache: Optional[bool] = None,
|
| 677 |
+
output_attentions: Optional[bool] = None,
|
| 678 |
+
output_hidden_states: Optional[bool] = None,
|
| 679 |
+
use_zero_attention_mask: Optional[bool] = None,
|
| 680 |
+
return_dict: Optional[bool] = None,
|
| 681 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 682 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 683 |
+
output_hidden_states = (
|
| 684 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 685 |
+
)
|
| 686 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 687 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 688 |
+
|
| 689 |
+
# retrieve input_ids and inputs_embeds
|
| 690 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 691 |
+
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
|
| 692 |
+
elif input_ids is not None:
|
| 693 |
+
batch_size, seq_length = input_ids.shape
|
| 694 |
+
elif inputs_embeds is not None:
|
| 695 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 696 |
+
else:
|
| 697 |
+
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
|
| 698 |
+
seq_length_with_past = seq_length
|
| 699 |
+
past_key_values_length = 0
|
| 700 |
+
|
| 701 |
+
if past_key_values is not None:
|
| 702 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 703 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 704 |
+
|
| 705 |
+
if position_ids is None:
|
| 706 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 707 |
+
position_ids = torch.arange(
|
| 708 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 709 |
+
)
|
| 710 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 711 |
+
else:
|
| 712 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 713 |
+
|
| 714 |
+
if inputs_embeds is None:
|
| 715 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 716 |
+
# embed positions
|
| 717 |
+
if attention_mask is None:
|
| 718 |
+
attention_mask = torch.ones(
|
| 719 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 720 |
+
)
|
| 721 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 722 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 723 |
+
)
|
| 724 |
+
if use_zero_attention_mask:
|
| 725 |
+
attention_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
|
| 726 |
+
|
| 727 |
+
hidden_states = inputs_embeds
|
| 728 |
+
|
| 729 |
+
if self.gradient_checkpointing and self.training:
|
| 730 |
+
if use_cache:
|
| 731 |
+
logger.warning_once(
|
| 732 |
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
| 733 |
+
)
|
| 734 |
+
use_cache = False
|
| 735 |
+
|
| 736 |
+
# decoder layers
|
| 737 |
+
all_hidden_states = () if output_hidden_states else None
|
| 738 |
+
all_self_attns = () if output_attentions else None
|
| 739 |
+
next_decoder_cache = () if use_cache else None
|
| 740 |
+
|
| 741 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 742 |
+
if output_hidden_states:
|
| 743 |
+
all_hidden_states += (hidden_states,)
|
| 744 |
+
|
| 745 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 746 |
+
|
| 747 |
+
layer_outputs = decoder_layer(
|
| 748 |
+
hidden_states,
|
| 749 |
+
vision_hidden_states,
|
| 750 |
+
attention_mask=attention_mask,
|
| 751 |
+
position_ids=position_ids,
|
| 752 |
+
past_key_value=past_key_value,
|
| 753 |
+
output_attentions=output_attentions,
|
| 754 |
+
use_cache=use_cache,
|
| 755 |
+
repeat_time=repeat_time,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
hidden_states = layer_outputs[0]
|
| 759 |
+
|
| 760 |
+
if use_cache:
|
| 761 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 762 |
+
|
| 763 |
+
if output_attentions:
|
| 764 |
+
all_self_attns += (layer_outputs[1],)
|
| 765 |
+
|
| 766 |
+
hidden_states = self.norm(hidden_states)
|
| 767 |
+
|
| 768 |
+
# add hidden states from the last decoder layer
|
| 769 |
+
if output_hidden_states:
|
| 770 |
+
all_hidden_states += (hidden_states,)
|
| 771 |
+
|
| 772 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 773 |
+
if not return_dict:
|
| 774 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 775 |
+
return BaseModelOutputWithPast(
|
| 776 |
+
last_hidden_state=hidden_states,
|
| 777 |
+
past_key_values=next_cache,
|
| 778 |
+
hidden_states=all_hidden_states,
|
| 779 |
+
attentions=all_self_attns,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 783 |
+
def forward_train(
|
| 784 |
+
self,
|
| 785 |
+
input_ids: torch.LongTensor = None,
|
| 786 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 787 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 788 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 789 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 790 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 791 |
+
repeat_time: Optional[int] = 1,
|
| 792 |
+
use_cache: Optional[bool] = None,
|
| 793 |
+
output_attentions: Optional[bool] = None,
|
| 794 |
+
output_hidden_states: Optional[bool] = None,
|
| 795 |
+
return_dict: Optional[bool] = None,
|
| 796 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 797 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 798 |
+
output_hidden_states = (
|
| 799 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 800 |
+
)
|
| 801 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 802 |
+
|
| 803 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 804 |
+
|
| 805 |
+
# retrieve input_ids and inputs_embeds
|
| 806 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 807 |
+
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
|
| 808 |
+
elif input_ids is not None:
|
| 809 |
+
batch_size, seq_length = input_ids.shape
|
| 810 |
+
elif inputs_embeds is not None:
|
| 811 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 812 |
+
else:
|
| 813 |
+
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
|
| 814 |
+
|
| 815 |
+
seq_length_with_past = seq_length
|
| 816 |
+
past_key_values_length = 0
|
| 817 |
+
|
| 818 |
+
if past_key_values is not None:
|
| 819 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 820 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 821 |
+
|
| 822 |
+
if position_ids is None:
|
| 823 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 824 |
+
position_ids = torch.arange(
|
| 825 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 826 |
+
)
|
| 827 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 828 |
+
else:
|
| 829 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 830 |
+
|
| 831 |
+
if inputs_embeds is None:
|
| 832 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 833 |
+
# embed positions
|
| 834 |
+
# if attention_mask is None:
|
| 835 |
+
# attention_mask = torch.ones(
|
| 836 |
+
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 837 |
+
# )
|
| 838 |
+
# attention_mask = self._prepare_decoder_attention_mask(
|
| 839 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 840 |
+
# )
|
| 841 |
+
hidden_states = inputs_embeds
|
| 842 |
+
|
| 843 |
+
if self.gradient_checkpointing and self.training:
|
| 844 |
+
if use_cache:
|
| 845 |
+
logger.warning_once(
|
| 846 |
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
| 847 |
+
)
|
| 848 |
+
use_cache = False
|
| 849 |
+
|
| 850 |
+
# decoder layers
|
| 851 |
+
all_hidden_states = () if output_hidden_states else None
|
| 852 |
+
all_self_attns = () if output_attentions else None
|
| 853 |
+
next_decoder_cache = () if use_cache else None
|
| 854 |
+
|
| 855 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 856 |
+
if output_hidden_states:
|
| 857 |
+
all_hidden_states += (hidden_states,)
|
| 858 |
+
|
| 859 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 860 |
+
|
| 861 |
+
if self.gradient_checkpointing and self.training:
|
| 862 |
+
|
| 863 |
+
def create_custom_forward(module):
|
| 864 |
+
def custom_forward(*inputs):
|
| 865 |
+
# None for past_key_value
|
| 866 |
+
return module(*inputs, output_attentions, None, repeat_time)
|
| 867 |
+
|
| 868 |
+
return custom_forward
|
| 869 |
+
|
| 870 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 871 |
+
create_custom_forward(decoder_layer),
|
| 872 |
+
hidden_states,
|
| 873 |
+
vision_hidden_states,
|
| 874 |
+
attention_mask,
|
| 875 |
+
position_ids,
|
| 876 |
+
None,
|
| 877 |
+
)
|
| 878 |
+
else:
|
| 879 |
+
layer_outputs = decoder_layer(
|
| 880 |
+
hidden_states,
|
| 881 |
+
vision_hidden_states,
|
| 882 |
+
attention_mask=attention_mask,
|
| 883 |
+
position_ids=position_ids,
|
| 884 |
+
past_key_value=past_key_value,
|
| 885 |
+
output_attentions=output_attentions,
|
| 886 |
+
use_cache=use_cache,
|
| 887 |
+
repeat_time=repeat_time,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
hidden_states = layer_outputs[0]
|
| 891 |
+
|
| 892 |
+
if use_cache:
|
| 893 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 894 |
+
|
| 895 |
+
if output_attentions:
|
| 896 |
+
all_self_attns += (layer_outputs[1],)
|
| 897 |
+
|
| 898 |
+
hidden_states = self.norm(hidden_states)
|
| 899 |
+
|
| 900 |
+
# add hidden states from the last decoder layer
|
| 901 |
+
if output_hidden_states:
|
| 902 |
+
all_hidden_states += (hidden_states,)
|
| 903 |
+
|
| 904 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 905 |
+
if not return_dict:
|
| 906 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 907 |
+
return BaseModelOutputWithPast(
|
| 908 |
+
last_hidden_state=hidden_states,
|
| 909 |
+
past_key_values=next_cache,
|
| 910 |
+
hidden_states=all_hidden_states,
|
| 911 |
+
attentions=all_self_attns,
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
| 916 |
+
def __init__(self, config):
|
| 917 |
+
super().__init__(config)
|
| 918 |
+
self.model = LlamaModel(config)
|
| 919 |
+
|
| 920 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 921 |
+
|
| 922 |
+
# Initialize weights and apply final processing
|
| 923 |
+
# self.post_init()
|
| 924 |
+
|
| 925 |
+
def get_input_embeddings(self):
|
| 926 |
+
return self.model.embed_tokens
|
| 927 |
+
|
| 928 |
+
def set_input_embeddings(self, value):
|
| 929 |
+
self.model.embed_tokens = value
|
| 930 |
+
|
| 931 |
+
def get_output_embeddings(self):
|
| 932 |
+
return self.lm_head
|
| 933 |
+
|
| 934 |
+
def set_output_embeddings(self, new_embeddings):
|
| 935 |
+
self.lm_head = new_embeddings
|
| 936 |
+
|
| 937 |
+
def set_decoder(self, decoder):
|
| 938 |
+
self.model = decoder
|
| 939 |
+
|
| 940 |
+
def get_decoder(self):
|
| 941 |
+
return self.model
|
| 942 |
+
|
| 943 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
| 944 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 945 |
+
def forward(
|
| 946 |
+
self,
|
| 947 |
+
input_ids: torch.LongTensor = None,
|
| 948 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 949 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 950 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 951 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 952 |
+
vision_hidden_states: Optional[torch.FloatTensor] = None,
|
| 953 |
+
labels: Optional[torch.LongTensor] = None,
|
| 954 |
+
use_cache: Optional[bool] = None,
|
| 955 |
+
output_attentions: Optional[bool] = None,
|
| 956 |
+
output_hidden_states: Optional[bool] = None,
|
| 957 |
+
use_zero_attention_mask: Optional[bool] = None,
|
| 958 |
+
return_dict: Optional[bool] = None,
|
| 959 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 960 |
+
r"""
|
| 961 |
+
Args:
|
| 962 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 963 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 964 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 965 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 966 |
+
|
| 967 |
+
Returns:
|
| 968 |
+
|
| 969 |
+
Example:
|
| 970 |
+
|
| 971 |
+
```python
|
| 972 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
| 973 |
+
|
| 974 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 975 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 976 |
+
|
| 977 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
| 978 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 979 |
+
|
| 980 |
+
>>> # Generate
|
| 981 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 982 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 983 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
| 984 |
+
```"""
|
| 985 |
+
|
| 986 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 987 |
+
output_hidden_states = (
|
| 988 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 989 |
+
)
|
| 990 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 991 |
+
|
| 992 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 993 |
+
outputs = self.model(
|
| 994 |
+
input_ids=input_ids,
|
| 995 |
+
attention_mask=attention_mask,
|
| 996 |
+
position_ids=position_ids,
|
| 997 |
+
past_key_values=past_key_values,
|
| 998 |
+
inputs_embeds=inputs_embeds,
|
| 999 |
+
vision_hidden_states=vision_hidden_states,
|
| 1000 |
+
use_cache=use_cache,
|
| 1001 |
+
output_attentions=output_attentions,
|
| 1002 |
+
output_hidden_states=output_hidden_states,
|
| 1003 |
+
return_dict=return_dict,
|
| 1004 |
+
use_zero_attention_mask=use_zero_attention_mask,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
hidden_states = outputs[0]
|
| 1008 |
+
logits = self.lm_head(hidden_states)
|
| 1009 |
+
|
| 1010 |
+
loss = None
|
| 1011 |
+
if labels is not None:
|
| 1012 |
+
# Shift so that tokens < n predict n
|
| 1013 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1014 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1015 |
+
# Flatten the tokens
|
| 1016 |
+
loss_fct = CrossEntropyLoss()
|
| 1017 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1018 |
+
shift_labels = shift_labels.view(-1)
|
| 1019 |
+
# Enable model parallelism
|
| 1020 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1021 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1022 |
+
|
| 1023 |
+
if not return_dict:
|
| 1024 |
+
output = (logits,) + outputs[1:]
|
| 1025 |
+
return (loss,) + output if loss is not None else output
|
| 1026 |
+
|
| 1027 |
+
return CausalLMOutputWithPast(
|
| 1028 |
+
loss=loss,
|
| 1029 |
+
logits=logits,
|
| 1030 |
+
past_key_values=outputs.past_key_values,
|
| 1031 |
+
hidden_states=outputs.hidden_states,
|
| 1032 |
+
attentions=outputs.attentions,
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
def prepare_inputs_for_generation(
|
| 1036 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
| 1037 |
+
vision_hidden_states=None, use_zero_attention_mask=None, **kwargs
|
| 1038 |
+
):
|
| 1039 |
+
if past_key_values:
|
| 1040 |
+
input_ids = input_ids[:, -1:]
|
| 1041 |
+
|
| 1042 |
+
position_ids = kwargs.get('position_ids', None)
|
| 1043 |
+
if attention_mask is not None and position_ids is None:
|
| 1044 |
+
# create position_ids on the fly for batch generation
|
| 1045 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1046 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1047 |
+
if past_key_values:
|
| 1048 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1049 |
+
|
| 1050 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1051 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1052 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
| 1053 |
+
else:
|
| 1054 |
+
model_inputs = {'input_ids': input_ids}
|
| 1055 |
+
|
| 1056 |
+
model_inputs.update(
|
| 1057 |
+
{
|
| 1058 |
+
'position_ids': position_ids,
|
| 1059 |
+
'past_key_values': past_key_values,
|
| 1060 |
+
'use_cache': kwargs.get('use_cache'),
|
| 1061 |
+
'attention_mask': attention_mask,
|
| 1062 |
+
'vision_hidden_states': vision_hidden_states,
|
| 1063 |
+
'use_zero_attention_mask': use_zero_attention_mask,
|
| 1064 |
+
}
|
| 1065 |
+
)
|
| 1066 |
+
return model_inputs
|
| 1067 |
+
|
| 1068 |
+
@staticmethod
|
| 1069 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1070 |
+
reordered_past = ()
|
| 1071 |
+
for layer_past in past_key_values:
|
| 1072 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 1073 |
+
return reordered_past
|
VLMEvalKit_old/InternVL/internvl_g/internvl/train/dataset.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as T
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_transform(input_size):
|
| 14 |
+
# match fine-tune setting with blip2
|
| 15 |
+
# https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py
|
| 16 |
+
transform = T.Compose([
|
| 17 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 18 |
+
T.RandomResizedCrop(input_size, scale=(0.5, 1.0),
|
| 19 |
+
interpolation=InterpolationMode.BICUBIC),
|
| 20 |
+
T.RandomHorizontalFlip(),
|
| 21 |
+
T.ToTensor(),
|
| 22 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 23 |
+
])
|
| 24 |
+
return transform
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FlickrDataset(Dataset):
|
| 28 |
+
"""Dataset for supervised fine-tuning."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, metas, tokenizer, data_args):
|
| 31 |
+
super(FlickrDataset, self).__init__()
|
| 32 |
+
|
| 33 |
+
f = open(metas['annotation'])
|
| 34 |
+
lines = f.readlines()[1:]
|
| 35 |
+
|
| 36 |
+
self.data_args = data_args
|
| 37 |
+
self.tokenizer = tokenizer
|
| 38 |
+
self.images = []
|
| 39 |
+
self.image_ids = []
|
| 40 |
+
self.captions = []
|
| 41 |
+
|
| 42 |
+
for line in lines:
|
| 43 |
+
image, caption = line.strip().split('.jpg,')
|
| 44 |
+
image_id = int(image)
|
| 45 |
+
caption = self.process_single_caption(caption)
|
| 46 |
+
image = image + '.jpg'
|
| 47 |
+
image_path = metas['root'] + '/' + image
|
| 48 |
+
self.images.append(image_path)
|
| 49 |
+
self.image_ids.append(image_id)
|
| 50 |
+
self.captions.append(caption)
|
| 51 |
+
print(f'There are {len(self.images)} images.')
|
| 52 |
+
print(f'There are {len(self.captions)} captions.')
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.images)
|
| 56 |
+
|
| 57 |
+
def process_single_caption(self, caption, max_words=50):
|
| 58 |
+
caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
|
| 59 |
+
caption = re.sub(r'\s{2,}', ' ', caption)
|
| 60 |
+
caption = caption.rstrip('\n')
|
| 61 |
+
caption = caption.strip(' ')
|
| 62 |
+
|
| 63 |
+
# truncate caption
|
| 64 |
+
caption_words = caption.split(' ')
|
| 65 |
+
if len(caption_words) > max_words:
|
| 66 |
+
caption = ' '.join(caption_words[: max_words])
|
| 67 |
+
return caption
|
| 68 |
+
|
| 69 |
+
def preprocess(self, image, caption, neg_caption):
|
| 70 |
+
model_inputs = dict()
|
| 71 |
+
|
| 72 |
+
# input image
|
| 73 |
+
image_transform = build_transform(input_size=self.data_args.force_image_size)
|
| 74 |
+
image = Image.open(image)
|
| 75 |
+
image = image.convert('RGB')
|
| 76 |
+
pixel_values = image_transform(image)
|
| 77 |
+
model_inputs['pixel_values'] = pixel_values
|
| 78 |
+
|
| 79 |
+
# for image-text matching
|
| 80 |
+
pos_model_inputs = self.tokenizer(
|
| 81 |
+
caption,
|
| 82 |
+
max_length=self.data_args.max_seq_length,
|
| 83 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 84 |
+
truncation=True,
|
| 85 |
+
return_tensors='pt',
|
| 86 |
+
)
|
| 87 |
+
model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
|
| 88 |
+
model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
|
| 89 |
+
neg_model_inputs = self.tokenizer(
|
| 90 |
+
neg_caption,
|
| 91 |
+
max_length=self.data_args.max_seq_length,
|
| 92 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 93 |
+
truncation=True,
|
| 94 |
+
return_tensors='pt',
|
| 95 |
+
)
|
| 96 |
+
model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
|
| 97 |
+
model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
|
| 98 |
+
|
| 99 |
+
# for image-text contrastive learning
|
| 100 |
+
summarize_model_inputs = self.tokenizer(
|
| 101 |
+
'summarize:' + caption,
|
| 102 |
+
max_length=self.data_args.max_seq_length,
|
| 103 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 104 |
+
truncation=True,
|
| 105 |
+
return_tensors='pt',
|
| 106 |
+
)
|
| 107 |
+
model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
|
| 108 |
+
model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
|
| 109 |
+
|
| 110 |
+
# for image-grounded text generation
|
| 111 |
+
prefix = f'English caption:'
|
| 112 |
+
content = caption
|
| 113 |
+
tokenized_prefix = self.tokenizer(
|
| 114 |
+
prefix, padding=False, truncation=True, return_tensors='pt',
|
| 115 |
+
)
|
| 116 |
+
prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
|
| 117 |
+
prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
|
| 118 |
+
tokenized_content = self.tokenizer(
|
| 119 |
+
content,
|
| 120 |
+
max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
|
| 121 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 122 |
+
truncation=True,
|
| 123 |
+
return_tensors='pt',
|
| 124 |
+
)
|
| 125 |
+
content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
|
| 126 |
+
content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
|
| 127 |
+
model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
|
| 128 |
+
model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
|
| 129 |
+
labels = model_inputs['input_ids'].clone()
|
| 130 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 131 |
+
labels[:, :prefix_input_ids.size(1) - 1] = -100
|
| 132 |
+
model_inputs['labels'] = labels
|
| 133 |
+
return model_inputs
|
| 134 |
+
|
| 135 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 136 |
+
i = i % len(self.images)
|
| 137 |
+
j = random.randint(0, len(self.images) - 1)
|
| 138 |
+
while self.image_ids[j] == self.image_ids[i]:
|
| 139 |
+
j = random.randint(0, len(self.images) - 1)
|
| 140 |
+
ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
|
| 141 |
+
# for image-text matching
|
| 142 |
+
ret['positive_input_ids'] = ret['positive_input_ids'][0]
|
| 143 |
+
ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
|
| 144 |
+
ret['negative_input_ids'] = ret['negative_input_ids'][0]
|
| 145 |
+
ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
|
| 146 |
+
# for image-text contrastive learning
|
| 147 |
+
ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
|
| 148 |
+
ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
|
| 149 |
+
# for image-grounded text generation
|
| 150 |
+
ret['input_ids'] = ret['input_ids'][0]
|
| 151 |
+
ret['attention_mask'] = ret['attention_mask'][0]
|
| 152 |
+
ret['labels'] = ret['labels'][0]
|
| 153 |
+
ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
|
| 154 |
+
return ret
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class COCODataset(Dataset):
|
| 158 |
+
"""Dataset for supervised fine-tuning."""
|
| 159 |
+
|
| 160 |
+
def __init__(self, metas, tokenizer, data_args):
|
| 161 |
+
super(COCODataset, self).__init__()
|
| 162 |
+
|
| 163 |
+
annotations = json.load(open(metas['annotation']))
|
| 164 |
+
|
| 165 |
+
self.data_args = data_args
|
| 166 |
+
self.tokenizer = tokenizer
|
| 167 |
+
self.images = []
|
| 168 |
+
self.image_ids = []
|
| 169 |
+
self.captions = []
|
| 170 |
+
|
| 171 |
+
for annotation in annotations:
|
| 172 |
+
image_id = int(annotation['image_id'].split('_')[-1])
|
| 173 |
+
caption = annotation['caption']
|
| 174 |
+
caption = self.process_single_caption(caption)
|
| 175 |
+
image = annotation['image']
|
| 176 |
+
image_path = metas['root'] + '/' + image
|
| 177 |
+
self.images.append(image_path)
|
| 178 |
+
self.image_ids.append(image_id)
|
| 179 |
+
self.captions.append(caption)
|
| 180 |
+
print(f'There are {len(self.images)} images.')
|
| 181 |
+
print(f'There are {len(self.captions)} captions.')
|
| 182 |
+
|
| 183 |
+
def __len__(self):
|
| 184 |
+
return len(self.images)
|
| 185 |
+
|
| 186 |
+
def process_single_caption(self, caption, max_words=50):
|
| 187 |
+
caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
|
| 188 |
+
caption = re.sub(r'\s{2,}', ' ', caption)
|
| 189 |
+
caption = caption.rstrip('\n')
|
| 190 |
+
caption = caption.strip(' ')
|
| 191 |
+
|
| 192 |
+
# truncate caption
|
| 193 |
+
caption_words = caption.split(' ')
|
| 194 |
+
if len(caption_words) > max_words:
|
| 195 |
+
caption = ' '.join(caption_words[: max_words])
|
| 196 |
+
return caption
|
| 197 |
+
|
| 198 |
+
def preprocess(self, image, caption, neg_caption):
|
| 199 |
+
model_inputs = dict()
|
| 200 |
+
|
| 201 |
+
# input image
|
| 202 |
+
image_transform = build_transform(input_size=self.data_args.force_image_size)
|
| 203 |
+
image = Image.open(image)
|
| 204 |
+
image = image.convert('RGB')
|
| 205 |
+
pixel_values = image_transform(image)
|
| 206 |
+
model_inputs['pixel_values'] = pixel_values
|
| 207 |
+
|
| 208 |
+
# for image-text matching
|
| 209 |
+
pos_model_inputs = self.tokenizer(
|
| 210 |
+
caption,
|
| 211 |
+
max_length=self.data_args.max_seq_length,
|
| 212 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 213 |
+
truncation=True,
|
| 214 |
+
return_tensors='pt',
|
| 215 |
+
)
|
| 216 |
+
model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
|
| 217 |
+
model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
|
| 218 |
+
neg_model_inputs = self.tokenizer(
|
| 219 |
+
neg_caption,
|
| 220 |
+
max_length=self.data_args.max_seq_length,
|
| 221 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 222 |
+
truncation=True,
|
| 223 |
+
return_tensors='pt',
|
| 224 |
+
)
|
| 225 |
+
model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
|
| 226 |
+
model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
|
| 227 |
+
|
| 228 |
+
# for image-text contrastive learning
|
| 229 |
+
summarize_model_inputs = self.tokenizer(
|
| 230 |
+
'summarize:' + caption,
|
| 231 |
+
max_length=self.data_args.max_seq_length,
|
| 232 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 233 |
+
truncation=True,
|
| 234 |
+
return_tensors='pt',
|
| 235 |
+
)
|
| 236 |
+
model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
|
| 237 |
+
model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
|
| 238 |
+
|
| 239 |
+
# for image-grounded text generation
|
| 240 |
+
prefix = f'English caption:'
|
| 241 |
+
content = caption
|
| 242 |
+
tokenized_prefix = self.tokenizer(
|
| 243 |
+
prefix, padding=False, truncation=True, return_tensors='pt',
|
| 244 |
+
)
|
| 245 |
+
prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
|
| 246 |
+
prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
|
| 247 |
+
tokenized_content = self.tokenizer(
|
| 248 |
+
content,
|
| 249 |
+
max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
|
| 250 |
+
padding='max_length' if self.data_args.pad_to_max_length else False,
|
| 251 |
+
truncation=True,
|
| 252 |
+
return_tensors='pt',
|
| 253 |
+
)
|
| 254 |
+
content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
|
| 255 |
+
content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
|
| 256 |
+
model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
|
| 257 |
+
model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
|
| 258 |
+
labels = model_inputs['input_ids'].clone()
|
| 259 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 260 |
+
labels[:, :prefix_input_ids.size(1) - 1] = -100
|
| 261 |
+
model_inputs['labels'] = labels
|
| 262 |
+
return model_inputs
|
| 263 |
+
|
| 264 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 265 |
+
i = i % len(self.images)
|
| 266 |
+
j = random.randint(0, len(self.images) - 1)
|
| 267 |
+
while self.image_ids[j] == self.image_ids[i]:
|
| 268 |
+
j = random.randint(0, len(self.images) - 1)
|
| 269 |
+
ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
|
| 270 |
+
# for image-text matching
|
| 271 |
+
ret['positive_input_ids'] = ret['positive_input_ids'][0]
|
| 272 |
+
ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
|
| 273 |
+
ret['negative_input_ids'] = ret['negative_input_ids'][0]
|
| 274 |
+
ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
|
| 275 |
+
# for image-text contrastive learning
|
| 276 |
+
ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
|
| 277 |
+
ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
|
| 278 |
+
# for image-grounded text generation
|
| 279 |
+
ret['input_ids'] = ret['input_ids'][0]
|
| 280 |
+
ret['attention_mask'] = ret['attention_mask'][0]
|
| 281 |
+
ret['labels'] = ret['labels'][0]
|
| 282 |
+
ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
|
| 283 |
+
return ret
|
VLMEvalKit_old/InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import warnings
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import transformers
|
| 10 |
+
from internvl.dist_utils import init_dist
|
| 11 |
+
from internvl.model.internvl_stage2_retrieval import (InternVLConfig,
|
| 12 |
+
InternVLModel)
|
| 13 |
+
from internvl.train.dataset import COCODataset, FlickrDataset
|
| 14 |
+
from internvl.train.trainer_monkey_patch import replace_create_optimizer
|
| 15 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
| 16 |
+
from transformers import (HfArgumentParser, LlamaTokenizer, Trainer,
|
| 17 |
+
TrainingArguments, default_data_collator, set_seed)
|
| 18 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 19 |
+
from transformers.utils.logging import (enable_default_handler,
|
| 20 |
+
enable_explicit_format, set_verbosity)
|
| 21 |
+
|
| 22 |
+
IGNORE_INDEX = -100
|
| 23 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 25 |
+
MaximumDecompressedSize = 1024
|
| 26 |
+
MegaByte = 2 ** 20
|
| 27 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 28 |
+
|
| 29 |
+
warnings.filterwarnings('ignore')
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 33 |
+
|
| 34 |
+
ds_collections = {
|
| 35 |
+
'flickr30k_en_train': {
|
| 36 |
+
'root': './data/flickr30k/Images/',
|
| 37 |
+
'annotation': './data/flickr30k/flickr30k_train_karpathy.txt',
|
| 38 |
+
},
|
| 39 |
+
'flickr30k_cn_train': {
|
| 40 |
+
'root': './data/flickr30k/Images/',
|
| 41 |
+
'annotation': './data/flickr30k/flickr30k_cn_train.txt',
|
| 42 |
+
},
|
| 43 |
+
'coco_karpathy_train': {
|
| 44 |
+
'root': './data/coco/',
|
| 45 |
+
'annotation': './data/coco/annotations/coco_karpathy_train.json',
|
| 46 |
+
},
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class ModelArguments:
|
| 52 |
+
"""
|
| 53 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 54 |
+
"""
|
| 55 |
+
model_name_or_path: str = field(
|
| 56 |
+
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
|
| 57 |
+
)
|
| 58 |
+
freeze_model: bool = field(
|
| 59 |
+
default=False,
|
| 60 |
+
metadata={'help': 'Set to True to freeze the entire model.'},
|
| 61 |
+
)
|
| 62 |
+
freeze_vision_model: bool = field(
|
| 63 |
+
default=False,
|
| 64 |
+
metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
|
| 65 |
+
)
|
| 66 |
+
freeze_qllama: bool = field(
|
| 67 |
+
default=False,
|
| 68 |
+
metadata={'help': 'Set to True to freeze the QLLaMA of the model.'},
|
| 69 |
+
)
|
| 70 |
+
unfreeze_qllama_head: bool = field(
|
| 71 |
+
default=False,
|
| 72 |
+
metadata={'help': 'Set to True to unfreeze the head of the QLLaMA.'},
|
| 73 |
+
)
|
| 74 |
+
unfreeze_crossattn: bool = field(
|
| 75 |
+
default=False,
|
| 76 |
+
metadata={'help': 'Set to True to unfreeze the cross attention layers in the QLLaMA.'},
|
| 77 |
+
)
|
| 78 |
+
use_backbone_lora: int = field(
|
| 79 |
+
default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the vision backbone of the model'}
|
| 80 |
+
)
|
| 81 |
+
use_qllama_lora: int = field(
|
| 82 |
+
default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the QLLaMA of the model'}
|
| 83 |
+
)
|
| 84 |
+
use_custom_trainer: bool = field(
|
| 85 |
+
default=False, metadata={'help': 'Set to True to enable the use of a custom trainer.'},
|
| 86 |
+
)
|
| 87 |
+
drop_path_rate: float = field(
|
| 88 |
+
default=0.0, metadata={'help': 'Specify the value of drop path rate in the vision backbone. Default is 0.'}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class DataTrainingArguments:
|
| 94 |
+
"""
|
| 95 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 96 |
+
"""
|
| 97 |
+
dataset_name: Optional[str] = field(
|
| 98 |
+
default='flickr30k_en_train',
|
| 99 |
+
metadata={'help': 'Specify the name of dataset to be used.'},
|
| 100 |
+
)
|
| 101 |
+
max_seq_length: Optional[int] = field(
|
| 102 |
+
default=80,
|
| 103 |
+
metadata={
|
| 104 |
+
'help': (
|
| 105 |
+
'The maximum total input sequence length after tokenization. Sequences longer '
|
| 106 |
+
'than this will be truncated, sequences shorter will be padded.'
|
| 107 |
+
)
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
force_image_size: Optional[int] = field(
|
| 111 |
+
default=224,
|
| 112 |
+
metadata={'help': 'Specify the image size for training models.'},
|
| 113 |
+
)
|
| 114 |
+
pad_to_max_length: bool = field(
|
| 115 |
+
default=False,
|
| 116 |
+
metadata={
|
| 117 |
+
'help': (
|
| 118 |
+
'Whether to pad all samples to model maximum sentence length. '
|
| 119 |
+
'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
|
| 120 |
+
'efficient on GPU but very bad for TPU.'
|
| 121 |
+
)
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main():
|
| 127 |
+
# Parse input arguments
|
| 128 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 129 |
+
# If use DeepSpeed zero3, init_dist must before HfArgumentParser
|
| 130 |
+
launcher = os.environ.get('LAUNCHER', 'slurm')
|
| 131 |
+
init_dist(launcher=launcher, backend='nccl')
|
| 132 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 133 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
| 134 |
+
# If we pass only one argument to the script, and it's the path to a json file,
|
| 135 |
+
# let's parse it to get our arguments.
|
| 136 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 137 |
+
else:
|
| 138 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 139 |
+
|
| 140 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 141 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
| 142 |
+
# send_example_telemetry('finetune Flickr30K', model_args, data_args)
|
| 143 |
+
|
| 144 |
+
# Setup logging
|
| 145 |
+
logging.basicConfig(
|
| 146 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
| 147 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
| 148 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if training_args.should_log:
|
| 152 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 153 |
+
transformers.utils.logging.set_verbosity_info()
|
| 154 |
+
|
| 155 |
+
log_level = training_args.get_process_log_level()
|
| 156 |
+
logger.setLevel(log_level)
|
| 157 |
+
set_verbosity(log_level)
|
| 158 |
+
enable_default_handler()
|
| 159 |
+
enable_explicit_format()
|
| 160 |
+
|
| 161 |
+
# Log on each process the small summary:
|
| 162 |
+
logger.warning(
|
| 163 |
+
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
|
| 164 |
+
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
|
| 165 |
+
)
|
| 166 |
+
logger.info(f'Training/evaluation parameters {training_args}')
|
| 167 |
+
|
| 168 |
+
# Detecting last checkpoint and eventually continue from last checkpoint.
|
| 169 |
+
last_checkpoint = None
|
| 170 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
| 171 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 172 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f'Output directory ({training_args.output_dir}) already exists and is not empty. '
|
| 175 |
+
'Use --overwrite_output_dir to overcome.'
|
| 176 |
+
)
|
| 177 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 178 |
+
logger.info(
|
| 179 |
+
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
|
| 180 |
+
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
|
| 181 |
+
)
|
| 182 |
+
# Set seed before initializing model.
|
| 183 |
+
set_seed(training_args.seed)
|
| 184 |
+
|
| 185 |
+
# Load pretrained model, tokenizer, and image processor
|
| 186 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
| 187 |
+
model_args.model_name_or_path,
|
| 188 |
+
add_eos_token=True
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if 'flickr' in data_args.dataset_name:
|
| 192 |
+
train_dataset = FlickrDataset(metas=ds_collections[data_args.dataset_name],
|
| 193 |
+
tokenizer=tokenizer, data_args=data_args)
|
| 194 |
+
elif 'coco' in data_args.dataset_name:
|
| 195 |
+
train_dataset = COCODataset(metas=ds_collections[data_args.dataset_name],
|
| 196 |
+
tokenizer=tokenizer, data_args=data_args)
|
| 197 |
+
config = InternVLConfig.from_pretrained(model_args.model_name_or_path)
|
| 198 |
+
config.vision_config.drop_path_rate = model_args.drop_path_rate
|
| 199 |
+
model = InternVLModel.from_pretrained(
|
| 200 |
+
model_args.model_name_or_path,
|
| 201 |
+
# ignore_mismatched_sizes=True,
|
| 202 |
+
config=config
|
| 203 |
+
)
|
| 204 |
+
if data_args.force_image_size != 224:
|
| 205 |
+
model.config.force_image_size = data_args.force_image_size
|
| 206 |
+
model.vision_model.resize_pos_embeddings(old_size=224, new_size=data_args.force_image_size, patch_size=14)
|
| 207 |
+
|
| 208 |
+
model.config.use_cache = False
|
| 209 |
+
model.config.qllama_config.use_cache = False
|
| 210 |
+
model.qllama.gradient_checkpointing = True
|
| 211 |
+
model.qllama.model.gradient_checkpointing = True
|
| 212 |
+
model.vision_model.gradient_checkpointing = True
|
| 213 |
+
model.vision_model.encoder.gradient_checkpointing = True
|
| 214 |
+
|
| 215 |
+
def _freeze_params(module):
|
| 216 |
+
for param in module.parameters():
|
| 217 |
+
param.requires_grad = False
|
| 218 |
+
|
| 219 |
+
if model_args.freeze_model:
|
| 220 |
+
_freeze_params(model)
|
| 221 |
+
|
| 222 |
+
if model_args.freeze_vision_model:
|
| 223 |
+
model.vision_model = model.vision_model.eval()
|
| 224 |
+
_freeze_params(model.vision_model)
|
| 225 |
+
|
| 226 |
+
if model_args.freeze_qllama:
|
| 227 |
+
model.qllama = model.qllama.eval()
|
| 228 |
+
_freeze_params(model.qllama)
|
| 229 |
+
|
| 230 |
+
if model_args.use_backbone_lora:
|
| 231 |
+
model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=model_args.use_backbone_lora * 2)
|
| 232 |
+
model.config.use_backbone_lora = model_args.use_backbone_lora
|
| 233 |
+
|
| 234 |
+
if model_args.use_qllama_lora:
|
| 235 |
+
model.wrap_qllama_lora(r=model_args.use_qllama_lora, lora_alpha=model_args.use_backbone_lora * 2)
|
| 236 |
+
model.config.use_qllama_lora = model_args.use_qllama_lora
|
| 237 |
+
|
| 238 |
+
if model_args.unfreeze_crossattn:
|
| 239 |
+
for name, param in model.qllama.named_parameters():
|
| 240 |
+
if 'cross_attn' in name:
|
| 241 |
+
param.requires_grad = True
|
| 242 |
+
|
| 243 |
+
if model_args.unfreeze_qllama_head:
|
| 244 |
+
model.qllama.lm_head.weight.requires_grad = True
|
| 245 |
+
model.text_projection.requires_grad = True
|
| 246 |
+
|
| 247 |
+
# print trainable parameters
|
| 248 |
+
if dist.get_rank() == 0:
|
| 249 |
+
for name, param in model.named_parameters():
|
| 250 |
+
print(name, param.requires_grad)
|
| 251 |
+
|
| 252 |
+
# set seed for torch dataloaders
|
| 253 |
+
set_seed(training_args.seed)
|
| 254 |
+
|
| 255 |
+
# Initialize our Trainer
|
| 256 |
+
if model_args.use_custom_trainer:
|
| 257 |
+
replace_create_optimizer()
|
| 258 |
+
|
| 259 |
+
trainer = Trainer(
|
| 260 |
+
model=model,
|
| 261 |
+
args=training_args,
|
| 262 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
| 263 |
+
eval_dataset=None,
|
| 264 |
+
tokenizer=tokenizer,
|
| 265 |
+
data_collator=default_data_collator,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Training
|
| 269 |
+
if training_args.do_train:
|
| 270 |
+
checkpoint = None
|
| 271 |
+
if training_args.resume_from_checkpoint is not None:
|
| 272 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 273 |
+
elif last_checkpoint is not None:
|
| 274 |
+
checkpoint = last_checkpoint
|
| 275 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 276 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
| 277 |
+
|
| 278 |
+
metrics = train_result.metrics
|
| 279 |
+
metrics['train_samples'] = len(train_dataset)
|
| 280 |
+
trainer.log_metrics('train', metrics)
|
| 281 |
+
trainer.save_metrics('train', metrics)
|
| 282 |
+
trainer.save_state()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if __name__ == '__main__':
|
| 286 |
+
main()
|
VLMEvalKit_old/InternVL/internvl_g/internvl/train/trainer_monkey_patch.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import transformers
|
| 7 |
+
from transformers import Trainer, logging
|
| 8 |
+
from transformers.trainer import is_sagemaker_mp_enabled
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
|
| 14 |
+
if var_name in ('query_tokens', 'logit_scale',):
|
| 15 |
+
return 0
|
| 16 |
+
if var_name.startswith('clip_projector.'):
|
| 17 |
+
return vit_num_max_layer
|
| 18 |
+
if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
|
| 19 |
+
var_name == 'text_projection':
|
| 20 |
+
return llama_num_max_layer
|
| 21 |
+
if var_name.startswith('vision_model.'):
|
| 22 |
+
if 'embeddings.' in var_name:
|
| 23 |
+
return 0
|
| 24 |
+
if 'layers.' in var_name:
|
| 25 |
+
var_name = var_name.split('layers.')[-1]
|
| 26 |
+
layer_id = int(var_name.split('.')[0])
|
| 27 |
+
return layer_id + 1
|
| 28 |
+
if var_name.startswith('qllama.'):
|
| 29 |
+
if 'embed_tokens' in var_name:
|
| 30 |
+
return 0
|
| 31 |
+
if 'layers.' in var_name:
|
| 32 |
+
var_name = var_name.split('layers.')[-1]
|
| 33 |
+
layer_id = int(var_name.split('.')[0])
|
| 34 |
+
return layer_id + 1
|
| 35 |
+
else:
|
| 36 |
+
return llama_num_max_layer
|
| 37 |
+
return 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def param_classification(name):
|
| 41 |
+
if name in ['query_tokens', 'text_projection', 'logit_scale']:
|
| 42 |
+
return 'qllama'
|
| 43 |
+
elif name.startswith('vision_model.'):
|
| 44 |
+
return 'vit'
|
| 45 |
+
elif name.startswith('qllama.'):
|
| 46 |
+
return 'qllama'
|
| 47 |
+
elif name.startswith('clip_projector.'):
|
| 48 |
+
return 'vit'
|
| 49 |
+
elif name.startswith('clip_projector2.'):
|
| 50 |
+
return 'qllama'
|
| 51 |
+
elif name.startswith('itm_head.'):
|
| 52 |
+
return 'qllama'
|
| 53 |
+
else:
|
| 54 |
+
return 'other'
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def create_optimizer(self):
|
| 58 |
+
"""
|
| 59 |
+
Setup the optimizer.
|
| 60 |
+
|
| 61 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 62 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 63 |
+
"""
|
| 64 |
+
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
| 65 |
+
|
| 66 |
+
parameter_groups = {}
|
| 67 |
+
try: # for stage2 model
|
| 68 |
+
vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
|
| 69 |
+
qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
|
| 70 |
+
except: # for stage3 model
|
| 71 |
+
vit_num_layers = opt_model.qllama.config.vision_config.num_hidden_layers + 2
|
| 72 |
+
qllama_num_layers = opt_model.qllama.config.qllama_config.num_hidden_layers + 2
|
| 73 |
+
print('vit_num_layers:', vit_num_layers)
|
| 74 |
+
print('qllama_num_layers:', qllama_num_layers)
|
| 75 |
+
|
| 76 |
+
vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
|
| 77 |
+
qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
|
| 78 |
+
print('vit_layer_decay_rate:', vit_layer_decay_rate)
|
| 79 |
+
print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
|
| 80 |
+
|
| 81 |
+
for name, param in opt_model.named_parameters():
|
| 82 |
+
if not param.requires_grad:
|
| 83 |
+
continue # frozen weights
|
| 84 |
+
if len(param.shape) == 1 or name.endswith('.bias'):
|
| 85 |
+
group_name = 'no_decay'
|
| 86 |
+
this_weight_decay = 0.
|
| 87 |
+
else:
|
| 88 |
+
group_name = 'decay'
|
| 89 |
+
this_weight_decay = self.args.weight_decay
|
| 90 |
+
|
| 91 |
+
cls = param_classification(name)
|
| 92 |
+
layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
|
| 93 |
+
group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
|
| 94 |
+
if group_name not in parameter_groups:
|
| 95 |
+
if cls == 'vit':
|
| 96 |
+
scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
|
| 97 |
+
else:
|
| 98 |
+
scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
|
| 99 |
+
scale = min(1.0, scale)
|
| 100 |
+
parameter_groups[group_name] = {
|
| 101 |
+
'weight_decay': this_weight_decay,
|
| 102 |
+
'params': [],
|
| 103 |
+
'param_names': [],
|
| 104 |
+
'lr_scale': scale,
|
| 105 |
+
'group_name': group_name,
|
| 106 |
+
'lr': scale * self.args.learning_rate,
|
| 107 |
+
}
|
| 108 |
+
parameter_groups[group_name]['params'].append(param)
|
| 109 |
+
parameter_groups[group_name]['param_names'].append(name)
|
| 110 |
+
|
| 111 |
+
rank = torch.distributed.get_rank()
|
| 112 |
+
if rank == 0:
|
| 113 |
+
to_display = {}
|
| 114 |
+
for key in parameter_groups:
|
| 115 |
+
to_display[key] = {
|
| 116 |
+
'param_names': parameter_groups[key]['param_names'],
|
| 117 |
+
'lr_scale': parameter_groups[key]['lr_scale'],
|
| 118 |
+
'lr': parameter_groups[key]['lr'],
|
| 119 |
+
'weight_decay': parameter_groups[key]['weight_decay'],
|
| 120 |
+
}
|
| 121 |
+
print('Param groups = %s' % json.dumps(to_display, indent=2))
|
| 122 |
+
|
| 123 |
+
optimizer_grouped_parameters = list(parameter_groups.values())
|
| 124 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
| 125 |
+
|
| 126 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 127 |
+
if optimizer_cls.__name__ == 'Adam8bit':
|
| 128 |
+
import bitsandbytes
|
| 129 |
+
|
| 130 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
| 131 |
+
|
| 132 |
+
skipped = 0
|
| 133 |
+
for module in opt_model.modules():
|
| 134 |
+
if isinstance(module, nn.Embedding):
|
| 135 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
| 136 |
+
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
|
| 137 |
+
manager.register_module_override(module, 'weight', {'optim_bits': 32})
|
| 138 |
+
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
|
| 139 |
+
logger.info(f'skipped: {skipped / 2 ** 20}M params')
|
| 140 |
+
|
| 141 |
+
if is_sagemaker_mp_enabled():
|
| 142 |
+
import smdistributed.modelparallel.torch as smp
|
| 143 |
+
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
| 144 |
+
|
| 145 |
+
return self.optimizer
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def replace_create_optimizer():
|
| 149 |
+
print('Replace original create_optimizer with custom create_optimizer')
|
| 150 |
+
transformers.Trainer.create_optimizer = create_optimizer
|
VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
export VIT_LAYER_DECAY_RATE=0.9
|
| 4 |
+
export QLLAMA_LAYER_DECAY_RATE=0.9
|
| 5 |
+
|
| 6 |
+
PARTITION=${PARTITION:-"VC2"}
|
| 7 |
+
GPUS=${GPUS:-32}
|
| 8 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 9 |
+
QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
|
| 10 |
+
NODES=$((GPUS / GPUS_PER_NODE))
|
| 11 |
+
CPUS_PER_TASK=${CPUS_PER_TASK:-10}
|
| 12 |
+
SRUN_ARGS=${SRUN_ARGS:-""}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
| 16 |
+
|
| 17 |
+
# number of gpus: 32
|
| 18 |
+
# batch size per gpu: 32
|
| 19 |
+
# gradient accumulation steps: 1
|
| 20 |
+
# total batch size: 1024
|
| 21 |
+
# epoch: 5
|
| 22 |
+
srun -p ${PARTITION} \
|
| 23 |
+
--gres=gpu:${GPUS_PER_NODE} \
|
| 24 |
+
--nodes=${NODES} \
|
| 25 |
+
--ntasks=${GPUS} \
|
| 26 |
+
--ntasks-per-node=${GPUS_PER_NODE} \
|
| 27 |
+
--cpus-per-task=${CPUS_PER_TASK} \
|
| 28 |
+
--kill-on-bad-exit=1 \
|
| 29 |
+
--quotatype=${QUOTA_TYPE} \
|
| 30 |
+
${SRUN_ARGS} \
|
| 31 |
+
python -u internvl/train/internvl_stage2_finetune.py \
|
| 32 |
+
--dataset_name 'coco_karpathy_train' \
|
| 33 |
+
--model_name_or_path "./pretrained/InternVL-14B-224px" \
|
| 34 |
+
--output_dir "./work_dirs/internvl_stage2_finetune_coco_364_bs1024_ep5" \
|
| 35 |
+
--overwrite_output_dir True \
|
| 36 |
+
--force_image_size 364 \
|
| 37 |
+
--drop_path_rate 0.3 \
|
| 38 |
+
--use_custom_trainer \
|
| 39 |
+
--dataloader_num_workers 2 \
|
| 40 |
+
--pad_to_max_length True \
|
| 41 |
+
--bf16 True \
|
| 42 |
+
--num_train_epochs 5 \
|
| 43 |
+
--per_device_train_batch_size 32 \
|
| 44 |
+
--gradient_accumulation_steps 1 \
|
| 45 |
+
--evaluation_strategy "no" \
|
| 46 |
+
--save_strategy "steps" \
|
| 47 |
+
--save_steps 100 \
|
| 48 |
+
--save_total_limit 5 \
|
| 49 |
+
--learning_rate 1e-6 \
|
| 50 |
+
--weight_decay 0.05 \
|
| 51 |
+
--warmup_steps 100 \
|
| 52 |
+
--lr_scheduler_type "cosine" \
|
| 53 |
+
--logging_steps 1 \
|
| 54 |
+
--max_seq_length 80 \
|
| 55 |
+
--do_train True \
|
| 56 |
+
--optim adamw_torch \
|
| 57 |
+
--deepspeed "zero_stage1_config.json" \
|
| 58 |
+
--report_to "tensorboard"
|
VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
export VIT_LAYER_DECAY_RATE=0.9
|
| 4 |
+
export QLLAMA_LAYER_DECAY_RATE=0.9
|
| 5 |
+
|
| 6 |
+
PARTITION=${PARTITION:-"VC2"}
|
| 7 |
+
GPUS=${GPUS:-32}
|
| 8 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 9 |
+
QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
|
| 10 |
+
NODES=$((GPUS / GPUS_PER_NODE))
|
| 11 |
+
CPUS_PER_TASK=${CPUS_PER_TASK:-10}
|
| 12 |
+
SRUN_ARGS=${SRUN_ARGS:-""}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
| 16 |
+
|
| 17 |
+
# number of gpus: 32
|
| 18 |
+
# batch size per gpu: 32
|
| 19 |
+
# gradient accumulation steps: 1
|
| 20 |
+
# total batch size: 1024
|
| 21 |
+
# epoch: 10
|
| 22 |
+
srun -p ${PARTITION} \
|
| 23 |
+
--gres=gpu:${GPUS_PER_NODE} \
|
| 24 |
+
--nodes=${NODES} \
|
| 25 |
+
--ntasks=${GPUS} \
|
| 26 |
+
--ntasks-per-node=${GPUS_PER_NODE} \
|
| 27 |
+
--cpus-per-task=${CPUS_PER_TASK} \
|
| 28 |
+
--kill-on-bad-exit=1 \
|
| 29 |
+
--quotatype=${QUOTA_TYPE} \
|
| 30 |
+
${SRUN_ARGS} \
|
| 31 |
+
python -u internvl/train/internvl_stage2_finetune.py \
|
| 32 |
+
--dataset_name 'flickr30k_en_train' \
|
| 33 |
+
--model_name_or_path "./pretrained/InternVL-14B-224px" \
|
| 34 |
+
--output_dir "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10" \
|
| 35 |
+
--overwrite_output_dir True \
|
| 36 |
+
--force_image_size 364 \
|
| 37 |
+
--drop_path_rate 0.3 \
|
| 38 |
+
--use_custom_trainer \
|
| 39 |
+
--dataloader_num_workers 2 \
|
| 40 |
+
--pad_to_max_length True \
|
| 41 |
+
--bf16 True \
|
| 42 |
+
--num_train_epochs 10 \
|
| 43 |
+
--per_device_train_batch_size 32 \
|
| 44 |
+
--gradient_accumulation_steps 1 \
|
| 45 |
+
--evaluation_strategy "no" \
|
| 46 |
+
--save_strategy "steps" \
|
| 47 |
+
--save_steps 100 \
|
| 48 |
+
--save_total_limit 5 \
|
| 49 |
+
--learning_rate 1e-6 \
|
| 50 |
+
--weight_decay 0.05 \
|
| 51 |
+
--warmup_steps 100 \
|
| 52 |
+
--lr_scheduler_type "cosine" \
|
| 53 |
+
--logging_steps 1 \
|
| 54 |
+
--max_seq_length 80 \
|
| 55 |
+
--do_train True \
|
| 56 |
+
--optim adamw_torch \
|
| 57 |
+
--deepspeed "zero_stage1_config.json" \
|
| 58 |
+
--report_to "tensorboard"
|
VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
GPUS=${GPUS:-4}
|
| 4 |
+
BATCH_SIZE=${BATCH_SIZE:-32}
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
| 8 |
+
export MASTER_PORT=34229
|
| 9 |
+
export TF_CPP_MIN_LOG_LEVEL=3
|
| 10 |
+
export LAUNCHER=pytorch
|
| 11 |
+
|
| 12 |
+
OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10_lora_4gpu'
|
| 13 |
+
|
| 14 |
+
if [ ! -d "$OUTPUT_DIR" ]; then
|
| 15 |
+
mkdir -p "$OUTPUT_DIR"
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# number of gpus: 32
|
| 19 |
+
# batch size per gpu: 32
|
| 20 |
+
# gradient accumulation steps: 1
|
| 21 |
+
# total batch size: 1024
|
| 22 |
+
# epoch: 10
|
| 23 |
+
torchrun \
|
| 24 |
+
--nnodes=1 \
|
| 25 |
+
--node_rank=0 \
|
| 26 |
+
--master_addr=127.0.0.1 \
|
| 27 |
+
--nproc_per_node=${GPUS} \
|
| 28 |
+
--master_port=${MASTER_PORT} \
|
| 29 |
+
internvl/train/internvl_stage2_finetune.py \
|
| 30 |
+
--dataset_name 'flickr30k_en_train' \
|
| 31 |
+
--model_name_or_path "./pretrained/InternVL-14B-224px" \
|
| 32 |
+
--output_dir ${OUTPUT_DIR} \
|
| 33 |
+
--overwrite_output_dir True \
|
| 34 |
+
--freeze_model \
|
| 35 |
+
--freeze_vision_model \
|
| 36 |
+
--freeze_qllama \
|
| 37 |
+
--unfreeze_qllama_head \
|
| 38 |
+
--use_backbone_lora 16 \
|
| 39 |
+
--use_qllama_lora 16 \
|
| 40 |
+
--force_image_size 224 \
|
| 41 |
+
--drop_path_rate 0.0 \
|
| 42 |
+
--dataloader_num_workers 2 \
|
| 43 |
+
--pad_to_max_length True \
|
| 44 |
+
--bf16 True \
|
| 45 |
+
--num_train_epochs 10 \
|
| 46 |
+
--per_device_train_batch_size ${BATCH_SIZE} \
|
| 47 |
+
--gradient_accumulation_steps 1 \
|
| 48 |
+
--evaluation_strategy "no" \
|
| 49 |
+
--save_strategy "steps" \
|
| 50 |
+
--save_steps 100 \
|
| 51 |
+
--save_total_limit 5 \
|
| 52 |
+
--learning_rate 1e-6 \
|
| 53 |
+
--weight_decay 0.05 \
|
| 54 |
+
--warmup_steps 100 \
|
| 55 |
+
--lr_scheduler_type "cosine" \
|
| 56 |
+
--logging_steps 1 \
|
| 57 |
+
--max_seq_length 80 \
|
| 58 |
+
--do_train True \
|
| 59 |
+
--optim adamw_torch \
|
| 60 |
+
--deepspeed "zero_stage3_config.json" \
|
| 61 |
+
--report_to "tensorboard"
|
VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
GPUS=${GPUS:-4}
|
| 4 |
+
BATCH_SIZE=${BATCH_SIZE:-32}
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
| 8 |
+
export MASTER_PORT=34229
|
| 9 |
+
export TF_CPP_MIN_LOG_LEVEL=3
|
| 10 |
+
export LAUNCHER=pytorch
|
| 11 |
+
|
| 12 |
+
OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10_lora_4gpu'
|
| 13 |
+
|
| 14 |
+
if [ ! -d "$OUTPUT_DIR" ]; then
|
| 15 |
+
mkdir -p "$OUTPUT_DIR"
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# number of gpus: 32
|
| 19 |
+
# batch size per gpu: 32
|
| 20 |
+
# gradient accumulation steps: 1
|
| 21 |
+
# total batch size: 1024
|
| 22 |
+
# epoch: 10
|
| 23 |
+
torchrun \
|
| 24 |
+
--nnodes=1 \
|
| 25 |
+
--node_rank=0 \
|
| 26 |
+
--master_addr=127.0.0.1 \
|
| 27 |
+
--nproc_per_node=${GPUS} \
|
| 28 |
+
--master_port=${MASTER_PORT} \
|
| 29 |
+
internvl/train/internvl_stage2_finetune.py \
|
| 30 |
+
--dataset_name 'flickr30k_cn_train' \
|
| 31 |
+
--model_name_or_path "./pretrained/InternVL-14B-224px" \
|
| 32 |
+
--output_dir ${OUTPUT_DIR} \
|
| 33 |
+
--overwrite_output_dir True \
|
| 34 |
+
--freeze_model \
|
| 35 |
+
--freeze_vision_model \
|
| 36 |
+
--freeze_qllama \
|
| 37 |
+
--unfreeze_qllama_head \
|
| 38 |
+
--use_backbone_lora 16 \
|
| 39 |
+
--use_qllama_lora 16 \
|
| 40 |
+
--force_image_size 224 \
|
| 41 |
+
--drop_path_rate 0.0 \
|
| 42 |
+
--dataloader_num_workers 2 \
|
| 43 |
+
--pad_to_max_length True \
|
| 44 |
+
--bf16 True \
|
| 45 |
+
--num_train_epochs 10 \
|
| 46 |
+
--per_device_train_batch_size ${BATCH_SIZE} \
|
| 47 |
+
--gradient_accumulation_steps 1 \
|
| 48 |
+
--evaluation_strategy "no" \
|
| 49 |
+
--save_strategy "steps" \
|
| 50 |
+
--save_steps 100 \
|
| 51 |
+
--save_total_limit 5 \
|
| 52 |
+
--learning_rate 1e-6 \
|
| 53 |
+
--weight_decay 0.05 \
|
| 54 |
+
--warmup_steps 100 \
|
| 55 |
+
--lr_scheduler_type "cosine" \
|
| 56 |
+
--logging_steps 1 \
|
| 57 |
+
--max_seq_length 80 \
|
| 58 |
+
--do_train True \
|
| 59 |
+
--optim adamw_torch \
|
| 60 |
+
--deepspeed "zero_stage3_config.json" \
|
| 61 |
+
--report_to "tensorboard"
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'ADE20KDataset'
|
| 3 |
+
data_root = 'data/ade/ADEChallengeData2016'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
crop_size = (504, 504)
|
| 7 |
+
train_pipeline = [
|
| 8 |
+
dict(type='LoadImageFromFile'),
|
| 9 |
+
dict(type='LoadAnnotations', reduce_zero_label=True),
|
| 10 |
+
dict(type='Resize', img_scale=(2016, 504), ratio_range=(0.5, 2.0)),
|
| 11 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 12 |
+
dict(type='RandomFlip', prob=0.5),
|
| 13 |
+
dict(type='PhotoMetricDistortion'),
|
| 14 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 15 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 16 |
+
dict(type='DefaultFormatBundle'),
|
| 17 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 18 |
+
]
|
| 19 |
+
test_pipeline = [
|
| 20 |
+
dict(type='LoadImageFromFile'),
|
| 21 |
+
dict(
|
| 22 |
+
type='MultiScaleFlipAug',
|
| 23 |
+
img_scale=(2016, 504),
|
| 24 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
| 25 |
+
flip=False,
|
| 26 |
+
transforms=[
|
| 27 |
+
dict(type='SETR_Resize', keep_ratio=True,
|
| 28 |
+
crop_size=crop_size, setr_multi_scale=True),
|
| 29 |
+
dict(type='ResizeToMultiple', size_divisor=14),
|
| 30 |
+
dict(type='RandomFlip'),
|
| 31 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 32 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 33 |
+
dict(type='Collect', keys=['img']),
|
| 34 |
+
])
|
| 35 |
+
]
|
| 36 |
+
data = dict(
|
| 37 |
+
samples_per_gpu=4,
|
| 38 |
+
workers_per_gpu=4,
|
| 39 |
+
train=dict(
|
| 40 |
+
type=dataset_type,
|
| 41 |
+
data_root=data_root,
|
| 42 |
+
img_dir='images/training',
|
| 43 |
+
ann_dir='annotations/training',
|
| 44 |
+
pipeline=train_pipeline),
|
| 45 |
+
val=dict(
|
| 46 |
+
type=dataset_type,
|
| 47 |
+
data_root=data_root,
|
| 48 |
+
img_dir='images/validation',
|
| 49 |
+
ann_dir='annotations/validation',
|
| 50 |
+
pipeline=test_pipeline),
|
| 51 |
+
test=dict(
|
| 52 |
+
type=dataset_type,
|
| 53 |
+
data_root=data_root,
|
| 54 |
+
img_dir='images/validation',
|
| 55 |
+
ann_dir='annotations/validation',
|
| 56 |
+
pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_640x640.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'ADE20KDataset'
|
| 3 |
+
data_root = 'data/ade/ADEChallengeData2016'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
crop_size = (640, 640)
|
| 7 |
+
train_pipeline = [
|
| 8 |
+
dict(type='LoadImageFromFile'),
|
| 9 |
+
dict(type='LoadAnnotations', reduce_zero_label=True),
|
| 10 |
+
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
|
| 11 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 12 |
+
dict(type='RandomFlip', prob=0.5),
|
| 13 |
+
dict(type='PhotoMetricDistortion'),
|
| 14 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 15 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 16 |
+
dict(type='DefaultFormatBundle'),
|
| 17 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 18 |
+
]
|
| 19 |
+
test_pipeline = [
|
| 20 |
+
dict(type='LoadImageFromFile'),
|
| 21 |
+
dict(
|
| 22 |
+
type='MultiScaleFlipAug',
|
| 23 |
+
img_scale=(2560, 640),
|
| 24 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
| 25 |
+
flip=False,
|
| 26 |
+
transforms=[
|
| 27 |
+
dict(type='Resize', keep_ratio=True),
|
| 28 |
+
dict(type='RandomFlip'),
|
| 29 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 30 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 31 |
+
dict(type='Collect', keys=['img']),
|
| 32 |
+
])
|
| 33 |
+
]
|
| 34 |
+
data = dict(
|
| 35 |
+
samples_per_gpu=4,
|
| 36 |
+
workers_per_gpu=4,
|
| 37 |
+
train=dict(
|
| 38 |
+
type=dataset_type,
|
| 39 |
+
data_root=data_root,
|
| 40 |
+
img_dir='images/training',
|
| 41 |
+
ann_dir='annotations/training',
|
| 42 |
+
pipeline=train_pipeline),
|
| 43 |
+
val=dict(
|
| 44 |
+
type=dataset_type,
|
| 45 |
+
data_root=data_root,
|
| 46 |
+
img_dir='images/validation',
|
| 47 |
+
ann_dir='annotations/validation',
|
| 48 |
+
pipeline=test_pipeline),
|
| 49 |
+
test=dict(
|
| 50 |
+
type=dataset_type,
|
| 51 |
+
data_root=data_root,
|
| 52 |
+
img_dir='images/validation',
|
| 53 |
+
ann_dir='annotations/validation',
|
| 54 |
+
pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/cityscapes_832x832.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_base_ = './cityscapes.py'
|
| 2 |
+
img_norm_cfg = dict(
|
| 3 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 4 |
+
crop_size = (832, 832)
|
| 5 |
+
train_pipeline = [
|
| 6 |
+
dict(type='LoadImageFromFile'),
|
| 7 |
+
dict(type='LoadAnnotations'),
|
| 8 |
+
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
| 9 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 10 |
+
dict(type='RandomFlip', prob=0.5),
|
| 11 |
+
dict(type='PhotoMetricDistortion'),
|
| 12 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 13 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 14 |
+
dict(type='DefaultFormatBundle'),
|
| 15 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 16 |
+
]
|
| 17 |
+
test_pipeline = [
|
| 18 |
+
dict(type='LoadImageFromFile'),
|
| 19 |
+
dict(
|
| 20 |
+
type='MultiScaleFlipAug',
|
| 21 |
+
img_scale=(2048, 1024),
|
| 22 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
| 23 |
+
flip=False,
|
| 24 |
+
transforms=[
|
| 25 |
+
dict(type='Resize', keep_ratio=True),
|
| 26 |
+
dict(type='RandomFlip'),
|
| 27 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 28 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 29 |
+
dict(type='Collect', keys=['img']),
|
| 30 |
+
])
|
| 31 |
+
]
|
| 32 |
+
data = dict(
|
| 33 |
+
train=dict(pipeline=train_pipeline),
|
| 34 |
+
val=dict(pipeline=test_pipeline),
|
| 35 |
+
test=dict(pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff10k.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'COCOStuffDataset'
|
| 3 |
+
data_root = 'data/coco_stuff10k'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
crop_size = (512, 512)
|
| 7 |
+
train_pipeline = [
|
| 8 |
+
dict(type='LoadImageFromFile'),
|
| 9 |
+
dict(type='LoadAnnotations', reduce_zero_label=True),
|
| 10 |
+
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
| 11 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 12 |
+
dict(type='RandomFlip', prob=0.5),
|
| 13 |
+
dict(type='PhotoMetricDistortion'),
|
| 14 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 15 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 16 |
+
dict(type='DefaultFormatBundle'),
|
| 17 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 18 |
+
]
|
| 19 |
+
test_pipeline = [
|
| 20 |
+
dict(type='LoadImageFromFile'),
|
| 21 |
+
dict(
|
| 22 |
+
type='MultiScaleFlipAug',
|
| 23 |
+
img_scale=(2048, 512),
|
| 24 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
| 25 |
+
flip=False,
|
| 26 |
+
transforms=[
|
| 27 |
+
dict(type='Resize', keep_ratio=True),
|
| 28 |
+
dict(type='RandomFlip'),
|
| 29 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 30 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 31 |
+
dict(type='Collect', keys=['img']),
|
| 32 |
+
])
|
| 33 |
+
]
|
| 34 |
+
data = dict(
|
| 35 |
+
samples_per_gpu=4,
|
| 36 |
+
workers_per_gpu=4,
|
| 37 |
+
train=dict(
|
| 38 |
+
type=dataset_type,
|
| 39 |
+
data_root=data_root,
|
| 40 |
+
reduce_zero_label=True,
|
| 41 |
+
img_dir='images/train2014',
|
| 42 |
+
ann_dir='annotations/train2014',
|
| 43 |
+
pipeline=train_pipeline),
|
| 44 |
+
val=dict(
|
| 45 |
+
type=dataset_type,
|
| 46 |
+
data_root=data_root,
|
| 47 |
+
reduce_zero_label=True,
|
| 48 |
+
img_dir='images/test2014',
|
| 49 |
+
ann_dir='annotations/test2014',
|
| 50 |
+
pipeline=test_pipeline),
|
| 51 |
+
test=dict(
|
| 52 |
+
type=dataset_type,
|
| 53 |
+
data_root=data_root,
|
| 54 |
+
reduce_zero_label=True,
|
| 55 |
+
img_dir='images/test2014',
|
| 56 |
+
ann_dir='annotations/test2014',
|
| 57 |
+
pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff164k.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'COCOStuffDataset'
|
| 3 |
+
data_root = 'data/coco_stuff164k'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
crop_size = (512, 512)
|
| 7 |
+
train_pipeline = [
|
| 8 |
+
dict(type='LoadImageFromFile'),
|
| 9 |
+
dict(type='LoadAnnotations'),
|
| 10 |
+
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
| 11 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 12 |
+
dict(type='RandomFlip', prob=0.5),
|
| 13 |
+
dict(type='PhotoMetricDistortion'),
|
| 14 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 15 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 16 |
+
dict(type='DefaultFormatBundle'),
|
| 17 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 18 |
+
]
|
| 19 |
+
test_pipeline = [
|
| 20 |
+
dict(type='LoadImageFromFile'),
|
| 21 |
+
dict(
|
| 22 |
+
type='MultiScaleFlipAug',
|
| 23 |
+
img_scale=(2048, 512),
|
| 24 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
| 25 |
+
flip=False,
|
| 26 |
+
transforms=[
|
| 27 |
+
dict(type='Resize', keep_ratio=True),
|
| 28 |
+
dict(type='RandomFlip'),
|
| 29 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 30 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 31 |
+
dict(type='Collect', keys=['img']),
|
| 32 |
+
])
|
| 33 |
+
]
|
| 34 |
+
data = dict(
|
| 35 |
+
samples_per_gpu=4,
|
| 36 |
+
workers_per_gpu=4,
|
| 37 |
+
train=dict(
|
| 38 |
+
type=dataset_type,
|
| 39 |
+
data_root=data_root,
|
| 40 |
+
img_dir='images/train2017',
|
| 41 |
+
ann_dir='annotations/train2017',
|
| 42 |
+
pipeline=train_pipeline),
|
| 43 |
+
val=dict(
|
| 44 |
+
type=dataset_type,
|
| 45 |
+
data_root=data_root,
|
| 46 |
+
img_dir='images/val2017',
|
| 47 |
+
ann_dir='annotations/val2017',
|
| 48 |
+
pipeline=test_pipeline),
|
| 49 |
+
test=dict(
|
| 50 |
+
type=dataset_type,
|
| 51 |
+
data_root=data_root,
|
| 52 |
+
img_dir='images/val2017',
|
| 53 |
+
ann_dir='annotations/val2017',
|
| 54 |
+
pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/hrf.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'HRFDataset'
|
| 3 |
+
data_root = 'data/HRF'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
img_scale = (2336, 3504)
|
| 7 |
+
crop_size = (256, 256)
|
| 8 |
+
train_pipeline = [
|
| 9 |
+
dict(type='LoadImageFromFile'),
|
| 10 |
+
dict(type='LoadAnnotations'),
|
| 11 |
+
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
| 12 |
+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
| 13 |
+
dict(type='RandomFlip', prob=0.5),
|
| 14 |
+
dict(type='PhotoMetricDistortion'),
|
| 15 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 16 |
+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
| 17 |
+
dict(type='DefaultFormatBundle'),
|
| 18 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
| 19 |
+
]
|
| 20 |
+
test_pipeline = [
|
| 21 |
+
dict(type='LoadImageFromFile'),
|
| 22 |
+
dict(
|
| 23 |
+
type='MultiScaleFlipAug',
|
| 24 |
+
img_scale=img_scale,
|
| 25 |
+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
| 26 |
+
flip=False,
|
| 27 |
+
transforms=[
|
| 28 |
+
dict(type='Resize', keep_ratio=True),
|
| 29 |
+
dict(type='RandomFlip'),
|
| 30 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 31 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 32 |
+
dict(type='Collect', keys=['img'])
|
| 33 |
+
])
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
data = dict(
|
| 37 |
+
samples_per_gpu=4,
|
| 38 |
+
workers_per_gpu=4,
|
| 39 |
+
train=dict(
|
| 40 |
+
type='RepeatDataset',
|
| 41 |
+
times=40000,
|
| 42 |
+
dataset=dict(
|
| 43 |
+
type=dataset_type,
|
| 44 |
+
data_root=data_root,
|
| 45 |
+
img_dir='images/training',
|
| 46 |
+
ann_dir='annotations/training',
|
| 47 |
+
pipeline=train_pipeline)),
|
| 48 |
+
val=dict(
|
| 49 |
+
type=dataset_type,
|
| 50 |
+
data_root=data_root,
|
| 51 |
+
img_dir='images/validation',
|
| 52 |
+
ann_dir='annotations/validation',
|
| 53 |
+
pipeline=test_pipeline),
|
| 54 |
+
test=dict(
|
| 55 |
+
type=dataset_type,
|
| 56 |
+
data_root=data_root,
|
| 57 |
+
img_dir='images/validation',
|
| 58 |
+
ann_dir='annotations/validation',
|
| 59 |
+
pipeline=test_pipeline))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ann_r50-d8.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='ANNHead',
|
| 19 |
+
in_channels=[1024, 2048],
|
| 20 |
+
in_index=[2, 3],
|
| 21 |
+
channels=512,
|
| 22 |
+
project_channels=256,
|
| 23 |
+
query_scales=(1, ),
|
| 24 |
+
key_pool_scales=(1, 3, 6, 8),
|
| 25 |
+
dropout_ratio=0.1,
|
| 26 |
+
num_classes=19,
|
| 27 |
+
norm_cfg=norm_cfg,
|
| 28 |
+
align_corners=False,
|
| 29 |
+
loss_decode=dict(
|
| 30 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 31 |
+
auxiliary_head=dict(
|
| 32 |
+
type='FCNHead',
|
| 33 |
+
in_channels=1024,
|
| 34 |
+
in_index=2,
|
| 35 |
+
channels=256,
|
| 36 |
+
num_convs=1,
|
| 37 |
+
concat_input=False,
|
| 38 |
+
dropout_ratio=0.1,
|
| 39 |
+
num_classes=19,
|
| 40 |
+
norm_cfg=norm_cfg,
|
| 41 |
+
align_corners=False,
|
| 42 |
+
loss_decode=dict(
|
| 43 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 44 |
+
# model training and testing settings
|
| 45 |
+
train_cfg=dict(),
|
| 46 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ccnet_r50-d8.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='CCHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
recurrence=2,
|
| 23 |
+
dropout_ratio=0.1,
|
| 24 |
+
num_classes=19,
|
| 25 |
+
norm_cfg=norm_cfg,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
loss_decode=dict(
|
| 28 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 29 |
+
auxiliary_head=dict(
|
| 30 |
+
type='FCNHead',
|
| 31 |
+
in_channels=1024,
|
| 32 |
+
in_index=2,
|
| 33 |
+
channels=256,
|
| 34 |
+
num_convs=1,
|
| 35 |
+
concat_input=False,
|
| 36 |
+
dropout_ratio=0.1,
|
| 37 |
+
num_classes=19,
|
| 38 |
+
norm_cfg=norm_cfg,
|
| 39 |
+
align_corners=False,
|
| 40 |
+
loss_decode=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 42 |
+
# model training and testing settings
|
| 43 |
+
train_cfg=dict(),
|
| 44 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/cgnet.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='CGNet',
|
| 7 |
+
norm_cfg=norm_cfg,
|
| 8 |
+
in_channels=3,
|
| 9 |
+
num_channels=(32, 64, 128),
|
| 10 |
+
num_blocks=(3, 21),
|
| 11 |
+
dilations=(2, 4),
|
| 12 |
+
reductions=(8, 16)),
|
| 13 |
+
decode_head=dict(
|
| 14 |
+
type='FCNHead',
|
| 15 |
+
in_channels=256,
|
| 16 |
+
in_index=2,
|
| 17 |
+
channels=256,
|
| 18 |
+
num_convs=0,
|
| 19 |
+
concat_input=False,
|
| 20 |
+
dropout_ratio=0,
|
| 21 |
+
num_classes=19,
|
| 22 |
+
norm_cfg=norm_cfg,
|
| 23 |
+
loss_decode=dict(
|
| 24 |
+
type='CrossEntropyLoss',
|
| 25 |
+
use_sigmoid=False,
|
| 26 |
+
loss_weight=1.0,
|
| 27 |
+
class_weight=[
|
| 28 |
+
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
|
| 29 |
+
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
|
| 30 |
+
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
|
| 31 |
+
10.396974, 10.055647
|
| 32 |
+
])),
|
| 33 |
+
# model training and testing settings
|
| 34 |
+
train_cfg=dict(sampler=None),
|
| 35 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/danet_r50-d8.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='DAHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
pam_channels=64,
|
| 23 |
+
dropout_ratio=0.1,
|
| 24 |
+
num_classes=19,
|
| 25 |
+
norm_cfg=norm_cfg,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
loss_decode=dict(
|
| 28 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 29 |
+
auxiliary_head=dict(
|
| 30 |
+
type='FCNHead',
|
| 31 |
+
in_channels=1024,
|
| 32 |
+
in_index=2,
|
| 33 |
+
channels=256,
|
| 34 |
+
num_convs=1,
|
| 35 |
+
concat_input=False,
|
| 36 |
+
dropout_ratio=0.1,
|
| 37 |
+
num_classes=19,
|
| 38 |
+
norm_cfg=norm_cfg,
|
| 39 |
+
align_corners=False,
|
| 40 |
+
loss_decode=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 42 |
+
# model training and testing settings
|
| 43 |
+
train_cfg=dict(),
|
| 44 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_r50-d8.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='ASPPHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
dilations=(1, 12, 24, 36),
|
| 23 |
+
dropout_ratio=0.1,
|
| 24 |
+
num_classes=19,
|
| 25 |
+
norm_cfg=norm_cfg,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
loss_decode=dict(
|
| 28 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 29 |
+
auxiliary_head=dict(
|
| 30 |
+
type='FCNHead',
|
| 31 |
+
in_channels=1024,
|
| 32 |
+
in_index=2,
|
| 33 |
+
channels=256,
|
| 34 |
+
num_convs=1,
|
| 35 |
+
concat_input=False,
|
| 36 |
+
dropout_ratio=0.1,
|
| 37 |
+
num_classes=19,
|
| 38 |
+
norm_cfg=norm_cfg,
|
| 39 |
+
align_corners=False,
|
| 40 |
+
loss_decode=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 42 |
+
# model training and testing settings
|
| 43 |
+
train_cfg=dict(),
|
| 44 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained=None,
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='UNet',
|
| 8 |
+
in_channels=3,
|
| 9 |
+
base_channels=64,
|
| 10 |
+
num_stages=5,
|
| 11 |
+
strides=(1, 1, 1, 1, 1),
|
| 12 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
| 13 |
+
dec_num_convs=(2, 2, 2, 2),
|
| 14 |
+
downsamples=(True, True, True, True),
|
| 15 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
| 16 |
+
dec_dilations=(1, 1, 1, 1),
|
| 17 |
+
with_cp=False,
|
| 18 |
+
conv_cfg=None,
|
| 19 |
+
norm_cfg=norm_cfg,
|
| 20 |
+
act_cfg=dict(type='ReLU'),
|
| 21 |
+
upsample_cfg=dict(type='InterpConv'),
|
| 22 |
+
norm_eval=False),
|
| 23 |
+
decode_head=dict(
|
| 24 |
+
type='ASPPHead',
|
| 25 |
+
in_channels=64,
|
| 26 |
+
in_index=4,
|
| 27 |
+
channels=16,
|
| 28 |
+
dilations=(1, 12, 24, 36),
|
| 29 |
+
dropout_ratio=0.1,
|
| 30 |
+
num_classes=2,
|
| 31 |
+
norm_cfg=norm_cfg,
|
| 32 |
+
align_corners=False,
|
| 33 |
+
loss_decode=dict(
|
| 34 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 35 |
+
auxiliary_head=dict(
|
| 36 |
+
type='FCNHead',
|
| 37 |
+
in_channels=128,
|
| 38 |
+
in_index=3,
|
| 39 |
+
channels=64,
|
| 40 |
+
num_convs=1,
|
| 41 |
+
concat_input=False,
|
| 42 |
+
dropout_ratio=0.1,
|
| 43 |
+
num_classes=2,
|
| 44 |
+
norm_cfg=norm_cfg,
|
| 45 |
+
align_corners=False,
|
| 46 |
+
loss_decode=dict(
|
| 47 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 48 |
+
# model training and testing settings
|
| 49 |
+
train_cfg=dict(),
|
| 50 |
+
test_cfg=dict(mode='slide', crop_size=256, stride=170))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='DMHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
filter_sizes=(1, 3, 5, 7),
|
| 23 |
+
dropout_ratio=0.1,
|
| 24 |
+
num_classes=19,
|
| 25 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
| 26 |
+
align_corners=False,
|
| 27 |
+
loss_decode=dict(
|
| 28 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 29 |
+
auxiliary_head=dict(
|
| 30 |
+
type='FCNHead',
|
| 31 |
+
in_channels=1024,
|
| 32 |
+
in_index=2,
|
| 33 |
+
channels=256,
|
| 34 |
+
num_convs=1,
|
| 35 |
+
concat_input=False,
|
| 36 |
+
dropout_ratio=0.1,
|
| 37 |
+
num_classes=19,
|
| 38 |
+
norm_cfg=norm_cfg,
|
| 39 |
+
align_corners=False,
|
| 40 |
+
loss_decode=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 42 |
+
# model training and testing settings
|
| 43 |
+
train_cfg=dict(),
|
| 44 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/emanet_r50-d8.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='EMAHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=256,
|
| 22 |
+
ema_channels=512,
|
| 23 |
+
num_bases=64,
|
| 24 |
+
num_stages=3,
|
| 25 |
+
momentum=0.1,
|
| 26 |
+
dropout_ratio=0.1,
|
| 27 |
+
num_classes=19,
|
| 28 |
+
norm_cfg=norm_cfg,
|
| 29 |
+
align_corners=False,
|
| 30 |
+
loss_decode=dict(
|
| 31 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 32 |
+
auxiliary_head=dict(
|
| 33 |
+
type='FCNHead',
|
| 34 |
+
in_channels=1024,
|
| 35 |
+
in_index=2,
|
| 36 |
+
channels=256,
|
| 37 |
+
num_convs=1,
|
| 38 |
+
concat_input=False,
|
| 39 |
+
dropout_ratio=0.1,
|
| 40 |
+
num_classes=19,
|
| 41 |
+
norm_cfg=norm_cfg,
|
| 42 |
+
align_corners=False,
|
| 43 |
+
loss_decode=dict(
|
| 44 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 45 |
+
# model training and testing settings
|
| 46 |
+
train_cfg=dict(),
|
| 47 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='EncHead',
|
| 19 |
+
in_channels=[512, 1024, 2048],
|
| 20 |
+
in_index=(1, 2, 3),
|
| 21 |
+
channels=512,
|
| 22 |
+
num_codes=32,
|
| 23 |
+
use_se_loss=True,
|
| 24 |
+
add_lateral=False,
|
| 25 |
+
dropout_ratio=0.1,
|
| 26 |
+
num_classes=19,
|
| 27 |
+
norm_cfg=norm_cfg,
|
| 28 |
+
align_corners=False,
|
| 29 |
+
loss_decode=dict(
|
| 30 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 31 |
+
loss_se_decode=dict(
|
| 32 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
|
| 33 |
+
auxiliary_head=dict(
|
| 34 |
+
type='FCNHead',
|
| 35 |
+
in_channels=1024,
|
| 36 |
+
in_index=2,
|
| 37 |
+
channels=256,
|
| 38 |
+
num_convs=1,
|
| 39 |
+
concat_input=False,
|
| 40 |
+
dropout_ratio=0.1,
|
| 41 |
+
num_classes=19,
|
| 42 |
+
norm_cfg=norm_cfg,
|
| 43 |
+
align_corners=False,
|
| 44 |
+
loss_decode=dict(
|
| 45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 46 |
+
# model training and testing settings
|
| 47 |
+
train_cfg=dict(),
|
| 48 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/erfnet_fcn.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained=None,
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ERFNet',
|
| 8 |
+
in_channels=3,
|
| 9 |
+
enc_downsample_channels=(16, 64, 128),
|
| 10 |
+
enc_stage_non_bottlenecks=(5, 8),
|
| 11 |
+
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
| 12 |
+
enc_non_bottleneck_channels=(64, 128),
|
| 13 |
+
dec_upsample_channels=(64, 16),
|
| 14 |
+
dec_stages_non_bottleneck=(2, 2),
|
| 15 |
+
dec_non_bottleneck_channels=(64, 16),
|
| 16 |
+
dropout_ratio=0.1,
|
| 17 |
+
init_cfg=None),
|
| 18 |
+
decode_head=dict(
|
| 19 |
+
type='FCNHead',
|
| 20 |
+
in_channels=16,
|
| 21 |
+
channels=128,
|
| 22 |
+
num_convs=1,
|
| 23 |
+
concat_input=False,
|
| 24 |
+
dropout_ratio=0.1,
|
| 25 |
+
num_classes=19,
|
| 26 |
+
norm_cfg=norm_cfg,
|
| 27 |
+
align_corners=False,
|
| 28 |
+
loss_decode=dict(
|
| 29 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 30 |
+
# model training and testing settings
|
| 31 |
+
train_cfg=dict(),
|
| 32 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fast_scnn.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='FastSCNN',
|
| 7 |
+
downsample_dw_channels=(32, 48),
|
| 8 |
+
global_in_channels=64,
|
| 9 |
+
global_block_channels=(64, 96, 128),
|
| 10 |
+
global_block_strides=(2, 2, 1),
|
| 11 |
+
global_out_channels=128,
|
| 12 |
+
higher_in_channels=64,
|
| 13 |
+
lower_in_channels=128,
|
| 14 |
+
fusion_out_channels=128,
|
| 15 |
+
out_indices=(0, 1, 2),
|
| 16 |
+
norm_cfg=norm_cfg,
|
| 17 |
+
align_corners=False),
|
| 18 |
+
decode_head=dict(
|
| 19 |
+
type='DepthwiseSeparableFCNHead',
|
| 20 |
+
in_channels=128,
|
| 21 |
+
channels=128,
|
| 22 |
+
concat_input=False,
|
| 23 |
+
num_classes=19,
|
| 24 |
+
in_index=-1,
|
| 25 |
+
norm_cfg=norm_cfg,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
loss_decode=dict(
|
| 28 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)),
|
| 29 |
+
auxiliary_head=[
|
| 30 |
+
dict(
|
| 31 |
+
type='FCNHead',
|
| 32 |
+
in_channels=128,
|
| 33 |
+
channels=32,
|
| 34 |
+
num_convs=1,
|
| 35 |
+
num_classes=19,
|
| 36 |
+
in_index=-2,
|
| 37 |
+
norm_cfg=norm_cfg,
|
| 38 |
+
concat_input=False,
|
| 39 |
+
align_corners=False,
|
| 40 |
+
loss_decode=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
|
| 42 |
+
dict(
|
| 43 |
+
type='FCNHead',
|
| 44 |
+
in_channels=64,
|
| 45 |
+
channels=32,
|
| 46 |
+
num_convs=1,
|
| 47 |
+
num_classes=19,
|
| 48 |
+
in_index=-3,
|
| 49 |
+
norm_cfg=norm_cfg,
|
| 50 |
+
concat_input=False,
|
| 51 |
+
align_corners=False,
|
| 52 |
+
loss_decode=dict(
|
| 53 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
|
| 54 |
+
],
|
| 55 |
+
# model training and testing settings
|
| 56 |
+
train_cfg=dict(),
|
| 57 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
dilations=(1, 1, 2, 4),
|
| 11 |
+
strides=(1, 2, 2, 2),
|
| 12 |
+
out_indices=(1, 2, 3),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
neck=dict(
|
| 18 |
+
type='JPU',
|
| 19 |
+
in_channels=(512, 1024, 2048),
|
| 20 |
+
mid_channels=512,
|
| 21 |
+
start_level=0,
|
| 22 |
+
end_level=-1,
|
| 23 |
+
dilations=(1, 2, 4, 8),
|
| 24 |
+
align_corners=False,
|
| 25 |
+
norm_cfg=norm_cfg),
|
| 26 |
+
decode_head=dict(
|
| 27 |
+
type='PSPHead',
|
| 28 |
+
in_channels=2048,
|
| 29 |
+
in_index=2,
|
| 30 |
+
channels=512,
|
| 31 |
+
pool_scales=(1, 2, 3, 6),
|
| 32 |
+
dropout_ratio=0.1,
|
| 33 |
+
num_classes=19,
|
| 34 |
+
norm_cfg=norm_cfg,
|
| 35 |
+
align_corners=False,
|
| 36 |
+
loss_decode=dict(
|
| 37 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 38 |
+
auxiliary_head=dict(
|
| 39 |
+
type='FCNHead',
|
| 40 |
+
in_channels=1024,
|
| 41 |
+
in_index=1,
|
| 42 |
+
channels=256,
|
| 43 |
+
num_convs=1,
|
| 44 |
+
concat_input=False,
|
| 45 |
+
dropout_ratio=0.1,
|
| 46 |
+
num_classes=19,
|
| 47 |
+
norm_cfg=norm_cfg,
|
| 48 |
+
align_corners=False,
|
| 49 |
+
loss_decode=dict(
|
| 50 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 51 |
+
# model training and testing settings
|
| 52 |
+
train_cfg=dict(),
|
| 53 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/gcnet_r50-d8.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='GCHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
ratio=1 / 4.,
|
| 23 |
+
pooling_type='att',
|
| 24 |
+
fusion_types=('channel_add', ),
|
| 25 |
+
dropout_ratio=0.1,
|
| 26 |
+
num_classes=19,
|
| 27 |
+
norm_cfg=norm_cfg,
|
| 28 |
+
align_corners=False,
|
| 29 |
+
loss_decode=dict(
|
| 30 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 31 |
+
auxiliary_head=dict(
|
| 32 |
+
type='FCNHead',
|
| 33 |
+
in_channels=1024,
|
| 34 |
+
in_index=2,
|
| 35 |
+
channels=256,
|
| 36 |
+
num_convs=1,
|
| 37 |
+
concat_input=False,
|
| 38 |
+
dropout_ratio=0.1,
|
| 39 |
+
num_classes=19,
|
| 40 |
+
norm_cfg=norm_cfg,
|
| 41 |
+
align_corners=False,
|
| 42 |
+
loss_decode=dict(
|
| 43 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 44 |
+
# model training and testing settings
|
| 45 |
+
train_cfg=dict(),
|
| 46 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='ISAHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
isa_channels=256,
|
| 23 |
+
down_factor=(8, 8),
|
| 24 |
+
dropout_ratio=0.1,
|
| 25 |
+
num_classes=19,
|
| 26 |
+
norm_cfg=norm_cfg,
|
| 27 |
+
align_corners=False,
|
| 28 |
+
loss_decode=dict(
|
| 29 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 30 |
+
auxiliary_head=dict(
|
| 31 |
+
type='FCNHead',
|
| 32 |
+
in_channels=1024,
|
| 33 |
+
in_index=2,
|
| 34 |
+
channels=256,
|
| 35 |
+
num_convs=1,
|
| 36 |
+
concat_input=False,
|
| 37 |
+
dropout_ratio=0.1,
|
| 38 |
+
num_classes=19,
|
| 39 |
+
norm_cfg=norm_cfg,
|
| 40 |
+
align_corners=False,
|
| 41 |
+
loss_decode=dict(
|
| 42 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 43 |
+
# model training and testing settings
|
| 44 |
+
train_cfg=dict(),
|
| 45 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='MobileNetV3',
|
| 7 |
+
arch='large',
|
| 8 |
+
out_indices=(1, 3, 16),
|
| 9 |
+
norm_cfg=norm_cfg),
|
| 10 |
+
decode_head=dict(
|
| 11 |
+
type='LRASPPHead',
|
| 12 |
+
in_channels=(16, 24, 960),
|
| 13 |
+
in_index=(0, 1, 2),
|
| 14 |
+
channels=128,
|
| 15 |
+
input_transform='multiple_select',
|
| 16 |
+
dropout_ratio=0.1,
|
| 17 |
+
num_classes=19,
|
| 18 |
+
norm_cfg=norm_cfg,
|
| 19 |
+
act_cfg=dict(type='ReLU'),
|
| 20 |
+
align_corners=False,
|
| 21 |
+
loss_decode=dict(
|
| 22 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 23 |
+
# model training and testing settings
|
| 24 |
+
train_cfg=dict(),
|
| 25 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/nonlocal_r50-d8.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='EncoderDecoder',
|
| 5 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 6 |
+
backbone=dict(
|
| 7 |
+
type='ResNetV1c',
|
| 8 |
+
depth=50,
|
| 9 |
+
num_stages=4,
|
| 10 |
+
out_indices=(0, 1, 2, 3),
|
| 11 |
+
dilations=(1, 1, 2, 4),
|
| 12 |
+
strides=(1, 2, 1, 1),
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=False,
|
| 15 |
+
style='pytorch',
|
| 16 |
+
contract_dilation=True),
|
| 17 |
+
decode_head=dict(
|
| 18 |
+
type='NLHead',
|
| 19 |
+
in_channels=2048,
|
| 20 |
+
in_index=3,
|
| 21 |
+
channels=512,
|
| 22 |
+
dropout_ratio=0.1,
|
| 23 |
+
reduction=2,
|
| 24 |
+
use_scale=True,
|
| 25 |
+
mode='embedded_gaussian',
|
| 26 |
+
num_classes=19,
|
| 27 |
+
norm_cfg=norm_cfg,
|
| 28 |
+
align_corners=False,
|
| 29 |
+
loss_decode=dict(
|
| 30 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 31 |
+
auxiliary_head=dict(
|
| 32 |
+
type='FCNHead',
|
| 33 |
+
in_channels=1024,
|
| 34 |
+
in_index=2,
|
| 35 |
+
channels=256,
|
| 36 |
+
num_convs=1,
|
| 37 |
+
concat_input=False,
|
| 38 |
+
dropout_ratio=0.1,
|
| 39 |
+
num_classes=19,
|
| 40 |
+
norm_cfg=norm_cfg,
|
| 41 |
+
align_corners=False,
|
| 42 |
+
loss_decode=dict(
|
| 43 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
| 44 |
+
# model training and testing settings
|
| 45 |
+
train_cfg=dict(),
|
| 46 |
+
test_cfg=dict(mode='whole'))
|
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/pointrend_r50.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='CascadeEncoderDecoder',
|
| 5 |
+
num_stages=2,
|
| 6 |
+
pretrained='open-mmlab://resnet50_v1c',
|
| 7 |
+
backbone=dict(
|
| 8 |
+
type='ResNetV1c',
|
| 9 |
+
depth=50,
|
| 10 |
+
num_stages=4,
|
| 11 |
+
out_indices=(0, 1, 2, 3),
|
| 12 |
+
dilations=(1, 1, 1, 1),
|
| 13 |
+
strides=(1, 2, 2, 2),
|
| 14 |
+
norm_cfg=norm_cfg,
|
| 15 |
+
norm_eval=False,
|
| 16 |
+
style='pytorch',
|
| 17 |
+
contract_dilation=True),
|
| 18 |
+
neck=dict(
|
| 19 |
+
type='FPN',
|
| 20 |
+
in_channels=[256, 512, 1024, 2048],
|
| 21 |
+
out_channels=256,
|
| 22 |
+
num_outs=4),
|
| 23 |
+
decode_head=[
|
| 24 |
+
dict(
|
| 25 |
+
type='FPNHead',
|
| 26 |
+
in_channels=[256, 256, 256, 256],
|
| 27 |
+
in_index=[0, 1, 2, 3],
|
| 28 |
+
feature_strides=[4, 8, 16, 32],
|
| 29 |
+
channels=128,
|
| 30 |
+
dropout_ratio=-1,
|
| 31 |
+
num_classes=19,
|
| 32 |
+
norm_cfg=norm_cfg,
|
| 33 |
+
align_corners=False,
|
| 34 |
+
loss_decode=dict(
|
| 35 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
| 36 |
+
dict(
|
| 37 |
+
type='PointHead',
|
| 38 |
+
in_channels=[256],
|
| 39 |
+
in_index=[0],
|
| 40 |
+
channels=256,
|
| 41 |
+
num_fcs=3,
|
| 42 |
+
coarse_pred_each_layer=True,
|
| 43 |
+
dropout_ratio=-1,
|
| 44 |
+
num_classes=19,
|
| 45 |
+
align_corners=False,
|
| 46 |
+
loss_decode=dict(
|
| 47 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
| 48 |
+
],
|
| 49 |
+
# model training and testing settings
|
| 50 |
+
train_cfg=dict(
|
| 51 |
+
num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
|
| 52 |
+
test_cfg=dict(
|
| 53 |
+
mode='whole',
|
| 54 |
+
subdivision_steps=2,
|
| 55 |
+
subdivision_num_points=8196,
|
| 56 |
+
scale_factor=2))
|