tuandunghcmut commited on
Commit
564e917
·
verified ·
1 Parent(s): d34bd2a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml +54 -0
  2. VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml +31 -0
  3. VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml +23 -0
  4. VLMEvalKit_old/InternVL/internvl_chat_llava/docs/Data.md +29 -0
  5. VLMEvalKit_old/InternVL/internvl_chat_llava/docs/LLaVA_Bench.md +31 -0
  6. VLMEvalKit_old/InternVL/internvl_g/eval/evaluate_caption.py +237 -0
  7. VLMEvalKit_old/InternVL/internvl_g/internvl/model/__init__.py +0 -0
  8. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/__init__.py +87 -0
  9. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/configuration_intern_vit.py +117 -0
  10. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/flash_attention.py +76 -0
  11. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_intern_vit.py +342 -0
  12. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_internvl.py +684 -0
  13. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_qllama.py +1073 -0
  14. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py +87 -0
  15. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_intern_vit.py +117 -0
  16. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_internvl.py +108 -0
  17. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/flash_attention.py +76 -0
  18. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py +342 -0
  19. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py +669 -0
  20. VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py +1073 -0
  21. VLMEvalKit_old/InternVL/internvl_g/internvl/train/dataset.py +283 -0
  22. VLMEvalKit_old/InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py +286 -0
  23. VLMEvalKit_old/InternVL/internvl_g/internvl/train/trainer_monkey_patch.py +150 -0
  24. VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh +58 -0
  25. VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh +58 -0
  26. VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh +61 -0
  27. VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh +61 -0
  28. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py +56 -0
  29. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_640x640.py +54 -0
  30. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/cityscapes_832x832.py +35 -0
  31. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff10k.py +57 -0
  32. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff164k.py +54 -0
  33. VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/hrf.py +59 -0
  34. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ann_r50-d8.py +46 -0
  35. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ccnet_r50-d8.py +44 -0
  36. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/cgnet.py +35 -0
  37. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/danet_r50-d8.py +44 -0
  38. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_r50-d8.py +44 -0
  39. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py +50 -0
  40. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py +44 -0
  41. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/emanet_r50-d8.py +47 -0
  42. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py +48 -0
  43. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/erfnet_fcn.py +32 -0
  44. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fast_scnn.py +57 -0
  45. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py +53 -0
  46. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/gcnet_r50-d8.py +46 -0
  47. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py +45 -0
  48. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py +25 -0
  49. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/nonlocal_r50-d8.py +46 -0
  50. VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/pointrend_r50.py +56 -0
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/1-bug-report.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐞 Bug report
2
+ description: Create a report to help us reproduce and fix the bug
3
+ title: "[Bug] "
4
+ labels: ['Bug']
5
+
6
+ body:
7
+ - type: checkboxes
8
+ attributes:
9
+ label: Checklist
10
+ options:
11
+ - label: 1. I have searched related issues but cannot get the expected help.
12
+ - label: 2. The bug has not been fixed in the latest version.
13
+ - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
14
+ - type: textarea
15
+ attributes:
16
+ label: Describe the bug
17
+ description: A clear and concise description of what the bug is.
18
+ validations:
19
+ required: true
20
+ - type: textarea
21
+ attributes:
22
+ label: Reproduction
23
+ description: |
24
+ 1. What command or script did you run?
25
+ placeholder: |
26
+ A placeholder for the command.
27
+ validations:
28
+ required: true
29
+ - type: textarea
30
+ attributes:
31
+ label: Environment
32
+ description: |
33
+ 1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here.
34
+ 2. You may add addition that may be helpful for locating the problem, such as
35
+ - Which **model** are you using?
36
+ - How you installed PyTorch \[e.g., pip, conda, source\]
37
+ - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
38
+ placeholder: Environment here.
39
+ render: Shell
40
+ validations:
41
+ required: true
42
+ - type: textarea
43
+ attributes:
44
+ label: Error traceback
45
+ description: |
46
+ If applicable, paste the error trackback here.
47
+ placeholder: Logs and traceback here.
48
+ render: Shell
49
+ - type: markdown
50
+ attributes:
51
+ value: >
52
+ If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
53
+
54
+ Thanks for your bug report. We appreciate it a lot.
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/2-feature-request.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🚀 Feature request
2
+ description: Suggest an idea for this project
3
+ title: "[Feature] "
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: |
9
+ We strongly appreciate you creating a PR to implement this feature [here](https://github.com/OpenGVLab/InternVL/pulls)!
10
+ If you need our help, please fill in as much of the following form as you're able to.
11
+
12
+ **The less clear the description, the longer it will take to solve it.**
13
+ - type: textarea
14
+ attributes:
15
+ label: Motivation
16
+ description: |
17
+ A clear and concise description of the motivation of the feature.
18
+ Ex1. It is inconvenient when \[....\].
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ attributes:
23
+ label: Related resources
24
+ description: |
25
+ If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
26
+ - type: textarea
27
+ attributes:
28
+ label: Additional context
29
+ description: |
30
+ Add any other context or screenshots about the feature request here.
31
+ If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
VLMEvalKit_old/InternVL/.github/ISSUE_TEMPLATE/3-documentation.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 📚 Documentation
2
+ description: Report an issue related to the documentation.
3
+ labels: "kind/doc,status/unconfirmed"
4
+ title: "[Docs] "
5
+
6
+ body:
7
+ - type: textarea
8
+ attributes:
9
+ label: 📚 The doc issue
10
+ description: >
11
+ A clear and concise description the issue.
12
+ validations:
13
+ required: true
14
+
15
+ - type: textarea
16
+ attributes:
17
+ label: Suggest a potential alternative/fix
18
+ description: >
19
+ Tell us how we could improve the documentation in this regard.
20
+ - type: markdown
21
+ attributes:
22
+ value: >
23
+ Thanks for contributing 🎉!
VLMEvalKit_old/InternVL/internvl_chat_llava/docs/Data.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Data
2
+
3
+ | Data file name | Size |
4
+ | --- | ---: |
5
+ | [llava_instruct_150k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json) | 229 MB |
6
+ | [llava_instruct_80k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_80k.json) | 229 MB |
7
+ | [conversation_58k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/conversation_58k.json) | 126 MB |
8
+ | [detail_23k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/detail_23k.json) | 20.5 MB |
9
+ | [complex_reasoning_77k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/complex_reasoning_77k.json) | 79.6 MB |
10
+
11
+ ### Pretraining Dataset
12
+ The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
13
+
14
+ If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
15
+
16
+ | Data | Chat File | Meta Data | Size |
17
+ | --- | --- | --- | ---: |
18
+ | CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
19
+ | LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
20
+
21
+ **Important notice**: Upon the request from the community, as ~15% images of the original CC-3M dataset are no longer accessible, we upload [`images.zip`](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/images.zip) for better reproducing our work in research community. It must not be used for any other purposes. The use of these images must comply with the CC-3M license. This may be taken down at any time when requested by the original CC-3M dataset owner or owners of the referenced images.
22
+
23
+ ### GPT-4 Prompts
24
+
25
+ We provide our prompts and few-shot samples for GPT-4 queries, to better facilitate research in this domain. Please check out the [`prompts`](playground/data/prompts) folder for three kinds of questions: conversation, detail description, and complex reasoning.
26
+
27
+ They are organized in a format of `system_message.txt` for system message, pairs of `abc_caps.txt` for few-shot sample user input, and `abc_conv.txt` for few-shot sample reference output.
28
+
29
+ Note that you may find them in different format. For example, `conversation` is in `jsonl`, and detail description is answer-only. The selected format in our preliminary experiments works slightly better than a limited set of alternatives that we tried: `jsonl`, more natural format, answer-only. If interested, you may try other variants or conduct more careful study in this. Contributions are welcomed!
VLMEvalKit_old/InternVL/internvl_chat_llava/docs/LLaVA_Bench.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA-Bench [[Download](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild)]
2
+
3
+ **-Introduction-** Large commercial multimodal chatbots have been released in this week, including
4
+ - [Multimodal Bing-Chat by Microsoft](https://blogs.bing.com/search/july-2023/Bing-Chat-Enterprise-announced,-multimodal-Visual-Search-rolling-out-to-Bing-Chat) (July 18, 2023)
5
+ - [Multimodal Bard by Google](https://bard.google.com/).
6
+
7
+ These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less explored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use.
8
+
9
+ ## LLaVA-Bench (In-the-Wild *[Ongoing work]*)
10
+
11
+ To evaluate the model's capability in more challenging tasks and generalizability to novel domains, we collect a diverse set of 24 images with 60 questions in total, including indoor and outdoor scenes, memes, paintings, sketches, etc, and associate each image with a highly-detailed and manually-curated description and a proper selection of questions. Such design also assesses the model's robustness to different prompts. In this release, we also categorize questions into three categories: conversation (simple QA), detailed description, and complex reasoning. We continue to expand and improve the diversity of the LLaVA-Bench (In-the-Wild). We manually query Bing-Chat and Bard to get the responses.
12
+
13
+ ### Results
14
+
15
+ The score is measured by comparing against a reference answer generated by text-only GPT-4. It is generated by feeding the question, along with the ground truth image annotations as the context. A text-only GPT-4 evaluator rates both answers. We query GPT-4 by putting the reference answer first, and then the answer generated by the candidate model. We upload images at their original resolution to Bard and Bing-Chat to obtain the results.
16
+
17
+ | Approach | Conversation | Detail | Reasoning | Overall |
18
+ |----------------|--------------|--------|-----------|---------|
19
+ | Bard-0718 | 83.7 | 69.7 | 78.7 | 77.8 |
20
+ | Bing-Chat-0629 | 59.6 | 52.2 | 90.1 | 71.5 |
21
+ | LLaVA-13B-v1-336px-0719 (beam=1) | 64.3 | 55.9 | 81.7 | 70.1 |
22
+ | LLaVA-13B-v1-336px-0719 (beam=5) | 68.4 | 59.9 | 84.3 | 73.5 |
23
+
24
+ Note that Bard sometimes refuses to answer questions about images containing humans, and Bing-Chat blurs the human faces in the images. We also provide the benchmark score for the subset without humans.
25
+
26
+ | Approach | Conversation | Detail | Reasoning | Overall |
27
+ |----------------|--------------|--------|-----------|---------|
28
+ | Bard-0718 | 94.9 | 74.3 | 84.3 | 84.6 |
29
+ | Bing-Chat-0629 | 55.8 | 53.6 | 93.5 | 72.6 |
30
+ | LLaVA-13B-v1-336px-0719 (beam=1) | 62.2 | 56.4 | 82.2 | 70.0 |
31
+ | LLaVA-13B-v1-336px-0719 (beam=5) | 65.6 | 61.7 | 85.0 | 73.6 |
VLMEvalKit_old/InternVL/internvl_g/eval/evaluate_caption.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import time
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torchvision.transforms as T
11
+ from internvl.model.internvl_stage2 import InternVLConfig, InternVLModel
12
+ from PIL import Image
13
+ from pycocoevalcap.eval import COCOEvalCap
14
+ from pycocotools.coco import COCO
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from tqdm import tqdm
17
+ from transformers import LlamaTokenizer
18
+
19
+ ds_collections = {
20
+ 'flickr30k': {
21
+ 'root': 'data/flickr30k/',
22
+ 'annotation': 'data/flickr30k/flickr30k_test_karpathy.json',
23
+ },
24
+ 'coco': {
25
+ 'root': 'data/coco/',
26
+ 'annotation': ['data/coco/annotations/coco_karpathy_test.json',
27
+ 'data/coco/annotations/coco_karpathy_test_gt.json'],
28
+ },
29
+ 'nocaps': {
30
+ 'root': 'data/nocaps/images',
31
+ 'annotation': 'data/nocaps/nocaps_val_4500_captions.json',
32
+ },
33
+ }
34
+
35
+
36
+ class CaptionDataset(torch.utils.data.Dataset):
37
+
38
+ def __init__(self, name, root, annotation, prompt, input_size=224):
39
+ if name == 'coco':
40
+ self.images = json.load(open(annotation))
41
+ else:
42
+ self.images = json.load(open(annotation))['images']
43
+ self.name = name
44
+ self.prompt = prompt
45
+ self.root = root
46
+ self.transform = T.Compose([
47
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
48
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
49
+ T.ToTensor(),
50
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
51
+ ])
52
+
53
+ def __len__(self):
54
+ return len(self.images)
55
+
56
+ def __getitem__(self, idx):
57
+ if self.name == 'coco':
58
+ filename = self.images[idx]['image']
59
+ image_id = int(filename.split('_')[-1].replace('.jpg', ''))
60
+ image_path = os.path.join(self.root, filename)
61
+ else:
62
+ image_id = self.images[idx]['id']
63
+ if 'file_name' in self.images[idx]:
64
+ image_path = os.path.join(self.root, self.images[idx]['file_name'])
65
+ else:
66
+ image_path = os.path.join(self.root, self.images[idx]['image'])
67
+ image = Image.open(image_path)
68
+ pixel_values = self.transform(image).unsqueeze(0)
69
+ return {
70
+ 'image_id': image_id,
71
+ 'input_text': self.prompt,
72
+ 'pixel_values': pixel_values
73
+ }
74
+
75
+
76
+ def collate_fn(inputs, tokenizer):
77
+ pixel_values = torch.cat([_['pixel_values'] for _ in inputs], dim=0)
78
+ image_ids = [_['image_id'] for _ in inputs]
79
+ input_texts = [_['input_text'] for _ in inputs]
80
+ input_tokens = tokenizer(input_texts, return_tensors='pt')
81
+
82
+ return pixel_values, image_ids, input_tokens.input_ids, input_tokens.attention_mask
83
+
84
+
85
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
86
+
87
+ def __init__(self, size):
88
+ self._size = int(size)
89
+ assert size > 0
90
+ self._rank = torch.distributed.get_rank()
91
+ self._world_size = torch.distributed.get_world_size()
92
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
93
+
94
+ @staticmethod
95
+ def _get_local_indices(total_size, world_size, rank):
96
+ shard_size = total_size // world_size
97
+ left = total_size % world_size
98
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
99
+
100
+ begin = sum(shard_sizes[:rank])
101
+ end = min(sum(shard_sizes[:rank + 1]), total_size)
102
+ return range(begin, end)
103
+
104
+ def __iter__(self):
105
+ yield from self._local_indices
106
+
107
+ def __len__(self):
108
+ return len(self._local_indices)
109
+
110
+
111
+ def evaluate_qllama_model():
112
+ prompts = ['English caption:']
113
+ print('prompts:', prompts)
114
+
115
+ config = InternVLConfig.from_pretrained(args.checkpoint)
116
+ model = InternVLModel.from_pretrained(args.checkpoint, config=config).eval()
117
+ model = model.to(torch.float16).cuda()
118
+ tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
119
+ tokenizer.add_eos_token = False
120
+
121
+ random.seed(args.seed)
122
+ summaries = []
123
+ for prompt in prompts:
124
+ for ds_name in args.datasets:
125
+ annotation = ds_collections[ds_name]['annotation']
126
+ if type(annotation) == list:
127
+ annotation = annotation[0]
128
+ if model.config.force_image_size is not None:
129
+ image_size = model.config.force_image_size
130
+ else:
131
+ image_size = model.config.vision_config.image_size
132
+ dataset = CaptionDataset(
133
+ name=ds_name,
134
+ root=ds_collections[ds_name]['root'],
135
+ annotation=annotation,
136
+ prompt=prompt,
137
+ input_size=image_size,
138
+ )
139
+ dataloader = torch.utils.data.DataLoader(
140
+ dataset=dataset,
141
+ sampler=InferenceSampler(len(dataset)),
142
+ batch_size=args.batch_size,
143
+ num_workers=args.num_workers,
144
+ pin_memory=True,
145
+ drop_last=False,
146
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
147
+ )
148
+
149
+ image_ids, captions = [], []
150
+ for _, (pixel_values, ids, input_ids, attention_mask) in tqdm(enumerate(dataloader)):
151
+ pred = model.generate(
152
+ pixel_values=pixel_values.cuda().to(torch.float16),
153
+ input_ids=input_ids.cuda(),
154
+ attention_mask=attention_mask.cuda(),
155
+ do_sample=False,
156
+ num_beams=args.num_beams,
157
+ max_new_tokens=30,
158
+ min_new_tokens=8,
159
+ use_cache=True
160
+ )
161
+ image_ids.extend(ids)
162
+ caption = [tokenizer.decode(_.cpu(), skip_special_tokens=True).strip() for _ in pred]
163
+ captions.extend(caption)
164
+ print(caption)
165
+
166
+ torch.distributed.barrier()
167
+
168
+ world_size = torch.distributed.get_world_size()
169
+ merged_ids = [None for _ in range(world_size)]
170
+ merged_captions = [None for _ in range(world_size)]
171
+ torch.distributed.all_gather_object(merged_ids, image_ids)
172
+ torch.distributed.all_gather_object(merged_captions, captions)
173
+
174
+ merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
175
+ merged_captions = [_ for _ in itertools.chain.from_iterable(merged_captions)]
176
+ average_length = sum(len(x.split()) for x in merged_captions) / len(merged_captions)
177
+ print(f'Average length: {average_length}')
178
+
179
+ if torch.distributed.get_rank() == 0:
180
+ print(f'Evaluating {ds_name} ...')
181
+
182
+ results = []
183
+ for image_id, caption in zip(merged_ids, merged_captions):
184
+ results.append({
185
+ 'image_id': int(image_id),
186
+ 'caption': caption,
187
+ })
188
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
189
+ results_file = f'{ds_name}_{time_prefix}.json'
190
+ results_file = os.path.join(args.out_dir, results_file)
191
+ json.dump(results, open(results_file, 'w'))
192
+
193
+ annotation = ds_collections[ds_name]['annotation']
194
+ if type(annotation) == list:
195
+ annotation = annotation[-1]
196
+ coco = COCO(annotation)
197
+ coco_result = coco.loadRes(results_file)
198
+ coco_eval = COCOEvalCap(coco, coco_result)
199
+ coco_eval.evaluate()
200
+
201
+ summary = coco_eval.eval.items()
202
+ print([ds_name, prompt, average_length, summary])
203
+ summaries.append([ds_name, prompt, average_length, summary])
204
+
205
+ torch.distributed.barrier()
206
+
207
+ for summary in summaries:
208
+ print(summary)
209
+
210
+
211
+ if __name__ == '__main__':
212
+
213
+ parser = argparse.ArgumentParser()
214
+ parser.add_argument('--checkpoint', type=str, default='')
215
+ parser.add_argument('--datasets', type=str, default='coco,flickr30k,nocaps')
216
+ parser.add_argument('--batch-size', type=int, default=1)
217
+ parser.add_argument('--num-workers', type=int, default=1)
218
+ parser.add_argument('--num-beams', type=int, default=5)
219
+ parser.add_argument('--out-dir', type=str, default='results')
220
+ parser.add_argument('--seed', type=int, default=0)
221
+ args = parser.parse_args()
222
+
223
+ os.makedirs(args.out_dir, exist_ok=True)
224
+
225
+ args.datasets = args.datasets.split(',')
226
+ print('datasets:', args.datasets)
227
+ assert args.batch_size == 1, 'Only batch size 1 is supported'
228
+
229
+ torch.distributed.init_process_group(
230
+ backend='nccl',
231
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
232
+ rank=int(os.getenv('RANK', '0')),
233
+ )
234
+
235
+ torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
236
+
237
+ evaluate_qllama_model()
VLMEvalKit_old/InternVL/internvl_g/internvl/model/__init__.py ADDED
File without changes
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms import InterpolationMode
11
+ from transformers import LlamaTokenizer
12
+
13
+ from .configuration_intern_vit import InternVisionConfig
14
+ from .configuration_internvl import InternVLConfig
15
+ from .modeling_intern_vit import InternVisionModel
16
+ from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
17
+
18
+ __all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
19
+ 'InternVLModel', 'InternVL_C', 'InternVL_G']
20
+
21
+
22
+ # Prefix the text "summarize:"
23
+ class InternVLTokenizer(nn.Module):
24
+ def __init__(self, model_path):
25
+ super(InternVLTokenizer, self).__init__()
26
+ self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
27
+ self.tokenizer.pad_token = ' ' # allow padding
28
+ self.tokenizer.add_eos_token = True
29
+
30
+ def forward(self, text, prefix='summarize:'):
31
+ if type(text) == str:
32
+ text = prefix + text
33
+ elif type(text) == list:
34
+ text = [prefix + item for item in text]
35
+ text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
36
+ return text
37
+
38
+
39
+ def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
40
+ if task == 'retrieval':
41
+ transform = T.Compose([
42
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
43
+ T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
44
+ T.ToTensor(),
45
+ T.Normalize(mean=mean, std=std)])
46
+ else:
47
+ transform = T.Compose([
48
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
49
+ T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
50
+ T.CenterCrop(image_size),
51
+ T.ToTensor(),
52
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
53
+ return transform
54
+
55
+
56
+ def load_internvl_c_huggingface(ckpt_path, device, task):
57
+ model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
58
+ if model.config.use_backbone_lora:
59
+ model.vision_model.merge_and_unload()
60
+ model.vision_model = model.vision_model.model
61
+ if model.config.use_qllama_lora:
62
+ model.qllama.merge_and_unload()
63
+ model.qllama = model.qllama.model
64
+ if model.config.force_image_size is not None:
65
+ image_size = model.config.force_image_size
66
+ else:
67
+ image_size = model.config.vision_config.image_size
68
+ transform = build_transform(task, image_size)
69
+ tokenizer = InternVLTokenizer(ckpt_path)
70
+ return model, transform, tokenizer
71
+
72
+
73
+ def load_internvl_g_huggingface(ckpt_path, device, task):
74
+ model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
75
+ if model.config.use_backbone_lora:
76
+ model.vision_model.merge_and_unload()
77
+ model.vision_model = model.vision_model.model
78
+ if model.config.use_qllama_lora:
79
+ model.qllama.merge_and_unload()
80
+ model.qllama = model.qllama.model
81
+ if model.config.force_image_size is not None:
82
+ image_size = model.config.force_image_size
83
+ else:
84
+ image_size = model.config.vision_config.image_size
85
+ transform = build_transform(task, image_size)
86
+ tokenizer = InternVLTokenizer(ckpt_path)
87
+ return model, transform, tokenizer
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/configuration_intern_vit.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = 'intern_vit_6b'
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act='gelu',
76
+ layer_norm_eps=1e-6,
77
+ dropout=0.0,
78
+ drop_path_rate=0.0,
79
+ attention_dropout=0.0,
80
+ initializer_range=0.02,
81
+ initializer_factor=0.1,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+
86
+ self.hidden_size = hidden_size
87
+ self.intermediate_size = intermediate_size
88
+ self.dropout = dropout
89
+ self.drop_path_rate = drop_path_rate
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.initializer_range = initializer_range
96
+ self.initializer_factor = initializer_factor
97
+ self.attention_dropout = attention_dropout
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.qkv_bias = qkv_bias
101
+ self.qk_normalization = qk_normalization
102
+ self.use_flash_attn = use_flash_attn
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ if 'vision_config' in config_dict:
109
+ config_dict = config_dict['vision_config']
110
+
111
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
112
+ logger.warning(
113
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
114
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
115
+ )
116
+
117
+ return cls.from_dict(config_dict, **kwargs)
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_intern_vit.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ from .flash_attention import FlashAttention
24
+ has_flash_attn = True
25
+ except:
26
+ print('FlashAttention is not installed.')
27
+ has_flash_attn = False
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
+ except ImportError:
54
+ # using the normal InternRMSNorm
55
+ pass
56
+ except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
+ pass
59
+
60
+
61
+ class InternVisionEmbeddings(nn.Module):
62
+ def __init__(self, config: InternVisionConfig):
63
+ super().__init__()
64
+ self.config = config
65
+ self.embed_dim = config.hidden_size
66
+ self.image_size = config.image_size
67
+ self.patch_size = config.patch_size
68
+
69
+ self.class_embedding = nn.Parameter(
70
+ torch.randn(1, 1, self.embed_dim),
71
+ )
72
+
73
+ self.patch_embedding = nn.Conv2d(
74
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
75
+ )
76
+
77
+ self.num_patches = (self.image_size // self.patch_size) ** 2
78
+ self.num_positions = self.num_patches + 1
79
+
80
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
81
+
82
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
+ batch_size = pixel_values.shape[0]
84
+ target_dtype = self.patch_embedding.weight.dtype
85
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
86
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
89
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
90
+ return embeddings
91
+
92
+
93
+ class InternAttention(nn.Module):
94
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
95
+
96
+ def __init__(self, config: InternVisionConfig):
97
+ super().__init__()
98
+ self.config = config
99
+ self.embed_dim = config.hidden_size
100
+ self.num_heads = config.num_attention_heads
101
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
+ if config.use_flash_attn and not has_flash_attn:
103
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
104
+ self.head_dim = self.embed_dim // self.num_heads
105
+ if self.head_dim * self.num_heads != self.embed_dim:
106
+ raise ValueError(
107
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
+ f' {self.num_heads}).'
109
+ )
110
+
111
+ self.scale = self.head_dim ** -0.5
112
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
+ self.attn_drop = nn.Dropout(config.attention_dropout)
114
+ self.proj_drop = nn.Dropout(config.dropout)
115
+
116
+ self.qk_normalization = config.qk_normalization
117
+
118
+ if self.qk_normalization:
119
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
120
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
121
+
122
+ if self.use_flash_attn:
123
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
124
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
125
+
126
+ def _naive_attn(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
130
+
131
+ if self.qk_normalization:
132
+ B_, H_, N_, D_ = q.shape
133
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
135
+
136
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+
140
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
141
+ x = self.proj(x)
142
+ x = self.proj_drop(x)
143
+ return x
144
+
145
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
+ qkv = self.qkv(x)
147
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
148
+
149
+ if self.qk_normalization:
150
+ q, k, v = qkv.unbind(2)
151
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
152
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
153
+ qkv = torch.stack([q, k, v], dim=2)
154
+
155
+ context, _ = self.inner_attn(
156
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
157
+ )
158
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
+ outs = self.proj_drop(outs)
160
+ return outs
161
+
162
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
164
+ return x
165
+
166
+
167
+ class InternMLP(nn.Module):
168
+ def __init__(self, config: InternVisionConfig):
169
+ super().__init__()
170
+ self.config = config
171
+ self.act = ACT2FN[config.hidden_act]
172
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
173
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ hidden_states = self.fc1(hidden_states)
177
+ hidden_states = self.act(hidden_states)
178
+ hidden_states = self.fc2(hidden_states)
179
+ return hidden_states
180
+
181
+
182
+ class InternVisionEncoderLayer(nn.Module):
183
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
184
+ super().__init__()
185
+ self.embed_dim = config.hidden_size
186
+ self.intermediate_size = config.intermediate_size
187
+
188
+ self.attn = InternAttention(config)
189
+ self.mlp = InternMLP(config)
190
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
191
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
192
+
193
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
202
+ """
203
+ Args:
204
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
+ """
206
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
207
+
208
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class InternVisionEncoder(nn.Module):
214
+ """
215
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
216
+ [`InternEncoderLayer`].
217
+
218
+ Args:
219
+ config (`InternConfig`):
220
+ The corresponding vision configuration for the `InternEncoder`.
221
+ """
222
+
223
+ def __init__(self, config: InternVisionConfig):
224
+ super().__init__()
225
+ self.config = config
226
+ # stochastic depth decay rule
227
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
+ self.layers = nn.ModuleList([
229
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
230
+ self.gradient_checkpointing = True
231
+
232
+ def forward(
233
+ self,
234
+ inputs_embeds,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
+ ) -> Union[Tuple, BaseModelOutput]:
238
+ r"""
239
+ Args:
240
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
241
+ Embedded representation of the inputs. Should be float, not int tokens.
242
+ output_hidden_states (`bool`, *optional*):
243
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
244
+ for more detail.
245
+ return_dict (`bool`, *optional*):
246
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
+ """
248
+ output_hidden_states = (
249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
250
+ )
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
+
253
+ encoder_states = () if output_hidden_states else None
254
+ hidden_states = inputs_embeds
255
+
256
+ for idx, encoder_layer in enumerate(self.layers):
257
+ if output_hidden_states:
258
+ encoder_states = encoder_states + (hidden_states,)
259
+ if self.gradient_checkpointing and self.training:
260
+ layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ encoder_layer,
262
+ hidden_states)
263
+ else:
264
+ layer_outputs = encoder_layer(
265
+ hidden_states,
266
+ )
267
+ hidden_states = layer_outputs
268
+
269
+ if output_hidden_states:
270
+ encoder_states = encoder_states + (hidden_states,)
271
+
272
+ if not return_dict:
273
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
274
+ return BaseModelOutput(
275
+ last_hidden_state=hidden_states, hidden_states=encoder_states
276
+ )
277
+
278
+
279
+ class InternVisionModel(PreTrainedModel):
280
+ main_input_name = 'pixel_values'
281
+ config_class = InternVisionConfig
282
+
283
+ def __init__(self, config: InternVisionConfig):
284
+ super().__init__(config)
285
+ self.config = config
286
+
287
+ self.embeddings = InternVisionEmbeddings(config)
288
+ self.encoder = InternVisionEncoder(config)
289
+
290
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
291
+ pos_emb = self.embeddings.position_embedding
292
+ _, num_positions, embed_dim = pos_emb.shape
293
+ cls_emb = pos_emb[:, :1, :]
294
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
295
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
296
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
297
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
298
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
299
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
300
+
301
+ def get_input_embeddings(self):
302
+ return self.embeddings
303
+
304
+ def forward(
305
+ self,
306
+ pixel_values: Optional[torch.FloatTensor] = None,
307
+ output_hidden_states: Optional[bool] = None,
308
+ return_dict: Optional[bool] = None,
309
+ pixel_embeds: Optional[torch.FloatTensor] = None,
310
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ if pixel_values is None and pixel_embeds is None:
317
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
318
+
319
+ if pixel_embeds is not None:
320
+ hidden_states = pixel_embeds
321
+ else:
322
+ if len(pixel_values.shape) == 4:
323
+ hidden_states = self.embeddings(pixel_values)
324
+ else:
325
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
326
+ encoder_outputs = self.encoder(
327
+ inputs_embeds=hidden_states,
328
+ output_hidden_states=output_hidden_states,
329
+ return_dict=return_dict,
330
+ )
331
+ last_hidden_state = encoder_outputs.last_hidden_state
332
+ pooled_output = last_hidden_state[:, 0, :]
333
+
334
+ if not return_dict:
335
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
336
+
337
+ return BaseModelOutputWithPooling(
338
+ last_hidden_state=last_hidden_state,
339
+ pooler_output=pooled_output,
340
+ hidden_states=encoder_outputs.hidden_states,
341
+ attentions=encoder_outputs.attentions,
342
+ )
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_internvl.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import Any, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ from peft import LoraConfig, get_peft_model
16
+ from timm.models.layers import DropPath
17
+ from torch import nn
18
+ from transformers import GenerationConfig
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import ModelOutput, logging
21
+
22
+ from .configuration_internvl import InternVLConfig
23
+ from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
24
+ InternVisionModel)
25
+ from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
26
+
27
+ try:
28
+ from .flash_attention import FlashAttention # v1/v2
29
+ except:
30
+ print('FlashAttention is not installed.')
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class InternVLPreTrainedModel(PreTrainedModel):
36
+ """
37
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
38
+ models.
39
+ """
40
+
41
+ config_class = InternVLConfig
42
+ base_model_prefix = 'internvl'
43
+ supports_gradient_checkpointing = True
44
+ _keys_to_ignore_on_load_missing = [
45
+ r'position_ids',
46
+ ]
47
+ _no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
48
+ _skip_keys_device_placement = 'past_key_values'
49
+ _keep_in_fp32_modules = ['wo']
50
+
51
+ # def _init_weights(self, module):
52
+ # """Initialize the weights"""
53
+ # factor = self.config.initializer_range
54
+ # if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
55
+ # module.weight.data.normal_(mean=0.0, std=factor)
56
+ # if hasattr(module, 'bias') and module.bias is not None:
57
+ # module.bias.data.zero_()
58
+ # if isinstance(module, InternVisionEmbeddings):
59
+ # if hasattr(self.config, 'vision_config'):
60
+ # factor = self.config.vision_config.initializer_range
61
+ # nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
62
+ # nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
63
+ # elif isinstance(module, nn.LayerNorm):
64
+ # module.bias.data.zero_()
65
+ # module.weight.data.fill_(1.0)
66
+ # elif isinstance(module, nn.Linear) and module.bias is not None:
67
+ # module.bias.data.zero_()
68
+
69
+ def _set_gradient_checkpointing(self, module, value=False):
70
+ if isinstance(module, InternVisionModel):
71
+ module.gradient_checkpointing = value
72
+ if isinstance(module, InternVisionEncoder):
73
+ module.gradient_checkpointing = value
74
+
75
+
76
+ class CrossAttention(nn.Module):
77
+ def __init__(
78
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
79
+ proj_drop=0., attn_head_dim=None, out_dim=None):
80
+ super().__init__()
81
+ if out_dim is None:
82
+ out_dim = dim
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ self.scale = qk_scale or head_dim ** -0.5
89
+ assert all_head_dim == dim
90
+
91
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
92
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
93
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
94
+
95
+ if qkv_bias:
96
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
97
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
98
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
99
+ else:
100
+ self.q_bias = None
101
+ self.k_bias = None
102
+ self.v_bias = None
103
+
104
+ self.attn_drop = nn.Dropout(attn_drop)
105
+ self.proj = nn.Linear(all_head_dim, out_dim)
106
+ self.proj_drop = nn.Dropout(proj_drop)
107
+
108
+ def forward(self, x, k=None, v=None):
109
+ B, N, C = x.shape
110
+ N_k = k.shape[1]
111
+ N_v = v.shape[1]
112
+
113
+ q_bias, k_bias, v_bias = None, None, None
114
+ if self.q_bias is not None:
115
+ q_bias = self.q_bias
116
+ k_bias = self.k_bias
117
+ v_bias = self.v_bias
118
+
119
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
120
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
121
+
122
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
123
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
124
+
125
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
126
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
130
+
131
+ attn = attn.softmax(dim=-1)
132
+ attn = self.attn_drop(attn)
133
+
134
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
135
+ x = self.proj(x)
136
+ x = self.proj_drop(x)
137
+
138
+ return x
139
+
140
+
141
+ class AttentiveBlock(nn.Module):
142
+
143
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
144
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
145
+ super().__init__()
146
+
147
+ self.norm1_q = norm_layer(dim)
148
+ self.norm1_k = norm_layer(dim)
149
+ self.norm1_v = norm_layer(dim)
150
+ self.cross_attn = CrossAttention(
151
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
152
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
153
+
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+
156
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
157
+ x_q = self.norm1_q(x_q + pos_q)
158
+ x_k = self.norm1_k(x_kv + pos_k)
159
+ x_v = self.norm1_v(x_kv)
160
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
161
+
162
+ return x
163
+
164
+
165
+ class AttentionPoolingBlock(AttentiveBlock):
166
+
167
+ def forward(self, x):
168
+ x_q = x.mean(1, keepdim=True)
169
+ x_kv, pos_q, pos_k = x, 0, 0
170
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
171
+ x = x.squeeze(1)
172
+ return x
173
+
174
+
175
+ @dataclass
176
+ class InternVLModelOutput(ModelOutput):
177
+ """
178
+ Class defining the outputs of [`InternVLModelOutput`].
179
+ """
180
+
181
+ loss: Optional[torch.FloatTensor] = None
182
+ loss_itm: Optional[torch.FloatTensor] = None
183
+ loss_itc: Optional[torch.FloatTensor] = None
184
+ loss_itg: Optional[torch.FloatTensor] = None
185
+
186
+ def to_tuple(self) -> Tuple[Any]:
187
+ return tuple(
188
+ self[k]
189
+ if k not in ['loss', 'loss_itm', 'loss_itc', 'loss_itg']
190
+ else getattr(self, k).to_tuple()
191
+ for k in self.keys()
192
+ )
193
+
194
+
195
+ class GatherLayer(torch.autograd.Function):
196
+ """Gather tensors from all process, supporting backward propagation.
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, input):
201
+ ctx.save_for_backward(input)
202
+ output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
203
+ dist.all_gather(output, input)
204
+ return torch.stack(output, 0)
205
+
206
+ @staticmethod
207
+ def backward(ctx, grads):
208
+ input, = ctx.saved_tensors
209
+ dist.all_reduce(grads)
210
+ grad_out = torch.zeros_like(input)
211
+ grad_out[:] = grads[dist.get_rank()]
212
+ return grad_out
213
+
214
+
215
+ class InternVLModel(InternVLPreTrainedModel):
216
+ config_class = InternVLConfig
217
+ main_input_name = 'pixel_values'
218
+
219
+ def __init__(self, config: InternVLConfig):
220
+ super().__init__(config)
221
+
222
+ text_hidden_size = config.qllama_config.hidden_size
223
+ vision_hidden_size = config.vision_config.hidden_size
224
+ clip_embed_dim = config.clip_embed_dim
225
+ attn_pool_num_heads = config.attn_pool_num_heads
226
+ config.qllama_config.num_query_token = config.num_query_token
227
+ self.num_query_token = config.num_query_token
228
+ self.label_smoothing = config.label_smoothing
229
+
230
+ self.vision_model = InternVisionModel(config.vision_config) # frozen
231
+ self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
232
+ self.query_tokens = nn.Parameter( # trainable
233
+ torch.zeros(1, config.num_query_token, text_hidden_size)
234
+ )
235
+
236
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
237
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
238
+ self.clip_projector = AttentionPoolingBlock( # frozen
239
+ dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
240
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
241
+ self.clip_projector2 = AttentionPoolingBlock( # trainable
242
+ dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
243
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
244
+ self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
245
+ self.gradient_checkpointing = True
246
+
247
+ # Initialize weights and apply final processing
248
+ # self.post_init()
249
+
250
+ if config.use_backbone_lora:
251
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=config.use_backbone_lora * 2)
252
+ if config.use_qllama_lora:
253
+ self.wrap_qllama_lora(r=config.use_qllama_lora, lora_alpha=config.use_qllama_lora * 2)
254
+ if config.force_image_size:
255
+ self.vision_model.resize_pos_embeddings(
256
+ old_size=config.vision_config.image_size,
257
+ new_size=config.force_image_size,
258
+ patch_size=config.vision_config.patch_size
259
+ )
260
+
261
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
262
+ lora_config = LoraConfig(
263
+ r=r,
264
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
265
+ lora_alpha=lora_alpha,
266
+ lora_dropout=lora_dropout,
267
+ )
268
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
269
+ self.vision_model.print_trainable_parameters()
270
+
271
+ def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
272
+ lora_config = LoraConfig(
273
+ r=r,
274
+ target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
275
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
276
+ lora_alpha=lora_alpha,
277
+ lora_dropout=lora_dropout,
278
+ )
279
+ self.qllama = get_peft_model(self.qllama, lora_config)
280
+ self.qllama.print_trainable_parameters()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.qllama.get_input_embeddings()
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.qllama.set_input_embeddings(value)
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.qllama.set_output_embeddings(new_embeddings)
290
+
291
+ def get_output_embeddings(self) -> nn.Module:
292
+ return self.qllama.get_output_embeddings()
293
+
294
+ @torch.no_grad()
295
+ def _prepare_attention_mask(
296
+ self,
297
+ image_attention_mask: torch.LongTensor,
298
+ attention_mask: torch.LongTensor,
299
+ input_embeds: torch.FloatTensor,
300
+ repeat_time: int,
301
+ ):
302
+ # itm, itc, itg
303
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
304
+ expand_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
305
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
306
+ itm_mask, itc_mask, itg_mask = torch.chunk(expand_mask, repeat_time, dim=0)
307
+
308
+ itc_mask[:, :, :self.num_query_token, self.num_query_token:] = torch.finfo(input_embeds.dtype).min
309
+ itc_mask[:, :, self.num_query_token:, :self.num_query_token] = torch.finfo(input_embeds.dtype).min
310
+ itc_mask_causal = _make_causal_mask(
311
+ (itc_mask.shape[0], itc_mask.shape[2] - self.num_query_token),
312
+ input_embeds.dtype,
313
+ device=input_embeds.device
314
+ )
315
+ # use causal mask for text in itc
316
+ itc_mask[:, :, self.num_query_token:, self.num_query_token:] += itc_mask_causal
317
+
318
+ itg_mask_causal = _make_causal_mask(
319
+ (itg_mask.shape[0], itg_mask.shape[2]),
320
+ input_embeds.dtype,
321
+ device=input_embeds.device
322
+ )
323
+ itg_mask = itg_mask + itg_mask_causal
324
+ itg_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
325
+ attention_mask = torch.cat([itm_mask, itc_mask, itg_mask], dim=0)
326
+
327
+ return attention_mask
328
+
329
+ def forward(
330
+ self,
331
+ pixel_values: torch.FloatTensor,
332
+ positive_input_ids: torch.FloatTensor,
333
+ positive_attention_mask: torch.LongTensor,
334
+ negative_input_ids: torch.FloatTensor,
335
+ negative_attention_mask: torch.LongTensor,
336
+ summarize_input_ids: torch.FloatTensor,
337
+ summarize_attention_mask: torch.LongTensor,
338
+ input_ids: torch.FloatTensor,
339
+ attention_mask: torch.LongTensor,
340
+ labels: torch.LongTensor,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ ) -> Union[Tuple, InternVLModelOutput]:
345
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
+
347
+ # step 1: forward the images through the vision encoder,
348
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
349
+ vision_outputs = self.vision_model(
350
+ pixel_values=pixel_values,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict)
353
+ image_embeds = vision_outputs[0]
354
+ backbone_embeds = self.clip_projector(image_embeds)
355
+
356
+ # step 2: prepare input_ids and attention_mask for three sub-tasks:
357
+ # 1) image-text matching; 2) image-text contrastive learning; 3) image-grounded text generation.
358
+ batch_size = input_ids.shape[0]
359
+ self.positive_num = batch_size // 2
360
+ input_ids = torch.cat([negative_input_ids[:-self.positive_num], positive_input_ids[-self.positive_num:],
361
+ summarize_input_ids, input_ids], dim=0) # [3 * batch_size, seq_len]
362
+ itm_attention_mask = torch.cat(
363
+ [negative_attention_mask[:-self.positive_num], positive_attention_mask[-self.positive_num:]], dim=0)
364
+ selected = itm_attention_mask.sum(1) - 1
365
+ attention_mask = torch.cat(
366
+ [itm_attention_mask, summarize_attention_mask, attention_mask], dim=0) # [3 * batch_size, seq_len]
367
+
368
+ repeat_time = input_ids.size(0) // batch_size
369
+ # step 3: forward the input_ids and attention_mask through the text encoder.
370
+ input_embeds = self.get_input_embeddings()(input_ids)
371
+ query_tokens = self.query_tokens.repeat(repeat_time * batch_size, 1, 1)
372
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
373
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
374
+ attention_mask = self._prepare_attention_mask(
375
+ image_attention_mask, attention_mask, input_embeds, repeat_time
376
+ )
377
+ if type(self.qllama.model) == LlamaForCausalLM:
378
+ outputs = self.qllama.model.model.forward_train(
379
+ inputs_embeds=input_embeds,
380
+ vision_hidden_states=image_embeds,
381
+ attention_mask=attention_mask,
382
+ output_attentions=output_attentions,
383
+ output_hidden_states=output_hidden_states,
384
+ return_dict=return_dict,
385
+ repeat_time=repeat_time,
386
+ ).last_hidden_state
387
+ else:
388
+ outputs = self.qllama.model.forward_train(
389
+ inputs_embeds=input_embeds,
390
+ vision_hidden_states=image_embeds,
391
+ attention_mask=attention_mask,
392
+ output_attentions=output_attentions,
393
+ output_hidden_states=output_hidden_states,
394
+ return_dict=return_dict,
395
+ repeat_time=repeat_time,
396
+ ).last_hidden_state
397
+ image_embeds = outputs[:, :self.num_query_token]
398
+ text_embeds = outputs[:, self.num_query_token:]
399
+ image_itm, image_itc, image_itg = image_embeds.chunk(repeat_time, dim=0)
400
+ text_itm, text_itc, text_itg = text_embeds.chunk(repeat_time, dim=0)
401
+
402
+ ###============== Image-Text Matching ===================###
403
+ image_itm = self.itm_head(image_itm)
404
+ logits = image_itm.mean(dim=1)
405
+ itm_labels = torch.cat([
406
+ torch.zeros(batch_size - self.positive_num, dtype=torch.long, device=logits.device),
407
+ torch.ones(self.positive_num, dtype=torch.long, device=logits.device)
408
+ ], dim=0)
409
+ itm_labels[selected == 1] = -100 # ignore empty texts
410
+ loss_itm = F.cross_entropy(logits, itm_labels)
411
+ neg_match_acc = ((logits[:batch_size - self.positive_num].argmax(dim=-1) == 0) / (
412
+ batch_size - self.positive_num)).sum()
413
+ pos_match_acc = ((logits[-self.positive_num:].argmax(dim=-1) == 1) / self.positive_num).sum()
414
+
415
+ ###============== Image-Text Contrastive ===================###
416
+ image_itc = self.clip_projector2(image_itc)
417
+
418
+ selected = summarize_attention_mask.sum(1) - 1
419
+ text_itc = text_itc[torch.arange(text_itc.shape[0]), selected]
420
+ text_itc = text_itc @ self.text_projection
421
+
422
+ # normalized features
423
+ image_itc = image_itc / image_itc.norm(dim=1, keepdim=True)
424
+ text_itc = text_itc / text_itc.norm(dim=1, keepdim=True)
425
+ image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
426
+ text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
427
+
428
+ # cosine similarity as logits
429
+ logit_scale = self.logit_scale.exp()
430
+ sim_i2t = logit_scale * (image_itc @ text_itc_all.t())
431
+ sim_t2i = logit_scale * (text_itc @ image_itc_all.t())
432
+ bs = image_itc.size(0)
433
+ rank = dist.get_rank() if dist.is_initialized() else 0
434
+ targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=torch.long, device=sim_i2t.device)
435
+ targets[selected == 4] = -100 # ignore empty texts
436
+ loss_itc = (
437
+ F.cross_entropy(sim_i2t, targets, label_smoothing=self.label_smoothing)
438
+ + F.cross_entropy(sim_t2i, targets, label_smoothing=self.label_smoothing)
439
+ ) / 2
440
+
441
+ ###============== Image-grounded Text Generation ===================###
442
+ logits = self.qllama.lm_head(text_itg)
443
+ # Shift so that tokens < n predict n
444
+ shift_logits = logits[..., :-1, :].contiguous()
445
+ shift_labels = labels[..., 1:].contiguous()
446
+ # Flatten the tokens
447
+ shift_logits = shift_logits.view(-1, self.qllama.config.vocab_size)
448
+ shift_labels = shift_labels.view(-1)
449
+ # Enable model parallelism
450
+ shift_labels = shift_labels.to(shift_logits.device)
451
+ loss_itg = F.cross_entropy(shift_logits, shift_labels)
452
+
453
+ vision_sim = F.cosine_similarity(backbone_embeds.detach(), image_itc).mean()
454
+
455
+ loss = loss_itm + loss_itc + loss_itg
456
+ if dist.get_rank() == 0:
457
+ print(f'loss: {loss.item()}, loss_itm: {loss_itm.item()}, loss_itc: {loss_itc.item()}, '
458
+ f'loss_itg: {loss_itg.item()}, vision_similarity: {round(vision_sim.item(), 5)}, '
459
+ f'logit scale: {round(1.0 / logit_scale.item(), 5)}, '
460
+ f'pos_match_acc: {round(pos_match_acc.item(), 4)}, '
461
+ f'neg_match_acc: {round(neg_match_acc.item(), 4)}')
462
+
463
+ return InternVLModelOutput(
464
+ loss=loss,
465
+ loss_itc=loss_itc.detach(),
466
+ loss_itm=loss_itm.detach(),
467
+ loss_itg=loss_itg.detach(),
468
+ )
469
+
470
+ @torch.no_grad()
471
+ def generate(
472
+ self,
473
+ pixel_values: torch.FloatTensor,
474
+ input_ids: torch.FloatTensor,
475
+ attention_mask: torch.LongTensor,
476
+ generation_config: Optional[GenerationConfig] = None,
477
+ output_hidden_states: Optional[bool] = None,
478
+ return_dict: Optional[bool] = None,
479
+ **generate_kwargs,
480
+ ) -> torch.LongTensor:
481
+
482
+ vision_outputs = self.vision_model(
483
+ pixel_values=pixel_values,
484
+ output_hidden_states=output_hidden_states,
485
+ return_dict=return_dict)
486
+ image_embeds = vision_outputs[0]
487
+
488
+ batch_size = image_embeds.shape[0]
489
+ input_embeds = self.get_input_embeddings()(input_ids)
490
+ query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
491
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
492
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
493
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
494
+
495
+ outputs = self.qllama.generate(
496
+ inputs_embeds=input_embeds,
497
+ attention_mask=attention_mask,
498
+ vision_hidden_states=image_embeds,
499
+ generation_config=generation_config,
500
+ use_zero_attention_mask=True,
501
+ **generate_kwargs,
502
+ )
503
+
504
+ return outputs
505
+
506
+ def get_text_features(
507
+ self,
508
+ input_ids: torch.Tensor,
509
+ attention_mask: torch.Tensor,
510
+ output_attentions: Optional[bool] = None,
511
+ output_hidden_states: Optional[bool] = None,
512
+ return_dict: Optional[bool] = None,
513
+ ):
514
+ r"""
515
+ Returns:
516
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
517
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
518
+ contains the language model logits, the past key values and the hidden states if
519
+ `output_hidden_states=True`.
520
+ ```"""
521
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
522
+ output_hidden_states = (
523
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
524
+ )
525
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
+
527
+ input_embeds = self.get_input_embeddings()(input_ids)
528
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
529
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
530
+ attention_mask += _make_causal_mask(
531
+ (attention_mask.shape[0], attention_mask.shape[2]),
532
+ input_embeds.dtype,
533
+ device=input_embeds.device
534
+ )
535
+ if type(self.qllama.model) == LlamaForCausalLM:
536
+ outputs = self.qllama.model.model.forward_train(
537
+ inputs_embeds=input_embeds,
538
+ vision_hidden_states=None,
539
+ attention_mask=attention_mask,
540
+ output_attentions=output_attentions,
541
+ output_hidden_states=output_hidden_states,
542
+ return_dict=return_dict,
543
+ ).last_hidden_state
544
+ else:
545
+ outputs = self.qllama.model.forward_train(
546
+ inputs_embeds=input_embeds,
547
+ vision_hidden_states=None,
548
+ attention_mask=attention_mask,
549
+ output_attentions=output_attentions,
550
+ output_hidden_states=output_hidden_states,
551
+ return_dict=return_dict,
552
+ ).last_hidden_state
553
+ return outputs
554
+
555
+ def get_image_features(
556
+ self,
557
+ pixel_values: torch.FloatTensor,
558
+ output_attentions: Optional[bool] = None,
559
+ output_hidden_states: Optional[bool] = None,
560
+ return_dict: Optional[bool] = None,
561
+ ):
562
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
563
+ output_hidden_states = (
564
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
565
+ )
566
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
567
+
568
+ vision_outputs = self.vision_model(
569
+ pixel_values=pixel_values,
570
+ output_hidden_states=output_hidden_states,
571
+ return_dict=return_dict)
572
+ image_embeds = vision_outputs[0]
573
+ backbone_embeds = image_embeds
574
+
575
+ batch_size = image_embeds.shape[0]
576
+ input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
577
+
578
+ attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
579
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
580
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
581
+ if type(self.qllama.model) == LlamaForCausalLM:
582
+ outputs = self.qllama.model.model.forward_train(
583
+ inputs_embeds=input_embeds,
584
+ vision_hidden_states=image_embeds,
585
+ attention_mask=attention_mask,
586
+ output_attentions=output_attentions,
587
+ output_hidden_states=output_hidden_states,
588
+ return_dict=return_dict,
589
+ ).last_hidden_state
590
+ else:
591
+ outputs = self.qllama.model.forward_train(
592
+ inputs_embeds=input_embeds,
593
+ vision_hidden_states=image_embeds,
594
+ attention_mask=attention_mask,
595
+ output_attentions=output_attentions,
596
+ output_hidden_states=output_hidden_states,
597
+ return_dict=return_dict,
598
+ ).last_hidden_state
599
+ return backbone_embeds, outputs
600
+
601
+
602
+ class InternVL_C(InternVLModel):
603
+
604
+ def encode_image(self, image):
605
+ vision_outputs = self.vision_model(
606
+ pixel_values=image,
607
+ output_hidden_states=False,
608
+ return_dict=True)
609
+ image_embeds = vision_outputs[0]
610
+ image_embeds = self.clip_projector(image_embeds)
611
+ return image_embeds
612
+
613
+ def encode_text(self, text):
614
+ attention_mask = text > 0
615
+ text_embeds = self.get_text_features(
616
+ input_ids=text,
617
+ attention_mask=attention_mask,
618
+ output_attentions=False,
619
+ output_hidden_states=False,
620
+ return_dict=True,
621
+ )
622
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
623
+ text_embeds = text_embeds @ self.text_projection
624
+ return text_embeds
625
+
626
+ def forward(self, image, text):
627
+ image_features = self.encode_image(image)
628
+ text_features = self.encode_text(text)
629
+
630
+ # normalized features
631
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
632
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
633
+
634
+ # cosine similarity as logits
635
+ logit_scale = self.logit_scale.exp()
636
+ logits_per_image = logit_scale * image_features @ text_features.t()
637
+ logits_per_text = logits_per_image.t()
638
+
639
+ return logits_per_image, logits_per_text
640
+
641
+
642
+ class InternVL_G(InternVLModel):
643
+
644
+ def encode_image(self, image):
645
+ backbone_embeds, image_embeds = self.get_image_features(
646
+ pixel_values=image,
647
+ output_hidden_states=False,
648
+ return_dict=True,
649
+ )
650
+ backbone_embeds = self.clip_projector(backbone_embeds)
651
+ image_embeds = self.clip_projector2(image_embeds)
652
+ # ensemble
653
+ backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
654
+ image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
655
+ image_embeds = image_embeds + backbone_embeds
656
+ return image_embeds
657
+
658
+ def encode_text(self, text):
659
+ attention_mask = text > 0
660
+ text_embeds = self.get_text_features(
661
+ input_ids=text,
662
+ attention_mask=attention_mask,
663
+ output_attentions=False,
664
+ output_hidden_states=False,
665
+ return_dict=True,
666
+ )
667
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
668
+ text_embeds = text_embeds @ self.text_projection
669
+ return text_embeds
670
+
671
+ def forward(self, image, text):
672
+ image_features = self.encode_image(image)
673
+ text_features = self.encode_text(text)
674
+
675
+ # normalized features
676
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
677
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
678
+
679
+ # cosine similarity as logits
680
+ logit_scale = self.logit_scale.exp()
681
+ logits_per_image = logit_scale * image_features @ text_features.t()
682
+ logits_per_text = logits_per_image.t()
683
+
684
+ return logits_per_image, logits_per_text
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2/modeling_qllama.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """ PyTorch QLLaMA model."""
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers import LlamaConfig
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast)
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (add_start_docstrings,
33
+ add_start_docstrings_to_model_forward, logging,
34
+ replace_return_docstrings)
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = 'LlamaConfig'
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ try:
95
+ from functools import partial
96
+
97
+ from apex.normalization import FusedRMSNorm
98
+
99
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
100
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
101
+ except ImportError:
102
+ # using the normal LlamaRMSNorm
103
+ pass
104
+ except Exception:
105
+ print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
106
+ pass
107
+
108
+
109
+ class LlamaRotaryEmbedding(torch.nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
+ self.register_buffer('inv_freq', inv_freq)
114
+
115
+ # Build here to make `torch.jit.trace` work.
116
+ self.max_seq_len_cached = max_position_embeddings
117
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
118
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1)
121
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
122
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
123
+
124
+ def forward(self, x, seq_len=None):
125
+ # x: [bs, num_attention_heads, seq_len, head_size]
126
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
127
+ if seq_len > self.max_seq_len_cached:
128
+ self.max_seq_len_cached = seq_len
129
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
130
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
133
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
134
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
135
+ return (
136
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
137
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
138
+ )
139
+
140
+
141
+ class FixedLlamaRotaryEmbedding(torch.nn.Module):
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
143
+ super().__init__()
144
+
145
+ self.dim = dim
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.base = base
148
+ self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
149
+
150
+ # Build here to make `torch.jit.trace` work.
151
+ self._set_cos_sin_cache(
152
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
153
+ )
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
158
+
159
+ freqs = torch.outer(t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
163
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
164
+
165
+ def forward(self, x, seq_len=None):
166
+ # x: [bs, num_attention_heads, seq_len, head_size]
167
+ if seq_len > self.max_seq_len_cached:
168
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
169
+
170
+ return (
171
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
172
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
173
+ )
174
+
175
+
176
+ LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding
177
+
178
+
179
+ def rotate_half(x):
180
+ """Rotates half the hidden dims of the input."""
181
+ x1 = x[..., : x.shape[-1] // 2]
182
+ x2 = x[..., x.shape[-1] // 2:]
183
+ return torch.cat((-x2, x1), dim=-1)
184
+
185
+
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
187
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
188
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
189
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
190
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
191
+ q_embed = (q * cos) + (rotate_half(q) * sin)
192
+ k_embed = (k * cos) + (rotate_half(k) * sin)
193
+ return q_embed, k_embed
194
+
195
+
196
+ class LlamaMLP(nn.Module):
197
+ def __init__(
198
+ self,
199
+ hidden_size: int,
200
+ intermediate_size: int,
201
+ hidden_act: str,
202
+ ):
203
+ super().__init__()
204
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
205
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
206
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
207
+ self.act_fn = ACT2FN[hidden_act]
208
+
209
+ def forward(self, x):
210
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
211
+
212
+
213
+ class LlamaAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ def __init__(self, config: LlamaConfig):
217
+ super().__init__()
218
+ self.config = config
219
+ self.hidden_size = config.hidden_size
220
+ self.num_heads = config.num_attention_heads
221
+ self.head_dim = self.hidden_size // self.num_heads
222
+ self.max_position_embeddings = config.max_position_embeddings
223
+
224
+ if (self.head_dim * self.num_heads) != self.hidden_size:
225
+ raise ValueError(
226
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
227
+ f' and `num_heads`: {self.num_heads}).'
228
+ )
229
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
230
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
231
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
232
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
233
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
234
+
235
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
236
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
+ output_attentions: bool = False,
245
+ use_cache: bool = False,
246
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
247
+ bsz, q_len, _ = hidden_states.size()
248
+
249
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
251
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
+
253
+ kv_seq_len = key_states.shape[-2]
254
+ if past_key_value is not None:
255
+ kv_seq_len += past_key_value[0].shape[-2]
256
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
257
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
258
+ # [bsz, nh, t, hd]
259
+
260
+ if past_key_value is not None:
261
+ # reuse k, v, self_attention
262
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
263
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
264
+
265
+ past_key_value = (key_states, value_states) if use_cache else None
266
+
267
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
268
+
269
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
270
+ raise ValueError(
271
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
272
+ f' {attn_weights.size()}'
273
+ )
274
+
275
+ if attention_mask is not None:
276
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
277
+ raise ValueError(
278
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
279
+ )
280
+ attn_weights = attn_weights + attention_mask
281
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
282
+
283
+ # upcast attention to fp32
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
285
+ attn_output = torch.matmul(attn_weights, value_states)
286
+
287
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
288
+ raise ValueError(
289
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
290
+ f' {attn_output.size()}'
291
+ )
292
+
293
+ attn_output = attn_output.transpose(1, 2)
294
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
295
+
296
+ attn_output = self.o_proj(attn_output)
297
+
298
+ if not output_attentions:
299
+ attn_weights = None
300
+
301
+ return attn_output, attn_weights, past_key_value
302
+
303
+
304
+ class LlamaCrossAttention(nn.Module):
305
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
306
+
307
+ def __init__(self, config: LlamaConfig):
308
+ super().__init__()
309
+ self.config = config
310
+ self.hidden_size = config.hidden_size
311
+ self.num_heads = config.num_attention_heads
312
+ self.head_dim = self.hidden_size // self.num_heads
313
+ self.max_position_embeddings = config.max_position_embeddings
314
+ self.vision_hidden_size = 3200
315
+
316
+ if (self.head_dim * self.num_heads) != self.hidden_size:
317
+ raise ValueError(
318
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
319
+ f' and `num_heads`: {self.num_heads}).'
320
+ )
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
322
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
323
+ self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
+
325
+ self.k_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
326
+ self.v_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
327
+ self.norm2 = LlamaRMSNorm(self.vision_hidden_size, eps=config.rms_norm_eps)
328
+
329
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
330
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ vision_hidden_states: torch.Tensor,
336
+ repeat_time: int = 1,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
339
+ output_attentions: bool = False,
340
+ use_cache: bool = False,
341
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
342
+ hidden_states = self.norm1(hidden_states)
343
+
344
+ bsz, q_len, _ = hidden_states.size()
345
+
346
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347
+
348
+ vision_hidden_states = self.norm2(vision_hidden_states)
349
+
350
+ bs_v, kv_len, _ = vision_hidden_states.size()
351
+
352
+ key_states = self.k_proj(vision_hidden_states).view(
353
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ value_states = self.v_proj(vision_hidden_states).view(
355
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
356
+
357
+ key_states = key_states.repeat(repeat_time, 1, 1, 1)
358
+ value_states = value_states.repeat(repeat_time, 1, 1, 1)
359
+
360
+ kv_seq_len = key_states.shape[-2]
361
+ if past_key_value is not None:
362
+ kv_seq_len += past_key_value[0].shape[-2]
363
+
364
+ if past_key_value is not None:
365
+ # reuse k, v, self_attention
366
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
367
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
368
+
369
+ past_key_value = (key_states, value_states) if use_cache else None
370
+
371
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372
+
373
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
374
+ raise ValueError(
375
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
376
+ f' {attn_weights.size()}'
377
+ )
378
+
379
+ if attention_mask is not None:
380
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
383
+ )
384
+ attn_weights = attn_weights + attention_mask
385
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
386
+
387
+ # upcast attention to fp32
388
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
389
+ attn_output = torch.matmul(attn_weights, value_states)
390
+
391
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
392
+ raise ValueError(
393
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
394
+ f' {attn_output.size()}'
395
+ )
396
+
397
+ attn_output = attn_output.transpose(1, 2)
398
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
399
+
400
+ attn_output = self.o_proj(attn_output)
401
+
402
+ if not output_attentions:
403
+ attn_weights = None
404
+
405
+ return attn_output, attn_weights, past_key_value
406
+
407
+
408
+ class LlamaDecoderLayer(nn.Module):
409
+ def __init__(self, config: LlamaConfig, use_cross_attn: bool):
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ self.self_attn = LlamaAttention(config=config)
413
+ self.cross_attn = LlamaCrossAttention(config=config) if use_cross_attn else None
414
+ self.mlp = LlamaMLP(
415
+ hidden_size=self.hidden_size,
416
+ intermediate_size=config.intermediate_size,
417
+ hidden_act=config.hidden_act,
418
+ )
419
+ self.num_query_token = 96
420
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
421
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states: torch.Tensor,
426
+ vision_hidden_states: torch.Tensor,
427
+ attention_mask: Optional[torch.Tensor] = None,
428
+ position_ids: Optional[torch.LongTensor] = None,
429
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
430
+ output_attentions: Optional[bool] = False,
431
+ use_cache: Optional[bool] = False,
432
+ repeat_time: int = 1,
433
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
434
+ """
435
+ Args:
436
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
437
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
438
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
439
+ output_attentions (`bool`, *optional*):
440
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
441
+ returned tensors for more detail.
442
+ use_cache (`bool`, *optional*):
443
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
444
+ (see `past_key_values`).
445
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
446
+ """
447
+
448
+ residual = hidden_states
449
+
450
+ hidden_states = self.input_layernorm(hidden_states)
451
+
452
+ # Self Attention
453
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
454
+ hidden_states=hidden_states,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_value=past_key_value,
458
+ output_attentions=output_attentions,
459
+ use_cache=use_cache,
460
+ )
461
+ hidden_states = residual + hidden_states
462
+
463
+ # when using generate function and cache mode, the size of hidden_states is 1,
464
+ # so we should not use cross attention
465
+ if self.cross_attn is not None and hidden_states.size(1) >= self.num_query_token \
466
+ and vision_hidden_states is not None:
467
+ query_feats = hidden_states[:, :self.num_query_token, :]
468
+ text_feats = hidden_states[:, self.num_query_token:, :]
469
+ residual = query_feats
470
+ query_feats, _, _ = self.cross_attn(
471
+ hidden_states=query_feats,
472
+ vision_hidden_states=vision_hidden_states,
473
+ attention_mask=None, # not use attention mask in cross attention
474
+ past_key_value=past_key_value,
475
+ output_attentions=output_attentions,
476
+ use_cache=use_cache,
477
+ repeat_time=repeat_time,
478
+ )
479
+ query_feats = residual + query_feats
480
+ hidden_states = torch.cat([query_feats, text_feats], dim=1)
481
+
482
+ # Fully Connected
483
+ residual = hidden_states
484
+ hidden_states = self.post_attention_layernorm(hidden_states)
485
+ hidden_states = self.mlp(hidden_states)
486
+ hidden_states = residual + hidden_states
487
+
488
+ outputs = (hidden_states,)
489
+
490
+ if output_attentions:
491
+ outputs += (self_attn_weights,)
492
+
493
+ if use_cache:
494
+ outputs += (present_key_value,)
495
+
496
+ return outputs
497
+
498
+
499
+ LLAMA_START_DOCSTRING = r"""
500
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
501
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
502
+ etc.)
503
+
504
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
505
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
506
+ and behavior.
507
+
508
+ Parameters:
509
+ config ([`LlamaConfig`]):
510
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
511
+ load the weights associated with the model, only the configuration. Check out the
512
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
513
+ """
514
+
515
+
516
+ @add_start_docstrings(
517
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
518
+ LLAMA_START_DOCSTRING,
519
+ )
520
+ class LlamaPreTrainedModel(PreTrainedModel):
521
+ config_class = LlamaConfig
522
+ base_model_prefix = 'model'
523
+ supports_gradient_checkpointing = True
524
+ _no_split_modules = ['LlamaDecoderLayer']
525
+ _keys_to_ignore_on_load_unexpected = [r'decoder\.version']
526
+
527
+ def _init_weights(self, module):
528
+ std = self.config.initializer_range
529
+ if isinstance(module, nn.Linear):
530
+ module.weight.data.normal_(mean=0.0, std=std)
531
+ if module.bias is not None:
532
+ module.bias.data.zero_()
533
+ elif isinstance(module, nn.Embedding):
534
+ module.weight.data.normal_(mean=0.0, std=std)
535
+ if module.padding_idx is not None:
536
+ module.weight.data[module.padding_idx].zero_()
537
+
538
+ def _set_gradient_checkpointing(self, module, value=False):
539
+ if isinstance(module, LlamaModel):
540
+ module.gradient_checkpointing = value
541
+ if isinstance(module, LlamaDecoderLayer):
542
+ module.gradient_checkpointing = value
543
+
544
+
545
+ LLAMA_INPUTS_DOCSTRING = r"""
546
+ Args:
547
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
548
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
549
+ it.
550
+
551
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+
554
+ [What are input IDs?](../glossary#input-ids)
555
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
556
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
557
+
558
+ - 1 for tokens that are **not masked**,
559
+ - 0 for tokens that are **masked**.
560
+
561
+ [What are attention masks?](../glossary#attention-mask)
562
+
563
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
564
+ [`PreTrainedTokenizer.__call__`] for details.
565
+
566
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
567
+ `past_key_values`).
568
+
569
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
570
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
571
+ information on the default strategy.
572
+
573
+ - 1 indicates the head is **not masked**,
574
+ - 0 indicates the head is **masked**.
575
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
577
+ config.n_positions - 1]`.
578
+
579
+ [What are position IDs?](../glossary#position-ids)
580
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
581
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
582
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
583
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
584
+
585
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
586
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
587
+
588
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
589
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
590
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
591
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
592
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
593
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
594
+ model's internal embedding lookup matrix.
595
+ use_cache (`bool`, *optional*):
596
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
597
+ `past_key_values`).
598
+ output_attentions (`bool`, *optional*):
599
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
600
+ tensors for more detail.
601
+ output_hidden_states (`bool`, *optional*):
602
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
603
+ more detail.
604
+ return_dict (`bool`, *optional*):
605
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
606
+ """
607
+
608
+
609
+ @add_start_docstrings(
610
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
611
+ LLAMA_START_DOCSTRING,
612
+ )
613
+ class LlamaModel(LlamaPreTrainedModel):
614
+ """
615
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
616
+
617
+ Args:
618
+ config: LlamaConfig
619
+ """
620
+
621
+ def __init__(self, config: LlamaConfig):
622
+ super().__init__(config)
623
+ self.padding_idx = config.pad_token_id
624
+ self.vocab_size = config.vocab_size
625
+ self.cross_attention_frequency = config.cross_attention_frequency
626
+ self.num_query_token = config.num_query_token
627
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
628
+ use_cross_attn = [idx % self.cross_attention_frequency == 0 for idx in range(config.num_hidden_layers)]
629
+ self.layers = nn.ModuleList(
630
+ [LlamaDecoderLayer(config, use_cross_attn[idx]) for idx in range(config.num_hidden_layers)])
631
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
+ self.gradient_checkpointing = False
633
+ # Initialize weights and apply final processing
634
+ # self.post_init()
635
+
636
+ def get_input_embeddings(self):
637
+ return self.embed_tokens
638
+
639
+ def set_input_embeddings(self, value):
640
+ self.embed_tokens = value
641
+
642
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
643
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
644
+ # create causal mask
645
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
646
+ combined_attention_mask = None
647
+ if input_shape[-1] > 1:
648
+ combined_attention_mask = _make_causal_mask(
649
+ input_shape,
650
+ inputs_embeds.dtype,
651
+ device=inputs_embeds.device,
652
+ past_key_values_length=past_key_values_length,
653
+ )
654
+
655
+ if attention_mask is not None:
656
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
657
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
658
+ inputs_embeds.device
659
+ )
660
+ combined_attention_mask = (
661
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
662
+ )
663
+
664
+ return combined_attention_mask
665
+
666
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
667
+ def forward(
668
+ self,
669
+ input_ids: torch.LongTensor = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ position_ids: Optional[torch.LongTensor] = None,
672
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
673
+ inputs_embeds: Optional[torch.FloatTensor] = None,
674
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
675
+ repeat_time: Optional[int] = 1,
676
+ use_cache: Optional[bool] = None,
677
+ output_attentions: Optional[bool] = None,
678
+ output_hidden_states: Optional[bool] = None,
679
+ use_zero_attention_mask: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
687
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
688
+
689
+ # retrieve input_ids and inputs_embeds
690
+ if input_ids is not None and inputs_embeds is not None:
691
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
692
+ elif input_ids is not None:
693
+ batch_size, seq_length = input_ids.shape
694
+ elif inputs_embeds is not None:
695
+ batch_size, seq_length, _ = inputs_embeds.shape
696
+ else:
697
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
698
+ seq_length_with_past = seq_length
699
+ past_key_values_length = 0
700
+
701
+ if past_key_values is not None:
702
+ past_key_values_length = past_key_values[0][0].shape[2]
703
+ seq_length_with_past = seq_length_with_past + past_key_values_length
704
+
705
+ if position_ids is None:
706
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
707
+ position_ids = torch.arange(
708
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
709
+ )
710
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
711
+ else:
712
+ position_ids = position_ids.view(-1, seq_length).long()
713
+
714
+ if inputs_embeds is None:
715
+ inputs_embeds = self.embed_tokens(input_ids)
716
+ # embed positions
717
+ if attention_mask is None:
718
+ attention_mask = torch.ones(
719
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
720
+ )
721
+ attention_mask = self._prepare_decoder_attention_mask(
722
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
723
+ )
724
+ if use_zero_attention_mask:
725
+ attention_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
726
+
727
+ hidden_states = inputs_embeds
728
+
729
+ if self.gradient_checkpointing and self.training:
730
+ if use_cache:
731
+ logger.warning_once(
732
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
733
+ )
734
+ use_cache = False
735
+
736
+ # decoder layers
737
+ all_hidden_states = () if output_hidden_states else None
738
+ all_self_attns = () if output_attentions else None
739
+ next_decoder_cache = () if use_cache else None
740
+
741
+ for idx, decoder_layer in enumerate(self.layers):
742
+ if output_hidden_states:
743
+ all_hidden_states += (hidden_states,)
744
+
745
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
746
+
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ vision_hidden_states,
750
+ attention_mask=attention_mask,
751
+ position_ids=position_ids,
752
+ past_key_value=past_key_value,
753
+ output_attentions=output_attentions,
754
+ use_cache=use_cache,
755
+ repeat_time=repeat_time,
756
+ )
757
+
758
+ hidden_states = layer_outputs[0]
759
+
760
+ if use_cache:
761
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
762
+
763
+ if output_attentions:
764
+ all_self_attns += (layer_outputs[1],)
765
+
766
+ hidden_states = self.norm(hidden_states)
767
+
768
+ # add hidden states from the last decoder layer
769
+ if output_hidden_states:
770
+ all_hidden_states += (hidden_states,)
771
+
772
+ next_cache = next_decoder_cache if use_cache else None
773
+ if not return_dict:
774
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
775
+ return BaseModelOutputWithPast(
776
+ last_hidden_state=hidden_states,
777
+ past_key_values=next_cache,
778
+ hidden_states=all_hidden_states,
779
+ attentions=all_self_attns,
780
+ )
781
+
782
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
783
+ def forward_train(
784
+ self,
785
+ input_ids: torch.LongTensor = None,
786
+ attention_mask: Optional[torch.Tensor] = None,
787
+ position_ids: Optional[torch.LongTensor] = None,
788
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
789
+ inputs_embeds: Optional[torch.FloatTensor] = None,
790
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
791
+ repeat_time: Optional[int] = 1,
792
+ use_cache: Optional[bool] = None,
793
+ output_attentions: Optional[bool] = None,
794
+ output_hidden_states: Optional[bool] = None,
795
+ return_dict: Optional[bool] = None,
796
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
797
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
798
+ output_hidden_states = (
799
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
800
+ )
801
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
802
+
803
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
804
+
805
+ # retrieve input_ids and inputs_embeds
806
+ if input_ids is not None and inputs_embeds is not None:
807
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
808
+ elif input_ids is not None:
809
+ batch_size, seq_length = input_ids.shape
810
+ elif inputs_embeds is not None:
811
+ batch_size, seq_length, _ = inputs_embeds.shape
812
+ else:
813
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
814
+
815
+ seq_length_with_past = seq_length
816
+ past_key_values_length = 0
817
+
818
+ if past_key_values is not None:
819
+ past_key_values_length = past_key_values[0][0].shape[2]
820
+ seq_length_with_past = seq_length_with_past + past_key_values_length
821
+
822
+ if position_ids is None:
823
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
824
+ position_ids = torch.arange(
825
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
826
+ )
827
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
828
+ else:
829
+ position_ids = position_ids.view(-1, seq_length).long()
830
+
831
+ if inputs_embeds is None:
832
+ inputs_embeds = self.embed_tokens(input_ids)
833
+ # embed positions
834
+ # if attention_mask is None:
835
+ # attention_mask = torch.ones(
836
+ # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
837
+ # )
838
+ # attention_mask = self._prepare_decoder_attention_mask(
839
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
840
+ # )
841
+ hidden_states = inputs_embeds
842
+
843
+ if self.gradient_checkpointing and self.training:
844
+ if use_cache:
845
+ logger.warning_once(
846
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
847
+ )
848
+ use_cache = False
849
+
850
+ # decoder layers
851
+ all_hidden_states = () if output_hidden_states else None
852
+ all_self_attns = () if output_attentions else None
853
+ next_decoder_cache = () if use_cache else None
854
+
855
+ for idx, decoder_layer in enumerate(self.layers):
856
+ if output_hidden_states:
857
+ all_hidden_states += (hidden_states,)
858
+
859
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
860
+
861
+ if self.gradient_checkpointing and self.training:
862
+
863
+ def create_custom_forward(module):
864
+ def custom_forward(*inputs):
865
+ # None for past_key_value
866
+ return module(*inputs, output_attentions, None, repeat_time)
867
+
868
+ return custom_forward
869
+
870
+ layer_outputs = torch.utils.checkpoint.checkpoint(
871
+ create_custom_forward(decoder_layer),
872
+ hidden_states,
873
+ vision_hidden_states,
874
+ attention_mask,
875
+ position_ids,
876
+ None,
877
+ )
878
+ else:
879
+ layer_outputs = decoder_layer(
880
+ hidden_states,
881
+ vision_hidden_states,
882
+ attention_mask=attention_mask,
883
+ position_ids=position_ids,
884
+ past_key_value=past_key_value,
885
+ output_attentions=output_attentions,
886
+ use_cache=use_cache,
887
+ repeat_time=repeat_time,
888
+ )
889
+
890
+ hidden_states = layer_outputs[0]
891
+
892
+ if use_cache:
893
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
894
+
895
+ if output_attentions:
896
+ all_self_attns += (layer_outputs[1],)
897
+
898
+ hidden_states = self.norm(hidden_states)
899
+
900
+ # add hidden states from the last decoder layer
901
+ if output_hidden_states:
902
+ all_hidden_states += (hidden_states,)
903
+
904
+ next_cache = next_decoder_cache if use_cache else None
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
907
+ return BaseModelOutputWithPast(
908
+ last_hidden_state=hidden_states,
909
+ past_key_values=next_cache,
910
+ hidden_states=all_hidden_states,
911
+ attentions=all_self_attns,
912
+ )
913
+
914
+
915
+ class LlamaForCausalLM(LlamaPreTrainedModel):
916
+ def __init__(self, config):
917
+ super().__init__(config)
918
+ self.model = LlamaModel(config)
919
+
920
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
921
+
922
+ # Initialize weights and apply final processing
923
+ # self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.model.embed_tokens
927
+
928
+ def set_input_embeddings(self, value):
929
+ self.model.embed_tokens = value
930
+
931
+ def get_output_embeddings(self):
932
+ return self.lm_head
933
+
934
+ def set_output_embeddings(self, new_embeddings):
935
+ self.lm_head = new_embeddings
936
+
937
+ def set_decoder(self, decoder):
938
+ self.model = decoder
939
+
940
+ def get_decoder(self):
941
+ return self.model
942
+
943
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: torch.LongTensor = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
953
+ labels: Optional[torch.LongTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ use_zero_attention_mask: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
960
+ r"""
961
+ Args:
962
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
963
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
964
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
965
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
966
+
967
+ Returns:
968
+
969
+ Example:
970
+
971
+ ```python
972
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
973
+
974
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
975
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
976
+
977
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
978
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
979
+
980
+ >>> # Generate
981
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
982
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
983
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
984
+ ```"""
985
+
986
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
987
+ output_hidden_states = (
988
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989
+ )
990
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
991
+
992
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
993
+ outputs = self.model(
994
+ input_ids=input_ids,
995
+ attention_mask=attention_mask,
996
+ position_ids=position_ids,
997
+ past_key_values=past_key_values,
998
+ inputs_embeds=inputs_embeds,
999
+ vision_hidden_states=vision_hidden_states,
1000
+ use_cache=use_cache,
1001
+ output_attentions=output_attentions,
1002
+ output_hidden_states=output_hidden_states,
1003
+ return_dict=return_dict,
1004
+ use_zero_attention_mask=use_zero_attention_mask,
1005
+ )
1006
+
1007
+ hidden_states = outputs[0]
1008
+ logits = self.lm_head(hidden_states)
1009
+
1010
+ loss = None
1011
+ if labels is not None:
1012
+ # Shift so that tokens < n predict n
1013
+ shift_logits = logits[..., :-1, :].contiguous()
1014
+ shift_labels = labels[..., 1:].contiguous()
1015
+ # Flatten the tokens
1016
+ loss_fct = CrossEntropyLoss()
1017
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1018
+ shift_labels = shift_labels.view(-1)
1019
+ # Enable model parallelism
1020
+ shift_labels = shift_labels.to(shift_logits.device)
1021
+ loss = loss_fct(shift_logits, shift_labels)
1022
+
1023
+ if not return_dict:
1024
+ output = (logits,) + outputs[1:]
1025
+ return (loss,) + output if loss is not None else output
1026
+
1027
+ return CausalLMOutputWithPast(
1028
+ loss=loss,
1029
+ logits=logits,
1030
+ past_key_values=outputs.past_key_values,
1031
+ hidden_states=outputs.hidden_states,
1032
+ attentions=outputs.attentions,
1033
+ )
1034
+
1035
+ def prepare_inputs_for_generation(
1036
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
1037
+ vision_hidden_states=None, use_zero_attention_mask=None, **kwargs
1038
+ ):
1039
+ if past_key_values:
1040
+ input_ids = input_ids[:, -1:]
1041
+
1042
+ position_ids = kwargs.get('position_ids', None)
1043
+ if attention_mask is not None and position_ids is None:
1044
+ # create position_ids on the fly for batch generation
1045
+ position_ids = attention_mask.long().cumsum(-1) - 1
1046
+ position_ids.masked_fill_(attention_mask == 0, 1)
1047
+ if past_key_values:
1048
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1049
+
1050
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1051
+ if inputs_embeds is not None and past_key_values is None:
1052
+ model_inputs = {'inputs_embeds': inputs_embeds}
1053
+ else:
1054
+ model_inputs = {'input_ids': input_ids}
1055
+
1056
+ model_inputs.update(
1057
+ {
1058
+ 'position_ids': position_ids,
1059
+ 'past_key_values': past_key_values,
1060
+ 'use_cache': kwargs.get('use_cache'),
1061
+ 'attention_mask': attention_mask,
1062
+ 'vision_hidden_states': vision_hidden_states,
1063
+ 'use_zero_attention_mask': use_zero_attention_mask,
1064
+ }
1065
+ )
1066
+ return model_inputs
1067
+
1068
+ @staticmethod
1069
+ def _reorder_cache(past_key_values, beam_idx):
1070
+ reordered_past = ()
1071
+ for layer_past in past_key_values:
1072
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1073
+ return reordered_past
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms import InterpolationMode
11
+ from transformers import LlamaTokenizer
12
+
13
+ from .configuration_intern_vit import InternVisionConfig
14
+ from .configuration_internvl import InternVLConfig
15
+ from .modeling_intern_vit import InternVisionModel
16
+ from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
17
+
18
+ __all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
19
+ 'InternVLModel', 'InternVL_C', 'InternVL_G']
20
+
21
+
22
+ # Prefix the text "summarize:"
23
+ class InternVLTokenizer(nn.Module):
24
+ def __init__(self, model_path):
25
+ super(InternVLTokenizer, self).__init__()
26
+ self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
27
+ self.tokenizer.pad_token = ' ' # allow padding
28
+ self.tokenizer.add_eos_token = True
29
+
30
+ def forward(self, text, prefix='summarize:'):
31
+ if type(text) == str:
32
+ text = prefix + text
33
+ elif type(text) == list:
34
+ text = [prefix + item for item in text]
35
+ text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
36
+ return text
37
+
38
+
39
+ def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
40
+ if task == 'retrieval':
41
+ transform = T.Compose([
42
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
43
+ T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
44
+ T.ToTensor(),
45
+ T.Normalize(mean=mean, std=std)])
46
+ else:
47
+ transform = T.Compose([
48
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
49
+ T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
50
+ T.CenterCrop(image_size),
51
+ T.ToTensor(),
52
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
53
+ return transform
54
+
55
+
56
+ def load_internvl_c_huggingface(ckpt_path, device, task):
57
+ model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
58
+ if model.config.use_backbone_lora:
59
+ model.vision_model.merge_and_unload()
60
+ model.vision_model = model.vision_model.model
61
+ if model.config.use_qllama_lora:
62
+ model.qllama.merge_and_unload()
63
+ model.qllama = model.qllama.model
64
+ if model.config.force_image_size is not None:
65
+ image_size = model.config.force_image_size
66
+ else:
67
+ image_size = model.config.vision_config.image_size
68
+ transform = build_transform(task, image_size)
69
+ tokenizer = InternVLTokenizer(ckpt_path)
70
+ return model, transform, tokenizer
71
+
72
+
73
+ def load_internvl_g_huggingface(ckpt_path, device, task):
74
+ model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
75
+ if model.config.use_backbone_lora:
76
+ model.vision_model.merge_and_unload()
77
+ model.vision_model = model.vision_model.model
78
+ if model.config.use_qllama_lora:
79
+ model.qllama.merge_and_unload()
80
+ model.qllama = model.qllama.model
81
+ if model.config.force_image_size is not None:
82
+ image_size = model.config.force_image_size
83
+ else:
84
+ image_size = model.config.vision_config.image_size
85
+ transform = build_transform(task, image_size)
86
+ tokenizer = InternVLTokenizer(ckpt_path)
87
+ return model, transform, tokenizer
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_intern_vit.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = 'intern_vit_6b'
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act='gelu',
76
+ layer_norm_eps=1e-6,
77
+ dropout=0.0,
78
+ drop_path_rate=0.0,
79
+ attention_dropout=0.0,
80
+ initializer_range=0.02,
81
+ initializer_factor=0.1,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+
86
+ self.hidden_size = hidden_size
87
+ self.intermediate_size = intermediate_size
88
+ self.dropout = dropout
89
+ self.drop_path_rate = drop_path_rate
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.initializer_range = initializer_range
96
+ self.initializer_factor = initializer_factor
97
+ self.attention_dropout = attention_dropout
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.qkv_bias = qkv_bias
101
+ self.qk_normalization = qk_normalization
102
+ self.use_flash_attn = use_flash_attn
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ if 'vision_config' in config_dict:
109
+ config_dict = config_dict['vision_config']
110
+
111
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
112
+ logger.warning(
113
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
114
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
115
+ )
116
+
117
+ return cls.from_dict(config_dict, **kwargs)
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/configuration_internvl.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import copy
7
+
8
+ from transformers import LlamaConfig
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ from .configuration_intern_vit import InternVisionConfig
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class InternVLConfig(PretrainedConfig):
18
+ r"""
19
+ [`InternVLConfig`] is the configuration class to store the configuration of a
20
+ [`InternVLModel`]. It is used to instantiate a InternVLModel according to the specified
21
+ arguments, defining the InternViT-6B and QLLaMA configs. Instantiating a configuration with
22
+ the defaults will yield a similar configuration to that of the InternVL architecture.
23
+
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+ Args:
28
+ vision_config (`dict`, *optional*):
29
+ Dictionary of configuration options used to initialize [`InternVisionConfig`].
30
+ qllama_config (`dict`, *optional*):
31
+ Dictionary of configuration options used to initialize [`LLaMAConfig`].
32
+ clip_embed_dim (`int`, *optional*, defaults to 768):
33
+ Size of the embeddings from the CLIP model.
34
+ attn_pool_num_heads (`int`, *optional*, defaults to 16):
35
+ Number of attention heads used in the attention pooling layers.
36
+ num_query_token (`int`, *optional*, defaults to 96):
37
+ Number of query tokens used in the transformer.
38
+ label_smoothing (`float`, *optional*, defaults to 0.0):
39
+ The amount of label smoothing to apply.
40
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
41
+ The frequency of cross-attention layers in the model.
42
+ use_backbone_lora (`int`, *optional*, defaults to 0):
43
+ If non-zero, indicates the use of LoRA in the backbone of the model.
44
+ use_qllama_lora (`int`, *optional*, defaults to 0):
45
+ If non-zero, indicates the use of LoRA in the QLLaMA of the model.
46
+ force_image_size (`int` or `None`, *optional*):
47
+ If not None, forces the model to use this specific image size.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ kwargs (*optional*):
51
+ Dictionary of additional keyword arguments.
52
+ """
53
+
54
+ model_type = 'internvl'
55
+ is_composition = True
56
+
57
+ def __init__(
58
+ self,
59
+ vision_config=None,
60
+ qllama_config=None,
61
+ clip_embed_dim=768,
62
+ attn_pool_num_heads=16,
63
+ num_query_token=96,
64
+ label_smoothing=0.0,
65
+ cross_attention_frequency=2,
66
+ use_backbone_lora=0,
67
+ use_qllama_lora=0,
68
+ force_image_size=None,
69
+ initializer_range=0.02,
70
+ **kwargs):
71
+ super().__init__(**kwargs)
72
+
73
+ if vision_config is None:
74
+ vision_config = {}
75
+ logger.info('vision_config is None. initializing the InternVisionConfig with default values.')
76
+
77
+ if qllama_config is None:
78
+ qllama_config = {}
79
+ logger.info(
80
+ 'qllama_config is None. Initializing the InternTextConfig config with default values (`LlamaConfig`).')
81
+
82
+ self.vision_config = InternVisionConfig(**vision_config)
83
+ self.qllama_config = LlamaConfig(**qllama_config)
84
+ self.qllama_config.num_query_token = num_query_token
85
+ self.qllama_config.cross_attention_frequency = cross_attention_frequency
86
+ self.hidden_size = self.qllama_config.hidden_size
87
+
88
+ self.clip_embed_dim = clip_embed_dim
89
+ self.attn_pool_num_heads = attn_pool_num_heads
90
+ self.num_query_token = num_query_token
91
+ self.label_smoothing = label_smoothing
92
+ self.use_backbone_lora = use_backbone_lora
93
+ self.use_qllama_lora = use_qllama_lora
94
+ self.force_image_size = force_image_size
95
+ self.initializer_range = initializer_range
96
+
97
+ def to_dict(self):
98
+ """
99
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
100
+
101
+ Returns:
102
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
103
+ """
104
+ output = copy.deepcopy(self.__dict__)
105
+ output['vision_config'] = self.vision_config.to_dict()
106
+ output['qllama_config'] = self.qllama_config.to_dict()
107
+ output['model_type'] = self.__class__.model_type
108
+ return output
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_intern_vit.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+ try:
23
+ from .flash_attention import FlashAttention
24
+ has_flash_attn = True
25
+ except:
26
+ print('FlashAttention is not installed.')
27
+ has_flash_attn = False
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
+ except ImportError:
54
+ # using the normal InternRMSNorm
55
+ pass
56
+ except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
+ pass
59
+
60
+
61
+ class InternVisionEmbeddings(nn.Module):
62
+ def __init__(self, config: InternVisionConfig):
63
+ super().__init__()
64
+ self.config = config
65
+ self.embed_dim = config.hidden_size
66
+ self.image_size = config.image_size
67
+ self.patch_size = config.patch_size
68
+
69
+ self.class_embedding = nn.Parameter(
70
+ torch.randn(1, 1, self.embed_dim),
71
+ )
72
+
73
+ self.patch_embedding = nn.Conv2d(
74
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
75
+ )
76
+
77
+ self.num_patches = (self.image_size // self.patch_size) ** 2
78
+ self.num_positions = self.num_patches + 1
79
+
80
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
81
+
82
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
+ batch_size = pixel_values.shape[0]
84
+ target_dtype = self.patch_embedding.weight.dtype
85
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
86
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
89
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
90
+ return embeddings
91
+
92
+
93
+ class InternAttention(nn.Module):
94
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
95
+
96
+ def __init__(self, config: InternVisionConfig):
97
+ super().__init__()
98
+ self.config = config
99
+ self.embed_dim = config.hidden_size
100
+ self.num_heads = config.num_attention_heads
101
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
+ if config.use_flash_attn and not has_flash_attn:
103
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
104
+ self.head_dim = self.embed_dim // self.num_heads
105
+ if self.head_dim * self.num_heads != self.embed_dim:
106
+ raise ValueError(
107
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
+ f' {self.num_heads}).'
109
+ )
110
+
111
+ self.scale = self.head_dim ** -0.5
112
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
+ self.attn_drop = nn.Dropout(config.attention_dropout)
114
+ self.proj_drop = nn.Dropout(config.dropout)
115
+
116
+ self.qk_normalization = config.qk_normalization
117
+
118
+ if self.qk_normalization:
119
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
120
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
121
+
122
+ if self.use_flash_attn:
123
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
124
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
125
+
126
+ def _naive_attn(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
130
+
131
+ if self.qk_normalization:
132
+ B_, H_, N_, D_ = q.shape
133
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
135
+
136
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+
140
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
141
+ x = self.proj(x)
142
+ x = self.proj_drop(x)
143
+ return x
144
+
145
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
+ qkv = self.qkv(x)
147
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
148
+
149
+ if self.qk_normalization:
150
+ q, k, v = qkv.unbind(2)
151
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
152
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
153
+ qkv = torch.stack([q, k, v], dim=2)
154
+
155
+ context, _ = self.inner_attn(
156
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
157
+ )
158
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
+ outs = self.proj_drop(outs)
160
+ return outs
161
+
162
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
164
+ return x
165
+
166
+
167
+ class InternMLP(nn.Module):
168
+ def __init__(self, config: InternVisionConfig):
169
+ super().__init__()
170
+ self.config = config
171
+ self.act = ACT2FN[config.hidden_act]
172
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
173
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
174
+
175
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
176
+ hidden_states = self.fc1(hidden_states)
177
+ hidden_states = self.act(hidden_states)
178
+ hidden_states = self.fc2(hidden_states)
179
+ return hidden_states
180
+
181
+
182
+ class InternVisionEncoderLayer(nn.Module):
183
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
184
+ super().__init__()
185
+ self.embed_dim = config.hidden_size
186
+ self.intermediate_size = config.intermediate_size
187
+
188
+ self.attn = InternAttention(config)
189
+ self.mlp = InternMLP(config)
190
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
191
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
192
+
193
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
202
+ """
203
+ Args:
204
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
+ """
206
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
207
+
208
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class InternVisionEncoder(nn.Module):
214
+ """
215
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
216
+ [`InternEncoderLayer`].
217
+
218
+ Args:
219
+ config (`InternConfig`):
220
+ The corresponding vision configuration for the `InternEncoder`.
221
+ """
222
+
223
+ def __init__(self, config: InternVisionConfig):
224
+ super().__init__()
225
+ self.config = config
226
+ # stochastic depth decay rule
227
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
+ self.layers = nn.ModuleList([
229
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
230
+ self.gradient_checkpointing = True
231
+
232
+ def forward(
233
+ self,
234
+ inputs_embeds,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
+ ) -> Union[Tuple, BaseModelOutput]:
238
+ r"""
239
+ Args:
240
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
241
+ Embedded representation of the inputs. Should be float, not int tokens.
242
+ output_hidden_states (`bool`, *optional*):
243
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
244
+ for more detail.
245
+ return_dict (`bool`, *optional*):
246
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
+ """
248
+ output_hidden_states = (
249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
250
+ )
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
+
253
+ encoder_states = () if output_hidden_states else None
254
+ hidden_states = inputs_embeds
255
+
256
+ for idx, encoder_layer in enumerate(self.layers):
257
+ if output_hidden_states:
258
+ encoder_states = encoder_states + (hidden_states,)
259
+ if self.gradient_checkpointing and self.training:
260
+ layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ encoder_layer,
262
+ hidden_states)
263
+ else:
264
+ layer_outputs = encoder_layer(
265
+ hidden_states,
266
+ )
267
+ hidden_states = layer_outputs
268
+
269
+ if output_hidden_states:
270
+ encoder_states = encoder_states + (hidden_states,)
271
+
272
+ if not return_dict:
273
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
274
+ return BaseModelOutput(
275
+ last_hidden_state=hidden_states, hidden_states=encoder_states
276
+ )
277
+
278
+
279
+ class InternVisionModel(PreTrainedModel):
280
+ main_input_name = 'pixel_values'
281
+ config_class = InternVisionConfig
282
+
283
+ def __init__(self, config: InternVisionConfig):
284
+ super().__init__(config)
285
+ self.config = config
286
+
287
+ self.embeddings = InternVisionEmbeddings(config)
288
+ self.encoder = InternVisionEncoder(config)
289
+
290
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
291
+ pos_emb = self.embeddings.position_embedding
292
+ _, num_positions, embed_dim = pos_emb.shape
293
+ cls_emb = pos_emb[:, :1, :]
294
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
295
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
296
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
297
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
298
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
299
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
300
+
301
+ def get_input_embeddings(self):
302
+ return self.embeddings
303
+
304
+ def forward(
305
+ self,
306
+ pixel_values: Optional[torch.FloatTensor] = None,
307
+ output_hidden_states: Optional[bool] = None,
308
+ return_dict: Optional[bool] = None,
309
+ pixel_embeds: Optional[torch.FloatTensor] = None,
310
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ if pixel_values is None and pixel_embeds is None:
317
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
318
+
319
+ if pixel_embeds is not None:
320
+ hidden_states = pixel_embeds
321
+ else:
322
+ if len(pixel_values.shape) == 4:
323
+ hidden_states = self.embeddings(pixel_values)
324
+ else:
325
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
326
+ encoder_outputs = self.encoder(
327
+ inputs_embeds=hidden_states,
328
+ output_hidden_states=output_hidden_states,
329
+ return_dict=return_dict,
330
+ )
331
+ last_hidden_state = encoder_outputs.last_hidden_state
332
+ pooled_output = last_hidden_state[:, 0, :]
333
+
334
+ if not return_dict:
335
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
336
+
337
+ return BaseModelOutputWithPooling(
338
+ last_hidden_state=last_hidden_state,
339
+ pooler_output=pooled_output,
340
+ hidden_states=encoder_outputs.hidden_states,
341
+ attentions=encoder_outputs.attentions,
342
+ )
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_internvl.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import Any, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ from peft import LoraConfig, get_peft_model
16
+ from timm.models.layers import DropPath
17
+ from torch import nn
18
+ from transformers import GenerationConfig
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import ModelOutput, logging
21
+
22
+ from .configuration_internvl import InternVLConfig
23
+ from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
24
+ InternVisionModel)
25
+ from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
26
+
27
+ try:
28
+ from .flash_attention import FlashAttention # v1/v2
29
+ except:
30
+ print('FlashAttention is not installed.')
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class InternVLPreTrainedModel(PreTrainedModel):
36
+ """
37
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
38
+ models.
39
+ """
40
+
41
+ config_class = InternVLConfig
42
+ base_model_prefix = 'internvl'
43
+ supports_gradient_checkpointing = True
44
+ _keys_to_ignore_on_load_missing = [
45
+ r'position_ids',
46
+ ]
47
+ _no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
48
+ _skip_keys_device_placement = 'past_key_values'
49
+ _keep_in_fp32_modules = ['wo']
50
+
51
+ # def _init_weights(self, module):
52
+ # """Initialize the weights"""
53
+ # factor = self.config.initializer_range
54
+ # if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
55
+ # module.weight.data.normal_(mean=0.0, std=factor)
56
+ # if hasattr(module, 'bias') and module.bias is not None:
57
+ # module.bias.data.zero_()
58
+ # if isinstance(module, InternVisionEmbeddings):
59
+ # if hasattr(self.config, 'vision_config'):
60
+ # factor = self.config.vision_config.initializer_range
61
+ # nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
62
+ # nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
63
+ # elif isinstance(module, nn.LayerNorm):
64
+ # module.bias.data.zero_()
65
+ # module.weight.data.fill_(1.0)
66
+ # elif isinstance(module, nn.Linear) and module.bias is not None:
67
+ # module.bias.data.zero_()
68
+
69
+ def _set_gradient_checkpointing(self, module, value=False):
70
+ if isinstance(module, InternVisionModel):
71
+ module.gradient_checkpointing = value
72
+ if isinstance(module, InternVisionEncoder):
73
+ module.gradient_checkpointing = value
74
+
75
+
76
+ class CrossAttention(nn.Module):
77
+ def __init__(
78
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
79
+ proj_drop=0., attn_head_dim=None, out_dim=None):
80
+ super().__init__()
81
+ if out_dim is None:
82
+ out_dim = dim
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ self.scale = qk_scale or head_dim ** -0.5
89
+ assert all_head_dim == dim
90
+
91
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
92
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
93
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
94
+
95
+ if qkv_bias:
96
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
97
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
98
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
99
+ else:
100
+ self.q_bias = None
101
+ self.k_bias = None
102
+ self.v_bias = None
103
+
104
+ self.attn_drop = nn.Dropout(attn_drop)
105
+ self.proj = nn.Linear(all_head_dim, out_dim)
106
+ self.proj_drop = nn.Dropout(proj_drop)
107
+
108
+ def forward(self, x, k=None, v=None):
109
+ B, N, C = x.shape
110
+ N_k = k.shape[1]
111
+ N_v = v.shape[1]
112
+
113
+ q_bias, k_bias, v_bias = None, None, None
114
+ if self.q_bias is not None:
115
+ q_bias = self.q_bias
116
+ k_bias = self.k_bias
117
+ v_bias = self.v_bias
118
+
119
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
120
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
121
+
122
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
123
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
124
+
125
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
126
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
130
+
131
+ attn = attn.softmax(dim=-1)
132
+ attn = self.attn_drop(attn)
133
+
134
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
135
+ x = self.proj(x)
136
+ x = self.proj_drop(x)
137
+
138
+ return x
139
+
140
+
141
+ class AttentiveBlock(nn.Module):
142
+
143
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
144
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
145
+ super().__init__()
146
+
147
+ self.norm1_q = norm_layer(dim)
148
+ self.norm1_k = norm_layer(dim)
149
+ self.norm1_v = norm_layer(dim)
150
+ self.cross_attn = CrossAttention(
151
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
152
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
153
+
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+
156
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
157
+ x_q = self.norm1_q(x_q + pos_q)
158
+ x_k = self.norm1_k(x_kv + pos_k)
159
+ x_v = self.norm1_v(x_kv)
160
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
161
+
162
+ return x
163
+
164
+
165
+ class AttentionPoolingBlock(AttentiveBlock):
166
+
167
+ def forward(self, x):
168
+ x_q = x.mean(1, keepdim=True)
169
+ x_kv, pos_q, pos_k = x, 0, 0
170
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
171
+ x = x.squeeze(1)
172
+ return x
173
+
174
+
175
+ @dataclass
176
+ class InternVLModelOutput(ModelOutput):
177
+ """
178
+ Class defining the outputs of [`InternVLModelOutput`].
179
+ """
180
+
181
+ loss: Optional[torch.FloatTensor] = None
182
+ loss_itm: Optional[torch.FloatTensor] = None
183
+ loss_itc: Optional[torch.FloatTensor] = None
184
+ loss_itg: Optional[torch.FloatTensor] = None
185
+
186
+ def to_tuple(self) -> Tuple[Any]:
187
+ return tuple(
188
+ self[k]
189
+ if k not in ['loss', 'loss_itm', 'loss_itc', 'loss_itg']
190
+ else getattr(self, k).to_tuple()
191
+ for k in self.keys()
192
+ )
193
+
194
+
195
+ class GatherLayer(torch.autograd.Function):
196
+ """Gather tensors from all process, supporting backward propagation.
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, input):
201
+ ctx.save_for_backward(input)
202
+ output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
203
+ dist.all_gather(output, input)
204
+ return torch.stack(output, 0)
205
+
206
+ @staticmethod
207
+ def backward(ctx, grads):
208
+ input, = ctx.saved_tensors
209
+ dist.all_reduce(grads)
210
+ grad_out = torch.zeros_like(input)
211
+ grad_out[:] = grads[dist.get_rank()]
212
+ return grad_out
213
+
214
+
215
+ class InternVLModel(InternVLPreTrainedModel):
216
+ config_class = InternVLConfig
217
+ main_input_name = 'pixel_values'
218
+
219
+ def __init__(self, config: InternVLConfig):
220
+ super().__init__(config)
221
+
222
+ text_hidden_size = config.qllama_config.hidden_size
223
+ vision_hidden_size = config.vision_config.hidden_size
224
+ clip_embed_dim = config.clip_embed_dim
225
+ attn_pool_num_heads = config.attn_pool_num_heads
226
+ config.qllama_config.num_query_token = config.num_query_token
227
+ self.num_query_token = config.num_query_token
228
+ self.label_smoothing = config.label_smoothing
229
+
230
+ self.vision_model = InternVisionModel(config.vision_config) # frozen
231
+ self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
232
+ self.query_tokens = nn.Parameter( # trainable
233
+ torch.zeros(1, config.num_query_token, text_hidden_size)
234
+ )
235
+
236
+ self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
237
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
238
+ self.clip_projector = AttentionPoolingBlock( # frozen
239
+ dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
240
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
241
+ self.clip_projector2 = AttentionPoolingBlock( # trainable
242
+ dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
243
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
244
+ self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
245
+ self.gradient_checkpointing = True
246
+
247
+ # Initialize weights and apply final processing
248
+ # self.post_init()
249
+
250
+ if config.use_backbone_lora:
251
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=config.use_backbone_lora * 2)
252
+ if config.use_qllama_lora:
253
+ self.wrap_qllama_lora(r=config.use_qllama_lora, lora_alpha=config.use_qllama_lora * 2)
254
+ if config.force_image_size:
255
+ self.vision_model.resize_pos_embeddings(
256
+ old_size=config.vision_config.image_size,
257
+ new_size=config.force_image_size,
258
+ patch_size=config.vision_config.patch_size
259
+ )
260
+
261
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
262
+ lora_config = LoraConfig(
263
+ r=r,
264
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
265
+ lora_alpha=lora_alpha,
266
+ lora_dropout=lora_dropout,
267
+ )
268
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
269
+ self.vision_model.print_trainable_parameters()
270
+
271
+ def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
272
+ lora_config = LoraConfig(
273
+ r=r,
274
+ target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
275
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
276
+ lora_alpha=lora_alpha,
277
+ lora_dropout=lora_dropout,
278
+ )
279
+ self.qllama = get_peft_model(self.qllama, lora_config)
280
+ self.qllama.print_trainable_parameters()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.qllama.get_input_embeddings()
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.qllama.set_input_embeddings(value)
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.qllama.set_output_embeddings(new_embeddings)
290
+
291
+ def get_output_embeddings(self) -> nn.Module:
292
+ return self.qllama.get_output_embeddings()
293
+
294
+ @torch.no_grad()
295
+ def _prepare_attention_mask(
296
+ self,
297
+ image_attention_mask: torch.LongTensor,
298
+ attention_mask: torch.LongTensor,
299
+ input_embeds: torch.FloatTensor,
300
+ repeat_time: int,
301
+ ):
302
+ # itm, itc
303
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
304
+ expand_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
305
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
306
+ itm_mask_neg, itm_mask_pos, itc_mask = torch.chunk(expand_mask, repeat_time, dim=0)
307
+
308
+ itc_mask[:, :, :self.num_query_token, self.num_query_token:] = torch.finfo(input_embeds.dtype).min
309
+ itc_mask[:, :, self.num_query_token:, :self.num_query_token] = torch.finfo(input_embeds.dtype).min
310
+ itc_mask_causal = _make_causal_mask(
311
+ (itc_mask.shape[0], itc_mask.shape[2] - self.num_query_token),
312
+ input_embeds.dtype,
313
+ device=input_embeds.device
314
+ )
315
+ # use causal mask for text in itc
316
+ itc_mask[:, :, self.num_query_token:, self.num_query_token:] += itc_mask_causal
317
+
318
+ attention_mask = torch.cat([itm_mask_neg, itm_mask_pos, itc_mask], dim=0)
319
+
320
+ return attention_mask
321
+
322
+ def forward(
323
+ self,
324
+ pixel_values: torch.FloatTensor,
325
+ positive_input_ids: torch.FloatTensor,
326
+ positive_attention_mask: torch.LongTensor,
327
+ negative_input_ids: torch.FloatTensor,
328
+ negative_attention_mask: torch.LongTensor,
329
+ summarize_input_ids: torch.FloatTensor,
330
+ summarize_attention_mask: torch.LongTensor,
331
+ input_ids: torch.FloatTensor,
332
+ attention_mask: torch.LongTensor,
333
+ image_ids: torch.LongTensor,
334
+ labels: torch.LongTensor,
335
+ output_attentions: Optional[bool] = None,
336
+ output_hidden_states: Optional[bool] = None,
337
+ return_dict: Optional[bool] = None,
338
+ ) -> Union[Tuple, InternVLModelOutput]:
339
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340
+
341
+ # step 1: forward the images through the vision encoder,
342
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
343
+ vision_outputs = self.vision_model(
344
+ pixel_values=pixel_values,
345
+ output_hidden_states=output_hidden_states,
346
+ return_dict=return_dict)
347
+ image_embeds = vision_outputs[0]
348
+ backbone_embeds = self.clip_projector(image_embeds)
349
+
350
+ # step 2: prepare input_ids and attention_mask for two sub-tasks:
351
+ # 1) image-text matching; 2) image-text contrastive learning.
352
+ batch_size = input_ids.shape[0]
353
+ input_ids = torch.cat([negative_input_ids, positive_input_ids,
354
+ summarize_input_ids], dim=0) # [3 * batch_size, seq_len]
355
+ itm_attention_mask = torch.cat(
356
+ [negative_attention_mask, positive_attention_mask], dim=0)
357
+ attention_mask = torch.cat(
358
+ [itm_attention_mask, summarize_attention_mask], dim=0) # [3 * batch_size, seq_len]
359
+
360
+ repeat_time = input_ids.size(0) // batch_size
361
+ # step 3: forward the input_ids and attention_mask through the text encoder.
362
+ input_embeds = self.get_input_embeddings()(input_ids)
363
+ query_tokens = self.query_tokens.repeat(repeat_time * batch_size, 1, 1)
364
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
365
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
366
+ attention_mask = self._prepare_attention_mask(
367
+ image_attention_mask, attention_mask, input_embeds, repeat_time
368
+ )
369
+ if type(self.qllama.model) == LlamaForCausalLM:
370
+ outputs = self.qllama.model.model.forward_train(
371
+ inputs_embeds=input_embeds,
372
+ vision_hidden_states=image_embeds,
373
+ attention_mask=attention_mask,
374
+ output_attentions=output_attentions,
375
+ output_hidden_states=output_hidden_states,
376
+ return_dict=return_dict,
377
+ repeat_time=repeat_time,
378
+ ).last_hidden_state
379
+ else:
380
+ outputs = self.qllama.model.forward_train(
381
+ inputs_embeds=input_embeds,
382
+ vision_hidden_states=image_embeds,
383
+ attention_mask=attention_mask,
384
+ output_attentions=output_attentions,
385
+ output_hidden_states=output_hidden_states,
386
+ return_dict=return_dict,
387
+ repeat_time=repeat_time,
388
+ ).last_hidden_state
389
+ image_embeds = outputs[:, :self.num_query_token]
390
+ text_embeds = outputs[:, self.num_query_token:]
391
+ image_itm_neg, image_itm_pos, image_itc = image_embeds.chunk(repeat_time, dim=0)
392
+ text_itm_neg, text_itm_pos, text_itc = text_embeds.chunk(repeat_time, dim=0)
393
+ image_itm = torch.cat([image_itm_neg, image_itm_pos], dim=0)
394
+
395
+ ###============== Image-Text Matching ===================###
396
+ image_itm = self.itm_head(image_itm)
397
+ logits = image_itm.mean(dim=1)
398
+ itm_labels = torch.cat([
399
+ torch.zeros(batch_size, dtype=torch.long, device=logits.device),
400
+ torch.ones(batch_size, dtype=torch.long, device=logits.device)
401
+ ], dim=0)
402
+ loss_itm = F.cross_entropy(logits, itm_labels)
403
+ neg_match_acc = ((logits[:batch_size].argmax(dim=-1) == 0) / batch_size).sum()
404
+ pos_match_acc = ((logits[batch_size:].argmax(dim=-1) == 1) / batch_size).sum()
405
+
406
+ ###============== Image-Text Contrastive ===================###
407
+ image_itc = self.clip_projector2(image_itc)
408
+
409
+ selected = summarize_attention_mask.sum(1) - 1
410
+ text_itc = text_itc[torch.arange(text_itc.shape[0]), selected]
411
+ text_itc = text_itc @ self.text_projection
412
+
413
+ # normalized features
414
+ backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
415
+ image_itc = image_itc / image_itc.norm(dim=1, keepdim=True)
416
+ text_itc = text_itc / text_itc.norm(dim=1, keepdim=True)
417
+ backbone_embeds_all = GatherLayer.apply(backbone_embeds).flatten(0, 1)
418
+ image_itc_all = GatherLayer.apply(image_itc).flatten(0, 1)
419
+ text_itc_all = GatherLayer.apply(text_itc).flatten(0, 1)
420
+
421
+ # cosine similarity as logits
422
+ logit_scale = self.logit_scale.exp()
423
+ sim_i2t = logit_scale * (image_itc @ text_itc_all.t())
424
+ sim_t2i = logit_scale * (text_itc @ image_itc_all.t())
425
+ backbone_i2t = logit_scale * (backbone_embeds @ text_itc_all.t())
426
+ backbone_t2i = logit_scale * (text_itc @ backbone_embeds_all.t())
427
+
428
+ image_ids = image_ids.view(-1, 1)
429
+ image_ids_all = GatherLayer.apply(image_ids).flatten(0, 1)
430
+ pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
431
+ sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
432
+
433
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
434
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
435
+ loss_backbone_t2i = -torch.sum(F.log_softmax(backbone_t2i, dim=1) * sim_targets, dim=1).mean()
436
+ loss_backbone_i2t = -torch.sum(F.log_softmax(backbone_i2t, dim=1) * sim_targets, dim=1).mean()
437
+ loss_itc = (loss_t2i + loss_i2t) / 2 + (loss_backbone_t2i + loss_backbone_i2t) / 2
438
+
439
+ vision_sim = F.cosine_similarity(backbone_embeds.detach(), image_itc).mean()
440
+
441
+ loss = loss_itm + loss_itc
442
+ if dist.get_rank() == 0:
443
+ print(f'loss: {loss.item()}, loss_itm: {loss_itm.item()}, loss_itc: {loss_itc.item()}, '
444
+ f'vision_similarity: {round(vision_sim.item(), 5)}, '
445
+ f'logit scale: {round(1.0 / logit_scale.item(), 5)}, '
446
+ f'pos_match_acc: {round(pos_match_acc.item(), 4)}, '
447
+ f'neg_match_acc: {round(neg_match_acc.item(), 4)}')
448
+
449
+ return InternVLModelOutput(
450
+ loss=loss,
451
+ loss_itc=loss_itc.detach(),
452
+ loss_itm=loss_itm.detach(),
453
+ )
454
+
455
+ @torch.no_grad()
456
+ def generate(
457
+ self,
458
+ pixel_values: torch.FloatTensor,
459
+ input_ids: torch.FloatTensor,
460
+ attention_mask: torch.LongTensor,
461
+ generation_config: Optional[GenerationConfig] = None,
462
+ output_hidden_states: Optional[bool] = None,
463
+ return_dict: Optional[bool] = None,
464
+ **generate_kwargs,
465
+ ) -> torch.LongTensor:
466
+
467
+ vision_outputs = self.vision_model(
468
+ pixel_values=pixel_values,
469
+ output_hidden_states=output_hidden_states,
470
+ return_dict=return_dict)
471
+ image_embeds = vision_outputs[0]
472
+
473
+ batch_size = image_embeds.shape[0]
474
+ input_embeds = self.get_input_embeddings()(input_ids)
475
+ query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
476
+ input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
477
+ image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
478
+ attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
479
+
480
+ outputs = self.qllama.generate(
481
+ inputs_embeds=input_embeds,
482
+ attention_mask=attention_mask,
483
+ vision_hidden_states=image_embeds,
484
+ generation_config=generation_config,
485
+ use_zero_attention_mask=True,
486
+ **generate_kwargs,
487
+ )
488
+
489
+ return outputs
490
+
491
+ def get_text_features(
492
+ self,
493
+ input_ids: torch.Tensor,
494
+ attention_mask: torch.Tensor,
495
+ output_attentions: Optional[bool] = None,
496
+ output_hidden_states: Optional[bool] = None,
497
+ return_dict: Optional[bool] = None,
498
+ ):
499
+ r"""
500
+ Returns:
501
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
502
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
503
+ contains the language model logits, the past key values and the hidden states if
504
+ `output_hidden_states=True`.
505
+ ```"""
506
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
507
+ output_hidden_states = (
508
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
509
+ )
510
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
511
+
512
+ input_embeds = self.get_input_embeddings()(input_ids)
513
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
514
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
515
+ attention_mask += _make_causal_mask(
516
+ (attention_mask.shape[0], attention_mask.shape[2]),
517
+ input_embeds.dtype,
518
+ device=input_embeds.device
519
+ )
520
+ if type(self.qllama.model) == LlamaForCausalLM:
521
+ outputs = self.qllama.model.model.forward_train(
522
+ inputs_embeds=input_embeds,
523
+ vision_hidden_states=None,
524
+ attention_mask=attention_mask,
525
+ output_attentions=output_attentions,
526
+ output_hidden_states=output_hidden_states,
527
+ return_dict=return_dict,
528
+ ).last_hidden_state
529
+ else:
530
+ outputs = self.qllama.model.forward_train(
531
+ inputs_embeds=input_embeds,
532
+ vision_hidden_states=None,
533
+ attention_mask=attention_mask,
534
+ output_attentions=output_attentions,
535
+ output_hidden_states=output_hidden_states,
536
+ return_dict=return_dict,
537
+ ).last_hidden_state
538
+ return outputs
539
+
540
+ def get_image_features(
541
+ self,
542
+ pixel_values: torch.FloatTensor,
543
+ output_attentions: Optional[bool] = None,
544
+ output_hidden_states: Optional[bool] = None,
545
+ return_dict: Optional[bool] = None,
546
+ ):
547
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
548
+ output_hidden_states = (
549
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
550
+ )
551
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
552
+
553
+ vision_outputs = self.vision_model(
554
+ pixel_values=pixel_values,
555
+ output_hidden_states=output_hidden_states,
556
+ return_dict=return_dict)
557
+ image_embeds = vision_outputs[0]
558
+ backbone_embeds = image_embeds
559
+
560
+ batch_size = image_embeds.shape[0]
561
+ input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
562
+
563
+ attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
564
+ attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
565
+ input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
566
+ if type(self.qllama.model) == LlamaForCausalLM:
567
+ outputs = self.qllama.model.model.forward_train(
568
+ inputs_embeds=input_embeds,
569
+ vision_hidden_states=image_embeds,
570
+ attention_mask=attention_mask,
571
+ output_attentions=output_attentions,
572
+ output_hidden_states=output_hidden_states,
573
+ return_dict=return_dict,
574
+ ).last_hidden_state
575
+ else:
576
+ outputs = self.qllama.model.forward_train(
577
+ inputs_embeds=input_embeds,
578
+ vision_hidden_states=image_embeds,
579
+ attention_mask=attention_mask,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states,
582
+ return_dict=return_dict,
583
+ ).last_hidden_state
584
+ return backbone_embeds, outputs
585
+
586
+
587
+ class InternVL_C(InternVLModel):
588
+
589
+ def encode_image(self, image):
590
+ vision_outputs = self.vision_model(
591
+ pixel_values=image,
592
+ output_hidden_states=False,
593
+ return_dict=True)
594
+ image_embeds = vision_outputs[0]
595
+ image_embeds = self.clip_projector(image_embeds)
596
+ return image_embeds
597
+
598
+ def encode_text(self, text):
599
+ attention_mask = text > 0
600
+ text_embeds = self.get_text_features(
601
+ input_ids=text,
602
+ attention_mask=attention_mask,
603
+ output_attentions=False,
604
+ output_hidden_states=False,
605
+ return_dict=True,
606
+ )
607
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
608
+ text_embeds = text_embeds @ self.text_projection
609
+ return text_embeds
610
+
611
+ def forward(self, image, text):
612
+ image_features = self.encode_image(image)
613
+ text_features = self.encode_text(text)
614
+
615
+ # normalized features
616
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
617
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
618
+
619
+ # cosine similarity as logits
620
+ logit_scale = self.logit_scale.exp()
621
+ logits_per_image = logit_scale * image_features @ text_features.t()
622
+ logits_per_text = logits_per_image.t()
623
+
624
+ return logits_per_image, logits_per_text
625
+
626
+
627
+ class InternVL_G(InternVLModel):
628
+
629
+ def encode_image(self, image):
630
+ backbone_embeds, image_embeds = self.get_image_features(
631
+ pixel_values=image,
632
+ output_hidden_states=False,
633
+ return_dict=True,
634
+ )
635
+ backbone_embeds = self.clip_projector(backbone_embeds)
636
+ image_embeds = self.clip_projector2(image_embeds)
637
+ # ensemble
638
+ backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
639
+ image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
640
+ image_embeds = image_embeds + backbone_embeds
641
+ return image_embeds
642
+
643
+ def encode_text(self, text):
644
+ attention_mask = text > 0
645
+ text_embeds = self.get_text_features(
646
+ input_ids=text,
647
+ attention_mask=attention_mask,
648
+ output_attentions=False,
649
+ output_hidden_states=False,
650
+ return_dict=True,
651
+ )
652
+ text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
653
+ text_embeds = text_embeds @ self.text_projection
654
+ return text_embeds
655
+
656
+ def forward(self, image, text):
657
+ image_features = self.encode_image(image)
658
+ text_features = self.encode_text(text)
659
+
660
+ # normalized features
661
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
662
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
663
+
664
+ # cosine similarity as logits
665
+ logit_scale = self.logit_scale.exp()
666
+ logits_per_image = logit_scale * image_features @ text_features.t()
667
+ logits_per_text = logits_per_image.t()
668
+
669
+ return logits_per_image, logits_per_text
VLMEvalKit_old/InternVL/internvl_g/internvl/model/internvl_stage2_retrieval/modeling_qllama.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """ PyTorch QLLaMA model."""
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers import LlamaConfig
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast)
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (add_start_docstrings,
33
+ add_start_docstrings_to_model_forward, logging,
34
+ replace_return_docstrings)
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = 'LlamaConfig'
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ try:
95
+ from functools import partial
96
+
97
+ from apex.normalization import FusedRMSNorm
98
+
99
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
100
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
101
+ except ImportError:
102
+ # using the normal LlamaRMSNorm
103
+ pass
104
+ except Exception:
105
+ print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
106
+ pass
107
+
108
+
109
+ class LlamaRotaryEmbedding(torch.nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
+ self.register_buffer('inv_freq', inv_freq)
114
+
115
+ # Build here to make `torch.jit.trace` work.
116
+ self.max_seq_len_cached = max_position_embeddings
117
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
118
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
119
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
+ emb = torch.cat((freqs, freqs), dim=-1)
121
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
122
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
123
+
124
+ def forward(self, x, seq_len=None):
125
+ # x: [bs, num_attention_heads, seq_len, head_size]
126
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
127
+ if seq_len > self.max_seq_len_cached:
128
+ self.max_seq_len_cached = seq_len
129
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
130
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
133
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
134
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
135
+ return (
136
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
137
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
138
+ )
139
+
140
+
141
+ class FixedLlamaRotaryEmbedding(torch.nn.Module):
142
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
143
+ super().__init__()
144
+
145
+ self.dim = dim
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.base = base
148
+ self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
149
+
150
+ # Build here to make `torch.jit.trace` work.
151
+ self._set_cos_sin_cache(
152
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
153
+ )
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
158
+
159
+ freqs = torch.outer(t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
163
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
164
+
165
+ def forward(self, x, seq_len=None):
166
+ # x: [bs, num_attention_heads, seq_len, head_size]
167
+ if seq_len > self.max_seq_len_cached:
168
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
169
+
170
+ return (
171
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
172
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
173
+ )
174
+
175
+
176
+ LlamaRotaryEmbedding = FixedLlamaRotaryEmbedding
177
+
178
+
179
+ def rotate_half(x):
180
+ """Rotates half the hidden dims of the input."""
181
+ x1 = x[..., : x.shape[-1] // 2]
182
+ x2 = x[..., x.shape[-1] // 2:]
183
+ return torch.cat((-x2, x1), dim=-1)
184
+
185
+
186
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
187
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
188
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
189
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
190
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
191
+ q_embed = (q * cos) + (rotate_half(q) * sin)
192
+ k_embed = (k * cos) + (rotate_half(k) * sin)
193
+ return q_embed, k_embed
194
+
195
+
196
+ class LlamaMLP(nn.Module):
197
+ def __init__(
198
+ self,
199
+ hidden_size: int,
200
+ intermediate_size: int,
201
+ hidden_act: str,
202
+ ):
203
+ super().__init__()
204
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
205
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
206
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
207
+ self.act_fn = ACT2FN[hidden_act]
208
+
209
+ def forward(self, x):
210
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
211
+
212
+
213
+ class LlamaAttention(nn.Module):
214
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
215
+
216
+ def __init__(self, config: LlamaConfig):
217
+ super().__init__()
218
+ self.config = config
219
+ self.hidden_size = config.hidden_size
220
+ self.num_heads = config.num_attention_heads
221
+ self.head_dim = self.hidden_size // self.num_heads
222
+ self.max_position_embeddings = config.max_position_embeddings
223
+
224
+ if (self.head_dim * self.num_heads) != self.hidden_size:
225
+ raise ValueError(
226
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
227
+ f' and `num_heads`: {self.num_heads}).'
228
+ )
229
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
230
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
231
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
232
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
233
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
234
+
235
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
236
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ position_ids: Optional[torch.LongTensor] = None,
243
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
+ output_attentions: bool = False,
245
+ use_cache: bool = False,
246
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
247
+ bsz, q_len, _ = hidden_states.size()
248
+
249
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
250
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
251
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
+
253
+ kv_seq_len = key_states.shape[-2]
254
+ if past_key_value is not None:
255
+ kv_seq_len += past_key_value[0].shape[-2]
256
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
257
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
258
+ # [bsz, nh, t, hd]
259
+
260
+ if past_key_value is not None:
261
+ # reuse k, v, self_attention
262
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
263
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
264
+
265
+ past_key_value = (key_states, value_states) if use_cache else None
266
+
267
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
268
+
269
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
270
+ raise ValueError(
271
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
272
+ f' {attn_weights.size()}'
273
+ )
274
+
275
+ if attention_mask is not None:
276
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
277
+ raise ValueError(
278
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
279
+ )
280
+ attn_weights = attn_weights + attention_mask
281
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
282
+
283
+ # upcast attention to fp32
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
285
+ attn_output = torch.matmul(attn_weights, value_states)
286
+
287
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
288
+ raise ValueError(
289
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
290
+ f' {attn_output.size()}'
291
+ )
292
+
293
+ attn_output = attn_output.transpose(1, 2)
294
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
295
+
296
+ attn_output = self.o_proj(attn_output)
297
+
298
+ if not output_attentions:
299
+ attn_weights = None
300
+
301
+ return attn_output, attn_weights, past_key_value
302
+
303
+
304
+ class LlamaCrossAttention(nn.Module):
305
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
306
+
307
+ def __init__(self, config: LlamaConfig):
308
+ super().__init__()
309
+ self.config = config
310
+ self.hidden_size = config.hidden_size
311
+ self.num_heads = config.num_attention_heads
312
+ self.head_dim = self.hidden_size // self.num_heads
313
+ self.max_position_embeddings = config.max_position_embeddings
314
+ self.vision_hidden_size = 3200
315
+
316
+ if (self.head_dim * self.num_heads) != self.hidden_size:
317
+ raise ValueError(
318
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
319
+ f' and `num_heads`: {self.num_heads}).'
320
+ )
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
322
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
323
+ self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
+
325
+ self.k_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
326
+ self.v_proj = nn.Linear(self.vision_hidden_size, self.num_heads * self.head_dim, bias=False)
327
+ self.norm2 = LlamaRMSNorm(self.vision_hidden_size, eps=config.rms_norm_eps)
328
+
329
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
330
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ vision_hidden_states: torch.Tensor,
336
+ repeat_time: int = 1,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
339
+ output_attentions: bool = False,
340
+ use_cache: bool = False,
341
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
342
+ hidden_states = self.norm1(hidden_states)
343
+
344
+ bsz, q_len, _ = hidden_states.size()
345
+
346
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347
+
348
+ vision_hidden_states = self.norm2(vision_hidden_states)
349
+
350
+ bs_v, kv_len, _ = vision_hidden_states.size()
351
+
352
+ key_states = self.k_proj(vision_hidden_states).view(
353
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ value_states = self.v_proj(vision_hidden_states).view(
355
+ bs_v, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
356
+
357
+ key_states = key_states.repeat(repeat_time, 1, 1, 1)
358
+ value_states = value_states.repeat(repeat_time, 1, 1, 1)
359
+
360
+ kv_seq_len = key_states.shape[-2]
361
+ if past_key_value is not None:
362
+ kv_seq_len += past_key_value[0].shape[-2]
363
+
364
+ if past_key_value is not None:
365
+ # reuse k, v, self_attention
366
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
367
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
368
+
369
+ past_key_value = (key_states, value_states) if use_cache else None
370
+
371
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
372
+
373
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
374
+ raise ValueError(
375
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
376
+ f' {attn_weights.size()}'
377
+ )
378
+
379
+ if attention_mask is not None:
380
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
381
+ raise ValueError(
382
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
383
+ )
384
+ attn_weights = attn_weights + attention_mask
385
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
386
+
387
+ # upcast attention to fp32
388
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
389
+ attn_output = torch.matmul(attn_weights, value_states)
390
+
391
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
392
+ raise ValueError(
393
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
394
+ f' {attn_output.size()}'
395
+ )
396
+
397
+ attn_output = attn_output.transpose(1, 2)
398
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
399
+
400
+ attn_output = self.o_proj(attn_output)
401
+
402
+ if not output_attentions:
403
+ attn_weights = None
404
+
405
+ return attn_output, attn_weights, past_key_value
406
+
407
+
408
+ class LlamaDecoderLayer(nn.Module):
409
+ def __init__(self, config: LlamaConfig, use_cross_attn: bool):
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ self.self_attn = LlamaAttention(config=config)
413
+ self.cross_attn = LlamaCrossAttention(config=config) if use_cross_attn else None
414
+ self.mlp = LlamaMLP(
415
+ hidden_size=self.hidden_size,
416
+ intermediate_size=config.intermediate_size,
417
+ hidden_act=config.hidden_act,
418
+ )
419
+ self.num_query_token = 96
420
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
421
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states: torch.Tensor,
426
+ vision_hidden_states: torch.Tensor,
427
+ attention_mask: Optional[torch.Tensor] = None,
428
+ position_ids: Optional[torch.LongTensor] = None,
429
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
430
+ output_attentions: Optional[bool] = False,
431
+ use_cache: Optional[bool] = False,
432
+ repeat_time: int = 1,
433
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
434
+ """
435
+ Args:
436
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
437
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
438
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
439
+ output_attentions (`bool`, *optional*):
440
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
441
+ returned tensors for more detail.
442
+ use_cache (`bool`, *optional*):
443
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
444
+ (see `past_key_values`).
445
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
446
+ """
447
+
448
+ residual = hidden_states
449
+
450
+ hidden_states = self.input_layernorm(hidden_states)
451
+
452
+ # Self Attention
453
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
454
+ hidden_states=hidden_states,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_value=past_key_value,
458
+ output_attentions=output_attentions,
459
+ use_cache=use_cache,
460
+ )
461
+ hidden_states = residual + hidden_states
462
+
463
+ # when using generate function and cache mode, the size of hidden_states is 1,
464
+ # so we should not use cross attention
465
+ if self.cross_attn is not None and hidden_states.size(1) >= self.num_query_token \
466
+ and vision_hidden_states is not None:
467
+ query_feats = hidden_states[:, :self.num_query_token, :]
468
+ text_feats = hidden_states[:, self.num_query_token:, :]
469
+ residual = query_feats
470
+ query_feats, _, _ = self.cross_attn(
471
+ hidden_states=query_feats,
472
+ vision_hidden_states=vision_hidden_states,
473
+ attention_mask=None, # not use attention mask in cross attention
474
+ past_key_value=past_key_value,
475
+ output_attentions=output_attentions,
476
+ use_cache=use_cache,
477
+ repeat_time=repeat_time,
478
+ )
479
+ query_feats = residual + query_feats
480
+ hidden_states = torch.cat([query_feats, text_feats], dim=1)
481
+
482
+ # Fully Connected
483
+ residual = hidden_states
484
+ hidden_states = self.post_attention_layernorm(hidden_states)
485
+ hidden_states = self.mlp(hidden_states)
486
+ hidden_states = residual + hidden_states
487
+
488
+ outputs = (hidden_states,)
489
+
490
+ if output_attentions:
491
+ outputs += (self_attn_weights,)
492
+
493
+ if use_cache:
494
+ outputs += (present_key_value,)
495
+
496
+ return outputs
497
+
498
+
499
+ LLAMA_START_DOCSTRING = r"""
500
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
501
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
502
+ etc.)
503
+
504
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
505
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
506
+ and behavior.
507
+
508
+ Parameters:
509
+ config ([`LlamaConfig`]):
510
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
511
+ load the weights associated with the model, only the configuration. Check out the
512
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
513
+ """
514
+
515
+
516
+ @add_start_docstrings(
517
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
518
+ LLAMA_START_DOCSTRING,
519
+ )
520
+ class LlamaPreTrainedModel(PreTrainedModel):
521
+ config_class = LlamaConfig
522
+ base_model_prefix = 'model'
523
+ supports_gradient_checkpointing = True
524
+ _no_split_modules = ['LlamaDecoderLayer']
525
+ _keys_to_ignore_on_load_unexpected = [r'decoder\.version']
526
+
527
+ def _init_weights(self, module):
528
+ std = self.config.initializer_range
529
+ if isinstance(module, nn.Linear):
530
+ module.weight.data.normal_(mean=0.0, std=std)
531
+ if module.bias is not None:
532
+ module.bias.data.zero_()
533
+ elif isinstance(module, nn.Embedding):
534
+ module.weight.data.normal_(mean=0.0, std=std)
535
+ if module.padding_idx is not None:
536
+ module.weight.data[module.padding_idx].zero_()
537
+
538
+ def _set_gradient_checkpointing(self, module, value=False):
539
+ if isinstance(module, LlamaModel):
540
+ module.gradient_checkpointing = value
541
+ if isinstance(module, LlamaDecoderLayer):
542
+ module.gradient_checkpointing = value
543
+
544
+
545
+ LLAMA_INPUTS_DOCSTRING = r"""
546
+ Args:
547
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
548
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
549
+ it.
550
+
551
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+
554
+ [What are input IDs?](../glossary#input-ids)
555
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
556
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
557
+
558
+ - 1 for tokens that are **not masked**,
559
+ - 0 for tokens that are **masked**.
560
+
561
+ [What are attention masks?](../glossary#attention-mask)
562
+
563
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
564
+ [`PreTrainedTokenizer.__call__`] for details.
565
+
566
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
567
+ `past_key_values`).
568
+
569
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
570
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
571
+ information on the default strategy.
572
+
573
+ - 1 indicates the head is **not masked**,
574
+ - 0 indicates the head is **masked**.
575
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
576
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
577
+ config.n_positions - 1]`.
578
+
579
+ [What are position IDs?](../glossary#position-ids)
580
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
581
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
582
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
583
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
584
+
585
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
586
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
587
+
588
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
589
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
590
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
591
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
592
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
593
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
594
+ model's internal embedding lookup matrix.
595
+ use_cache (`bool`, *optional*):
596
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
597
+ `past_key_values`).
598
+ output_attentions (`bool`, *optional*):
599
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
600
+ tensors for more detail.
601
+ output_hidden_states (`bool`, *optional*):
602
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
603
+ more detail.
604
+ return_dict (`bool`, *optional*):
605
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
606
+ """
607
+
608
+
609
+ @add_start_docstrings(
610
+ 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.',
611
+ LLAMA_START_DOCSTRING,
612
+ )
613
+ class LlamaModel(LlamaPreTrainedModel):
614
+ """
615
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
616
+
617
+ Args:
618
+ config: LlamaConfig
619
+ """
620
+
621
+ def __init__(self, config: LlamaConfig):
622
+ super().__init__(config)
623
+ self.padding_idx = config.pad_token_id
624
+ self.vocab_size = config.vocab_size
625
+ self.cross_attention_frequency = config.cross_attention_frequency
626
+ self.num_query_token = config.num_query_token
627
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
628
+ use_cross_attn = [idx % self.cross_attention_frequency == 0 for idx in range(config.num_hidden_layers)]
629
+ self.layers = nn.ModuleList(
630
+ [LlamaDecoderLayer(config, use_cross_attn[idx]) for idx in range(config.num_hidden_layers)])
631
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
+ self.gradient_checkpointing = False
633
+ # Initialize weights and apply final processing
634
+ # self.post_init()
635
+
636
+ def get_input_embeddings(self):
637
+ return self.embed_tokens
638
+
639
+ def set_input_embeddings(self, value):
640
+ self.embed_tokens = value
641
+
642
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
643
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
644
+ # create causal mask
645
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
646
+ combined_attention_mask = None
647
+ if input_shape[-1] > 1:
648
+ combined_attention_mask = _make_causal_mask(
649
+ input_shape,
650
+ inputs_embeds.dtype,
651
+ device=inputs_embeds.device,
652
+ past_key_values_length=past_key_values_length,
653
+ )
654
+
655
+ if attention_mask is not None:
656
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
657
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
658
+ inputs_embeds.device
659
+ )
660
+ combined_attention_mask = (
661
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
662
+ )
663
+
664
+ return combined_attention_mask
665
+
666
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
667
+ def forward(
668
+ self,
669
+ input_ids: torch.LongTensor = None,
670
+ attention_mask: Optional[torch.Tensor] = None,
671
+ position_ids: Optional[torch.LongTensor] = None,
672
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
673
+ inputs_embeds: Optional[torch.FloatTensor] = None,
674
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
675
+ repeat_time: Optional[int] = 1,
676
+ use_cache: Optional[bool] = None,
677
+ output_attentions: Optional[bool] = None,
678
+ output_hidden_states: Optional[bool] = None,
679
+ use_zero_attention_mask: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
682
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
683
+ output_hidden_states = (
684
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
685
+ )
686
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
687
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
688
+
689
+ # retrieve input_ids and inputs_embeds
690
+ if input_ids is not None and inputs_embeds is not None:
691
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
692
+ elif input_ids is not None:
693
+ batch_size, seq_length = input_ids.shape
694
+ elif inputs_embeds is not None:
695
+ batch_size, seq_length, _ = inputs_embeds.shape
696
+ else:
697
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
698
+ seq_length_with_past = seq_length
699
+ past_key_values_length = 0
700
+
701
+ if past_key_values is not None:
702
+ past_key_values_length = past_key_values[0][0].shape[2]
703
+ seq_length_with_past = seq_length_with_past + past_key_values_length
704
+
705
+ if position_ids is None:
706
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
707
+ position_ids = torch.arange(
708
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
709
+ )
710
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
711
+ else:
712
+ position_ids = position_ids.view(-1, seq_length).long()
713
+
714
+ if inputs_embeds is None:
715
+ inputs_embeds = self.embed_tokens(input_ids)
716
+ # embed positions
717
+ if attention_mask is None:
718
+ attention_mask = torch.ones(
719
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
720
+ )
721
+ attention_mask = self._prepare_decoder_attention_mask(
722
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
723
+ )
724
+ if use_zero_attention_mask:
725
+ attention_mask[:, :, :self.num_query_token, :self.num_query_token] = 0
726
+
727
+ hidden_states = inputs_embeds
728
+
729
+ if self.gradient_checkpointing and self.training:
730
+ if use_cache:
731
+ logger.warning_once(
732
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
733
+ )
734
+ use_cache = False
735
+
736
+ # decoder layers
737
+ all_hidden_states = () if output_hidden_states else None
738
+ all_self_attns = () if output_attentions else None
739
+ next_decoder_cache = () if use_cache else None
740
+
741
+ for idx, decoder_layer in enumerate(self.layers):
742
+ if output_hidden_states:
743
+ all_hidden_states += (hidden_states,)
744
+
745
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
746
+
747
+ layer_outputs = decoder_layer(
748
+ hidden_states,
749
+ vision_hidden_states,
750
+ attention_mask=attention_mask,
751
+ position_ids=position_ids,
752
+ past_key_value=past_key_value,
753
+ output_attentions=output_attentions,
754
+ use_cache=use_cache,
755
+ repeat_time=repeat_time,
756
+ )
757
+
758
+ hidden_states = layer_outputs[0]
759
+
760
+ if use_cache:
761
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
762
+
763
+ if output_attentions:
764
+ all_self_attns += (layer_outputs[1],)
765
+
766
+ hidden_states = self.norm(hidden_states)
767
+
768
+ # add hidden states from the last decoder layer
769
+ if output_hidden_states:
770
+ all_hidden_states += (hidden_states,)
771
+
772
+ next_cache = next_decoder_cache if use_cache else None
773
+ if not return_dict:
774
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
775
+ return BaseModelOutputWithPast(
776
+ last_hidden_state=hidden_states,
777
+ past_key_values=next_cache,
778
+ hidden_states=all_hidden_states,
779
+ attentions=all_self_attns,
780
+ )
781
+
782
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
783
+ def forward_train(
784
+ self,
785
+ input_ids: torch.LongTensor = None,
786
+ attention_mask: Optional[torch.Tensor] = None,
787
+ position_ids: Optional[torch.LongTensor] = None,
788
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
789
+ inputs_embeds: Optional[torch.FloatTensor] = None,
790
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
791
+ repeat_time: Optional[int] = 1,
792
+ use_cache: Optional[bool] = None,
793
+ output_attentions: Optional[bool] = None,
794
+ output_hidden_states: Optional[bool] = None,
795
+ return_dict: Optional[bool] = None,
796
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
797
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
798
+ output_hidden_states = (
799
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
800
+ )
801
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
802
+
803
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
804
+
805
+ # retrieve input_ids and inputs_embeds
806
+ if input_ids is not None and inputs_embeds is not None:
807
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
808
+ elif input_ids is not None:
809
+ batch_size, seq_length = input_ids.shape
810
+ elif inputs_embeds is not None:
811
+ batch_size, seq_length, _ = inputs_embeds.shape
812
+ else:
813
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
814
+
815
+ seq_length_with_past = seq_length
816
+ past_key_values_length = 0
817
+
818
+ if past_key_values is not None:
819
+ past_key_values_length = past_key_values[0][0].shape[2]
820
+ seq_length_with_past = seq_length_with_past + past_key_values_length
821
+
822
+ if position_ids is None:
823
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
824
+ position_ids = torch.arange(
825
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
826
+ )
827
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
828
+ else:
829
+ position_ids = position_ids.view(-1, seq_length).long()
830
+
831
+ if inputs_embeds is None:
832
+ inputs_embeds = self.embed_tokens(input_ids)
833
+ # embed positions
834
+ # if attention_mask is None:
835
+ # attention_mask = torch.ones(
836
+ # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
837
+ # )
838
+ # attention_mask = self._prepare_decoder_attention_mask(
839
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
840
+ # )
841
+ hidden_states = inputs_embeds
842
+
843
+ if self.gradient_checkpointing and self.training:
844
+ if use_cache:
845
+ logger.warning_once(
846
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
847
+ )
848
+ use_cache = False
849
+
850
+ # decoder layers
851
+ all_hidden_states = () if output_hidden_states else None
852
+ all_self_attns = () if output_attentions else None
853
+ next_decoder_cache = () if use_cache else None
854
+
855
+ for idx, decoder_layer in enumerate(self.layers):
856
+ if output_hidden_states:
857
+ all_hidden_states += (hidden_states,)
858
+
859
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
860
+
861
+ if self.gradient_checkpointing and self.training:
862
+
863
+ def create_custom_forward(module):
864
+ def custom_forward(*inputs):
865
+ # None for past_key_value
866
+ return module(*inputs, output_attentions, None, repeat_time)
867
+
868
+ return custom_forward
869
+
870
+ layer_outputs = torch.utils.checkpoint.checkpoint(
871
+ create_custom_forward(decoder_layer),
872
+ hidden_states,
873
+ vision_hidden_states,
874
+ attention_mask,
875
+ position_ids,
876
+ None,
877
+ )
878
+ else:
879
+ layer_outputs = decoder_layer(
880
+ hidden_states,
881
+ vision_hidden_states,
882
+ attention_mask=attention_mask,
883
+ position_ids=position_ids,
884
+ past_key_value=past_key_value,
885
+ output_attentions=output_attentions,
886
+ use_cache=use_cache,
887
+ repeat_time=repeat_time,
888
+ )
889
+
890
+ hidden_states = layer_outputs[0]
891
+
892
+ if use_cache:
893
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
894
+
895
+ if output_attentions:
896
+ all_self_attns += (layer_outputs[1],)
897
+
898
+ hidden_states = self.norm(hidden_states)
899
+
900
+ # add hidden states from the last decoder layer
901
+ if output_hidden_states:
902
+ all_hidden_states += (hidden_states,)
903
+
904
+ next_cache = next_decoder_cache if use_cache else None
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
907
+ return BaseModelOutputWithPast(
908
+ last_hidden_state=hidden_states,
909
+ past_key_values=next_cache,
910
+ hidden_states=all_hidden_states,
911
+ attentions=all_self_attns,
912
+ )
913
+
914
+
915
+ class LlamaForCausalLM(LlamaPreTrainedModel):
916
+ def __init__(self, config):
917
+ super().__init__(config)
918
+ self.model = LlamaModel(config)
919
+
920
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
921
+
922
+ # Initialize weights and apply final processing
923
+ # self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.model.embed_tokens
927
+
928
+ def set_input_embeddings(self, value):
929
+ self.model.embed_tokens = value
930
+
931
+ def get_output_embeddings(self):
932
+ return self.lm_head
933
+
934
+ def set_output_embeddings(self, new_embeddings):
935
+ self.lm_head = new_embeddings
936
+
937
+ def set_decoder(self, decoder):
938
+ self.model = decoder
939
+
940
+ def get_decoder(self):
941
+ return self.model
942
+
943
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
944
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
945
+ def forward(
946
+ self,
947
+ input_ids: torch.LongTensor = None,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
951
+ inputs_embeds: Optional[torch.FloatTensor] = None,
952
+ vision_hidden_states: Optional[torch.FloatTensor] = None,
953
+ labels: Optional[torch.LongTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ use_zero_attention_mask: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
960
+ r"""
961
+ Args:
962
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
963
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
964
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
965
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
966
+
967
+ Returns:
968
+
969
+ Example:
970
+
971
+ ```python
972
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
973
+
974
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
975
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
976
+
977
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
978
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
979
+
980
+ >>> # Generate
981
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
982
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
983
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
984
+ ```"""
985
+
986
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
987
+ output_hidden_states = (
988
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989
+ )
990
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
991
+
992
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
993
+ outputs = self.model(
994
+ input_ids=input_ids,
995
+ attention_mask=attention_mask,
996
+ position_ids=position_ids,
997
+ past_key_values=past_key_values,
998
+ inputs_embeds=inputs_embeds,
999
+ vision_hidden_states=vision_hidden_states,
1000
+ use_cache=use_cache,
1001
+ output_attentions=output_attentions,
1002
+ output_hidden_states=output_hidden_states,
1003
+ return_dict=return_dict,
1004
+ use_zero_attention_mask=use_zero_attention_mask,
1005
+ )
1006
+
1007
+ hidden_states = outputs[0]
1008
+ logits = self.lm_head(hidden_states)
1009
+
1010
+ loss = None
1011
+ if labels is not None:
1012
+ # Shift so that tokens < n predict n
1013
+ shift_logits = logits[..., :-1, :].contiguous()
1014
+ shift_labels = labels[..., 1:].contiguous()
1015
+ # Flatten the tokens
1016
+ loss_fct = CrossEntropyLoss()
1017
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1018
+ shift_labels = shift_labels.view(-1)
1019
+ # Enable model parallelism
1020
+ shift_labels = shift_labels.to(shift_logits.device)
1021
+ loss = loss_fct(shift_logits, shift_labels)
1022
+
1023
+ if not return_dict:
1024
+ output = (logits,) + outputs[1:]
1025
+ return (loss,) + output if loss is not None else output
1026
+
1027
+ return CausalLMOutputWithPast(
1028
+ loss=loss,
1029
+ logits=logits,
1030
+ past_key_values=outputs.past_key_values,
1031
+ hidden_states=outputs.hidden_states,
1032
+ attentions=outputs.attentions,
1033
+ )
1034
+
1035
+ def prepare_inputs_for_generation(
1036
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
1037
+ vision_hidden_states=None, use_zero_attention_mask=None, **kwargs
1038
+ ):
1039
+ if past_key_values:
1040
+ input_ids = input_ids[:, -1:]
1041
+
1042
+ position_ids = kwargs.get('position_ids', None)
1043
+ if attention_mask is not None and position_ids is None:
1044
+ # create position_ids on the fly for batch generation
1045
+ position_ids = attention_mask.long().cumsum(-1) - 1
1046
+ position_ids.masked_fill_(attention_mask == 0, 1)
1047
+ if past_key_values:
1048
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1049
+
1050
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1051
+ if inputs_embeds is not None and past_key_values is None:
1052
+ model_inputs = {'inputs_embeds': inputs_embeds}
1053
+ else:
1054
+ model_inputs = {'input_ids': input_ids}
1055
+
1056
+ model_inputs.update(
1057
+ {
1058
+ 'position_ids': position_ids,
1059
+ 'past_key_values': past_key_values,
1060
+ 'use_cache': kwargs.get('use_cache'),
1061
+ 'attention_mask': attention_mask,
1062
+ 'vision_hidden_states': vision_hidden_states,
1063
+ 'use_zero_attention_mask': use_zero_attention_mask,
1064
+ }
1065
+ )
1066
+ return model_inputs
1067
+
1068
+ @staticmethod
1069
+ def _reorder_cache(past_key_values, beam_idx):
1070
+ reordered_past = ()
1071
+ for layer_past in past_key_values:
1072
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1073
+ return reordered_past
VLMEvalKit_old/InternVL/internvl_g/internvl/train/dataset.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import re
4
+ from typing import Dict
5
+
6
+ import torch
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms.functional import InterpolationMode
11
+
12
+
13
+ def build_transform(input_size):
14
+ # match fine-tune setting with blip2
15
+ # https://github.com/salesforce/LAVIS/blob/main/lavis/processors/blip_processors.py
16
+ transform = T.Compose([
17
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
18
+ T.RandomResizedCrop(input_size, scale=(0.5, 1.0),
19
+ interpolation=InterpolationMode.BICUBIC),
20
+ T.RandomHorizontalFlip(),
21
+ T.ToTensor(),
22
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
23
+ ])
24
+ return transform
25
+
26
+
27
+ class FlickrDataset(Dataset):
28
+ """Dataset for supervised fine-tuning."""
29
+
30
+ def __init__(self, metas, tokenizer, data_args):
31
+ super(FlickrDataset, self).__init__()
32
+
33
+ f = open(metas['annotation'])
34
+ lines = f.readlines()[1:]
35
+
36
+ self.data_args = data_args
37
+ self.tokenizer = tokenizer
38
+ self.images = []
39
+ self.image_ids = []
40
+ self.captions = []
41
+
42
+ for line in lines:
43
+ image, caption = line.strip().split('.jpg,')
44
+ image_id = int(image)
45
+ caption = self.process_single_caption(caption)
46
+ image = image + '.jpg'
47
+ image_path = metas['root'] + '/' + image
48
+ self.images.append(image_path)
49
+ self.image_ids.append(image_id)
50
+ self.captions.append(caption)
51
+ print(f'There are {len(self.images)} images.')
52
+ print(f'There are {len(self.captions)} captions.')
53
+
54
+ def __len__(self):
55
+ return len(self.images)
56
+
57
+ def process_single_caption(self, caption, max_words=50):
58
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
59
+ caption = re.sub(r'\s{2,}', ' ', caption)
60
+ caption = caption.rstrip('\n')
61
+ caption = caption.strip(' ')
62
+
63
+ # truncate caption
64
+ caption_words = caption.split(' ')
65
+ if len(caption_words) > max_words:
66
+ caption = ' '.join(caption_words[: max_words])
67
+ return caption
68
+
69
+ def preprocess(self, image, caption, neg_caption):
70
+ model_inputs = dict()
71
+
72
+ # input image
73
+ image_transform = build_transform(input_size=self.data_args.force_image_size)
74
+ image = Image.open(image)
75
+ image = image.convert('RGB')
76
+ pixel_values = image_transform(image)
77
+ model_inputs['pixel_values'] = pixel_values
78
+
79
+ # for image-text matching
80
+ pos_model_inputs = self.tokenizer(
81
+ caption,
82
+ max_length=self.data_args.max_seq_length,
83
+ padding='max_length' if self.data_args.pad_to_max_length else False,
84
+ truncation=True,
85
+ return_tensors='pt',
86
+ )
87
+ model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
88
+ model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
89
+ neg_model_inputs = self.tokenizer(
90
+ neg_caption,
91
+ max_length=self.data_args.max_seq_length,
92
+ padding='max_length' if self.data_args.pad_to_max_length else False,
93
+ truncation=True,
94
+ return_tensors='pt',
95
+ )
96
+ model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
97
+ model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
98
+
99
+ # for image-text contrastive learning
100
+ summarize_model_inputs = self.tokenizer(
101
+ 'summarize:' + caption,
102
+ max_length=self.data_args.max_seq_length,
103
+ padding='max_length' if self.data_args.pad_to_max_length else False,
104
+ truncation=True,
105
+ return_tensors='pt',
106
+ )
107
+ model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
108
+ model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
109
+
110
+ # for image-grounded text generation
111
+ prefix = f'English caption:'
112
+ content = caption
113
+ tokenized_prefix = self.tokenizer(
114
+ prefix, padding=False, truncation=True, return_tensors='pt',
115
+ )
116
+ prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
117
+ prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
118
+ tokenized_content = self.tokenizer(
119
+ content,
120
+ max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
121
+ padding='max_length' if self.data_args.pad_to_max_length else False,
122
+ truncation=True,
123
+ return_tensors='pt',
124
+ )
125
+ content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
126
+ content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
127
+ model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
128
+ model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
129
+ labels = model_inputs['input_ids'].clone()
130
+ labels[labels == self.tokenizer.pad_token_id] = -100
131
+ labels[:, :prefix_input_ids.size(1) - 1] = -100
132
+ model_inputs['labels'] = labels
133
+ return model_inputs
134
+
135
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
136
+ i = i % len(self.images)
137
+ j = random.randint(0, len(self.images) - 1)
138
+ while self.image_ids[j] == self.image_ids[i]:
139
+ j = random.randint(0, len(self.images) - 1)
140
+ ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
141
+ # for image-text matching
142
+ ret['positive_input_ids'] = ret['positive_input_ids'][0]
143
+ ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
144
+ ret['negative_input_ids'] = ret['negative_input_ids'][0]
145
+ ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
146
+ # for image-text contrastive learning
147
+ ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
148
+ ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
149
+ # for image-grounded text generation
150
+ ret['input_ids'] = ret['input_ids'][0]
151
+ ret['attention_mask'] = ret['attention_mask'][0]
152
+ ret['labels'] = ret['labels'][0]
153
+ ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
154
+ return ret
155
+
156
+
157
+ class COCODataset(Dataset):
158
+ """Dataset for supervised fine-tuning."""
159
+
160
+ def __init__(self, metas, tokenizer, data_args):
161
+ super(COCODataset, self).__init__()
162
+
163
+ annotations = json.load(open(metas['annotation']))
164
+
165
+ self.data_args = data_args
166
+ self.tokenizer = tokenizer
167
+ self.images = []
168
+ self.image_ids = []
169
+ self.captions = []
170
+
171
+ for annotation in annotations:
172
+ image_id = int(annotation['image_id'].split('_')[-1])
173
+ caption = annotation['caption']
174
+ caption = self.process_single_caption(caption)
175
+ image = annotation['image']
176
+ image_path = metas['root'] + '/' + image
177
+ self.images.append(image_path)
178
+ self.image_ids.append(image_id)
179
+ self.captions.append(caption)
180
+ print(f'There are {len(self.images)} images.')
181
+ print(f'There are {len(self.captions)} captions.')
182
+
183
+ def __len__(self):
184
+ return len(self.images)
185
+
186
+ def process_single_caption(self, caption, max_words=50):
187
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.lower())
188
+ caption = re.sub(r'\s{2,}', ' ', caption)
189
+ caption = caption.rstrip('\n')
190
+ caption = caption.strip(' ')
191
+
192
+ # truncate caption
193
+ caption_words = caption.split(' ')
194
+ if len(caption_words) > max_words:
195
+ caption = ' '.join(caption_words[: max_words])
196
+ return caption
197
+
198
+ def preprocess(self, image, caption, neg_caption):
199
+ model_inputs = dict()
200
+
201
+ # input image
202
+ image_transform = build_transform(input_size=self.data_args.force_image_size)
203
+ image = Image.open(image)
204
+ image = image.convert('RGB')
205
+ pixel_values = image_transform(image)
206
+ model_inputs['pixel_values'] = pixel_values
207
+
208
+ # for image-text matching
209
+ pos_model_inputs = self.tokenizer(
210
+ caption,
211
+ max_length=self.data_args.max_seq_length,
212
+ padding='max_length' if self.data_args.pad_to_max_length else False,
213
+ truncation=True,
214
+ return_tensors='pt',
215
+ )
216
+ model_inputs['positive_input_ids'] = pos_model_inputs['input_ids']
217
+ model_inputs['positive_attention_mask'] = pos_model_inputs['attention_mask']
218
+ neg_model_inputs = self.tokenizer(
219
+ neg_caption,
220
+ max_length=self.data_args.max_seq_length,
221
+ padding='max_length' if self.data_args.pad_to_max_length else False,
222
+ truncation=True,
223
+ return_tensors='pt',
224
+ )
225
+ model_inputs['negative_input_ids'] = neg_model_inputs['input_ids']
226
+ model_inputs['negative_attention_mask'] = neg_model_inputs['attention_mask']
227
+
228
+ # for image-text contrastive learning
229
+ summarize_model_inputs = self.tokenizer(
230
+ 'summarize:' + caption,
231
+ max_length=self.data_args.max_seq_length,
232
+ padding='max_length' if self.data_args.pad_to_max_length else False,
233
+ truncation=True,
234
+ return_tensors='pt',
235
+ )
236
+ model_inputs['summarize_input_ids'] = summarize_model_inputs['input_ids']
237
+ model_inputs['summarize_attention_mask'] = summarize_model_inputs['attention_mask']
238
+
239
+ # for image-grounded text generation
240
+ prefix = f'English caption:'
241
+ content = caption
242
+ tokenized_prefix = self.tokenizer(
243
+ prefix, padding=False, truncation=True, return_tensors='pt',
244
+ )
245
+ prefix_input_ids = tokenized_prefix['input_ids'][:, :-1] # remove eos
246
+ prefix_attention_mask = tokenized_prefix['attention_mask'][:, :-1] # remove eos
247
+ tokenized_content = self.tokenizer(
248
+ content,
249
+ max_length=self.data_args.max_seq_length - prefix_input_ids.size(1) + 1,
250
+ padding='max_length' if self.data_args.pad_to_max_length else False,
251
+ truncation=True,
252
+ return_tensors='pt',
253
+ )
254
+ content_input_ids = tokenized_content['input_ids'][:, 1:] # remove bos
255
+ content_attention_mask = tokenized_content['attention_mask'][:, 1:] # remove bos
256
+ model_inputs['input_ids'] = torch.cat([prefix_input_ids, content_input_ids], dim=1)
257
+ model_inputs['attention_mask'] = torch.cat([prefix_attention_mask, content_attention_mask], dim=1)
258
+ labels = model_inputs['input_ids'].clone()
259
+ labels[labels == self.tokenizer.pad_token_id] = -100
260
+ labels[:, :prefix_input_ids.size(1) - 1] = -100
261
+ model_inputs['labels'] = labels
262
+ return model_inputs
263
+
264
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
265
+ i = i % len(self.images)
266
+ j = random.randint(0, len(self.images) - 1)
267
+ while self.image_ids[j] == self.image_ids[i]:
268
+ j = random.randint(0, len(self.images) - 1)
269
+ ret = self.preprocess(self.images[i], self.captions[i], self.captions[j])
270
+ # for image-text matching
271
+ ret['positive_input_ids'] = ret['positive_input_ids'][0]
272
+ ret['positive_attention_mask'] = ret['positive_attention_mask'][0]
273
+ ret['negative_input_ids'] = ret['negative_input_ids'][0]
274
+ ret['negative_attention_mask'] = ret['negative_attention_mask'][0]
275
+ # for image-text contrastive learning
276
+ ret['summarize_input_ids'] = ret['summarize_input_ids'][0]
277
+ ret['summarize_attention_mask'] = ret['summarize_attention_mask'][0]
278
+ # for image-grounded text generation
279
+ ret['input_ids'] = ret['input_ids'][0]
280
+ ret['attention_mask'] = ret['attention_mask'][0]
281
+ ret['labels'] = ret['labels'][0]
282
+ ret['image_ids'] = torch.Tensor([self.image_ids[i]]).long()
283
+ return ret
VLMEvalKit_old/InternVL/internvl_g/internvl/train/internvl_stage2_finetune.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import warnings
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, Optional
7
+
8
+ import torch.distributed as dist
9
+ import transformers
10
+ from internvl.dist_utils import init_dist
11
+ from internvl.model.internvl_stage2_retrieval import (InternVLConfig,
12
+ InternVLModel)
13
+ from internvl.train.dataset import COCODataset, FlickrDataset
14
+ from internvl.train.trainer_monkey_patch import replace_create_optimizer
15
+ from PIL import Image, ImageFile, PngImagePlugin
16
+ from transformers import (HfArgumentParser, LlamaTokenizer, Trainer,
17
+ TrainingArguments, default_data_collator, set_seed)
18
+ from transformers.trainer_utils import get_last_checkpoint
19
+ from transformers.utils.logging import (enable_default_handler,
20
+ enable_explicit_format, set_verbosity)
21
+
22
+ IGNORE_INDEX = -100
23
+ Image.MAX_IMAGE_PIXELS = None
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
25
+ MaximumDecompressedSize = 1024
26
+ MegaByte = 2 ** 20
27
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
28
+
29
+ warnings.filterwarnings('ignore')
30
+ logger = logging.getLogger(__name__)
31
+
32
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
33
+
34
+ ds_collections = {
35
+ 'flickr30k_en_train': {
36
+ 'root': './data/flickr30k/Images/',
37
+ 'annotation': './data/flickr30k/flickr30k_train_karpathy.txt',
38
+ },
39
+ 'flickr30k_cn_train': {
40
+ 'root': './data/flickr30k/Images/',
41
+ 'annotation': './data/flickr30k/flickr30k_cn_train.txt',
42
+ },
43
+ 'coco_karpathy_train': {
44
+ 'root': './data/coco/',
45
+ 'annotation': './data/coco/annotations/coco_karpathy_train.json',
46
+ },
47
+ }
48
+
49
+
50
+ @dataclass
51
+ class ModelArguments:
52
+ """
53
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
54
+ """
55
+ model_name_or_path: str = field(
56
+ metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
57
+ )
58
+ freeze_model: bool = field(
59
+ default=False,
60
+ metadata={'help': 'Set to True to freeze the entire model.'},
61
+ )
62
+ freeze_vision_model: bool = field(
63
+ default=False,
64
+ metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
65
+ )
66
+ freeze_qllama: bool = field(
67
+ default=False,
68
+ metadata={'help': 'Set to True to freeze the QLLaMA of the model.'},
69
+ )
70
+ unfreeze_qllama_head: bool = field(
71
+ default=False,
72
+ metadata={'help': 'Set to True to unfreeze the head of the QLLaMA.'},
73
+ )
74
+ unfreeze_crossattn: bool = field(
75
+ default=False,
76
+ metadata={'help': 'Set to True to unfreeze the cross attention layers in the QLLaMA.'},
77
+ )
78
+ use_backbone_lora: int = field(
79
+ default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the vision backbone of the model'}
80
+ )
81
+ use_qllama_lora: int = field(
82
+ default=0, metadata={'help': 'If non-zero, indicates the use of LoRA in the QLLaMA of the model'}
83
+ )
84
+ use_custom_trainer: bool = field(
85
+ default=False, metadata={'help': 'Set to True to enable the use of a custom trainer.'},
86
+ )
87
+ drop_path_rate: float = field(
88
+ default=0.0, metadata={'help': 'Specify the value of drop path rate in the vision backbone. Default is 0.'}
89
+ )
90
+
91
+
92
+ @dataclass
93
+ class DataTrainingArguments:
94
+ """
95
+ Arguments pertaining to what data we are going to input our model for training and eval.
96
+ """
97
+ dataset_name: Optional[str] = field(
98
+ default='flickr30k_en_train',
99
+ metadata={'help': 'Specify the name of dataset to be used.'},
100
+ )
101
+ max_seq_length: Optional[int] = field(
102
+ default=80,
103
+ metadata={
104
+ 'help': (
105
+ 'The maximum total input sequence length after tokenization. Sequences longer '
106
+ 'than this will be truncated, sequences shorter will be padded.'
107
+ )
108
+ },
109
+ )
110
+ force_image_size: Optional[int] = field(
111
+ default=224,
112
+ metadata={'help': 'Specify the image size for training models.'},
113
+ )
114
+ pad_to_max_length: bool = field(
115
+ default=False,
116
+ metadata={
117
+ 'help': (
118
+ 'Whether to pad all samples to model maximum sentence length. '
119
+ 'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
120
+ 'efficient on GPU but very bad for TPU.'
121
+ )
122
+ },
123
+ )
124
+
125
+
126
+ def main():
127
+ # Parse input arguments
128
+ # See all possible arguments in src/transformers/training_args.py
129
+ # If use DeepSpeed zero3, init_dist must before HfArgumentParser
130
+ launcher = os.environ.get('LAUNCHER', 'slurm')
131
+ init_dist(launcher=launcher, backend='nccl')
132
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
133
+ if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
134
+ # If we pass only one argument to the script, and it's the path to a json file,
135
+ # let's parse it to get our arguments.
136
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
137
+ else:
138
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
139
+
140
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
141
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
142
+ # send_example_telemetry('finetune Flickr30K', model_args, data_args)
143
+
144
+ # Setup logging
145
+ logging.basicConfig(
146
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
147
+ datefmt='%m/%d/%Y %H:%M:%S',
148
+ handlers=[logging.StreamHandler(sys.stdout)],
149
+ )
150
+
151
+ if training_args.should_log:
152
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
153
+ transformers.utils.logging.set_verbosity_info()
154
+
155
+ log_level = training_args.get_process_log_level()
156
+ logger.setLevel(log_level)
157
+ set_verbosity(log_level)
158
+ enable_default_handler()
159
+ enable_explicit_format()
160
+
161
+ # Log on each process the small summary:
162
+ logger.warning(
163
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
164
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
165
+ )
166
+ logger.info(f'Training/evaluation parameters {training_args}')
167
+
168
+ # Detecting last checkpoint and eventually continue from last checkpoint.
169
+ last_checkpoint = None
170
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
171
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
172
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
173
+ raise ValueError(
174
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. '
175
+ 'Use --overwrite_output_dir to overcome.'
176
+ )
177
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
178
+ logger.info(
179
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
180
+ 'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
181
+ )
182
+ # Set seed before initializing model.
183
+ set_seed(training_args.seed)
184
+
185
+ # Load pretrained model, tokenizer, and image processor
186
+ tokenizer = LlamaTokenizer.from_pretrained(
187
+ model_args.model_name_or_path,
188
+ add_eos_token=True
189
+ )
190
+
191
+ if 'flickr' in data_args.dataset_name:
192
+ train_dataset = FlickrDataset(metas=ds_collections[data_args.dataset_name],
193
+ tokenizer=tokenizer, data_args=data_args)
194
+ elif 'coco' in data_args.dataset_name:
195
+ train_dataset = COCODataset(metas=ds_collections[data_args.dataset_name],
196
+ tokenizer=tokenizer, data_args=data_args)
197
+ config = InternVLConfig.from_pretrained(model_args.model_name_or_path)
198
+ config.vision_config.drop_path_rate = model_args.drop_path_rate
199
+ model = InternVLModel.from_pretrained(
200
+ model_args.model_name_or_path,
201
+ # ignore_mismatched_sizes=True,
202
+ config=config
203
+ )
204
+ if data_args.force_image_size != 224:
205
+ model.config.force_image_size = data_args.force_image_size
206
+ model.vision_model.resize_pos_embeddings(old_size=224, new_size=data_args.force_image_size, patch_size=14)
207
+
208
+ model.config.use_cache = False
209
+ model.config.qllama_config.use_cache = False
210
+ model.qllama.gradient_checkpointing = True
211
+ model.qllama.model.gradient_checkpointing = True
212
+ model.vision_model.gradient_checkpointing = True
213
+ model.vision_model.encoder.gradient_checkpointing = True
214
+
215
+ def _freeze_params(module):
216
+ for param in module.parameters():
217
+ param.requires_grad = False
218
+
219
+ if model_args.freeze_model:
220
+ _freeze_params(model)
221
+
222
+ if model_args.freeze_vision_model:
223
+ model.vision_model = model.vision_model.eval()
224
+ _freeze_params(model.vision_model)
225
+
226
+ if model_args.freeze_qllama:
227
+ model.qllama = model.qllama.eval()
228
+ _freeze_params(model.qllama)
229
+
230
+ if model_args.use_backbone_lora:
231
+ model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=model_args.use_backbone_lora * 2)
232
+ model.config.use_backbone_lora = model_args.use_backbone_lora
233
+
234
+ if model_args.use_qllama_lora:
235
+ model.wrap_qllama_lora(r=model_args.use_qllama_lora, lora_alpha=model_args.use_backbone_lora * 2)
236
+ model.config.use_qllama_lora = model_args.use_qllama_lora
237
+
238
+ if model_args.unfreeze_crossattn:
239
+ for name, param in model.qllama.named_parameters():
240
+ if 'cross_attn' in name:
241
+ param.requires_grad = True
242
+
243
+ if model_args.unfreeze_qllama_head:
244
+ model.qllama.lm_head.weight.requires_grad = True
245
+ model.text_projection.requires_grad = True
246
+
247
+ # print trainable parameters
248
+ if dist.get_rank() == 0:
249
+ for name, param in model.named_parameters():
250
+ print(name, param.requires_grad)
251
+
252
+ # set seed for torch dataloaders
253
+ set_seed(training_args.seed)
254
+
255
+ # Initialize our Trainer
256
+ if model_args.use_custom_trainer:
257
+ replace_create_optimizer()
258
+
259
+ trainer = Trainer(
260
+ model=model,
261
+ args=training_args,
262
+ train_dataset=train_dataset if training_args.do_train else None,
263
+ eval_dataset=None,
264
+ tokenizer=tokenizer,
265
+ data_collator=default_data_collator,
266
+ )
267
+
268
+ # Training
269
+ if training_args.do_train:
270
+ checkpoint = None
271
+ if training_args.resume_from_checkpoint is not None:
272
+ checkpoint = training_args.resume_from_checkpoint
273
+ elif last_checkpoint is not None:
274
+ checkpoint = last_checkpoint
275
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
276
+ trainer.save_model() # Saves the tokenizer too for easy upload
277
+
278
+ metrics = train_result.metrics
279
+ metrics['train_samples'] = len(train_dataset)
280
+ trainer.log_metrics('train', metrics)
281
+ trainer.save_metrics('train', metrics)
282
+ trainer.save_state()
283
+
284
+
285
+ if __name__ == '__main__':
286
+ main()
VLMEvalKit_old/InternVL/internvl_g/internvl/train/trainer_monkey_patch.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import transformers
7
+ from transformers import Trainer, logging
8
+ from transformers.trainer import is_sagemaker_mp_enabled
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
14
+ if var_name in ('query_tokens', 'logit_scale',):
15
+ return 0
16
+ if var_name.startswith('clip_projector.'):
17
+ return vit_num_max_layer
18
+ if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
19
+ var_name == 'text_projection':
20
+ return llama_num_max_layer
21
+ if var_name.startswith('vision_model.'):
22
+ if 'embeddings.' in var_name:
23
+ return 0
24
+ if 'layers.' in var_name:
25
+ var_name = var_name.split('layers.')[-1]
26
+ layer_id = int(var_name.split('.')[0])
27
+ return layer_id + 1
28
+ if var_name.startswith('qllama.'):
29
+ if 'embed_tokens' in var_name:
30
+ return 0
31
+ if 'layers.' in var_name:
32
+ var_name = var_name.split('layers.')[-1]
33
+ layer_id = int(var_name.split('.')[0])
34
+ return layer_id + 1
35
+ else:
36
+ return llama_num_max_layer
37
+ return 0
38
+
39
+
40
+ def param_classification(name):
41
+ if name in ['query_tokens', 'text_projection', 'logit_scale']:
42
+ return 'qllama'
43
+ elif name.startswith('vision_model.'):
44
+ return 'vit'
45
+ elif name.startswith('qllama.'):
46
+ return 'qllama'
47
+ elif name.startswith('clip_projector.'):
48
+ return 'vit'
49
+ elif name.startswith('clip_projector2.'):
50
+ return 'qllama'
51
+ elif name.startswith('itm_head.'):
52
+ return 'qllama'
53
+ else:
54
+ return 'other'
55
+
56
+
57
+ def create_optimizer(self):
58
+ """
59
+ Setup the optimizer.
60
+
61
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
62
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
63
+ """
64
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
65
+
66
+ parameter_groups = {}
67
+ try: # for stage2 model
68
+ vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
69
+ qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
70
+ except: # for stage3 model
71
+ vit_num_layers = opt_model.qllama.config.vision_config.num_hidden_layers + 2
72
+ qllama_num_layers = opt_model.qllama.config.qllama_config.num_hidden_layers + 2
73
+ print('vit_num_layers:', vit_num_layers)
74
+ print('qllama_num_layers:', qllama_num_layers)
75
+
76
+ vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
77
+ qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
78
+ print('vit_layer_decay_rate:', vit_layer_decay_rate)
79
+ print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
80
+
81
+ for name, param in opt_model.named_parameters():
82
+ if not param.requires_grad:
83
+ continue # frozen weights
84
+ if len(param.shape) == 1 or name.endswith('.bias'):
85
+ group_name = 'no_decay'
86
+ this_weight_decay = 0.
87
+ else:
88
+ group_name = 'decay'
89
+ this_weight_decay = self.args.weight_decay
90
+
91
+ cls = param_classification(name)
92
+ layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
93
+ group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
94
+ if group_name not in parameter_groups:
95
+ if cls == 'vit':
96
+ scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
97
+ else:
98
+ scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
99
+ scale = min(1.0, scale)
100
+ parameter_groups[group_name] = {
101
+ 'weight_decay': this_weight_decay,
102
+ 'params': [],
103
+ 'param_names': [],
104
+ 'lr_scale': scale,
105
+ 'group_name': group_name,
106
+ 'lr': scale * self.args.learning_rate,
107
+ }
108
+ parameter_groups[group_name]['params'].append(param)
109
+ parameter_groups[group_name]['param_names'].append(name)
110
+
111
+ rank = torch.distributed.get_rank()
112
+ if rank == 0:
113
+ to_display = {}
114
+ for key in parameter_groups:
115
+ to_display[key] = {
116
+ 'param_names': parameter_groups[key]['param_names'],
117
+ 'lr_scale': parameter_groups[key]['lr_scale'],
118
+ 'lr': parameter_groups[key]['lr'],
119
+ 'weight_decay': parameter_groups[key]['weight_decay'],
120
+ }
121
+ print('Param groups = %s' % json.dumps(to_display, indent=2))
122
+
123
+ optimizer_grouped_parameters = list(parameter_groups.values())
124
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
125
+
126
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
127
+ if optimizer_cls.__name__ == 'Adam8bit':
128
+ import bitsandbytes
129
+
130
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
131
+
132
+ skipped = 0
133
+ for module in opt_model.modules():
134
+ if isinstance(module, nn.Embedding):
135
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
136
+ logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
137
+ manager.register_module_override(module, 'weight', {'optim_bits': 32})
138
+ logger.debug(f'bitsandbytes: will optimize {module} in fp32')
139
+ logger.info(f'skipped: {skipped / 2 ** 20}M params')
140
+
141
+ if is_sagemaker_mp_enabled():
142
+ import smdistributed.modelparallel.torch as smp
143
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
144
+
145
+ return self.optimizer
146
+
147
+
148
+ def replace_create_optimizer():
149
+ print('Replace original create_optimizer with custom create_optimizer')
150
+ transformers.Trainer.create_optimizer = create_optimizer
VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ export VIT_LAYER_DECAY_RATE=0.9
4
+ export QLLAMA_LAYER_DECAY_RATE=0.9
5
+
6
+ PARTITION=${PARTITION:-"VC2"}
7
+ GPUS=${GPUS:-32}
8
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9
+ QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
10
+ NODES=$((GPUS / GPUS_PER_NODE))
11
+ CPUS_PER_TASK=${CPUS_PER_TASK:-10}
12
+ SRUN_ARGS=${SRUN_ARGS:-""}
13
+
14
+
15
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
16
+
17
+ # number of gpus: 32
18
+ # batch size per gpu: 32
19
+ # gradient accumulation steps: 1
20
+ # total batch size: 1024
21
+ # epoch: 5
22
+ srun -p ${PARTITION} \
23
+ --gres=gpu:${GPUS_PER_NODE} \
24
+ --nodes=${NODES} \
25
+ --ntasks=${GPUS} \
26
+ --ntasks-per-node=${GPUS_PER_NODE} \
27
+ --cpus-per-task=${CPUS_PER_TASK} \
28
+ --kill-on-bad-exit=1 \
29
+ --quotatype=${QUOTA_TYPE} \
30
+ ${SRUN_ARGS} \
31
+ python -u internvl/train/internvl_stage2_finetune.py \
32
+ --dataset_name 'coco_karpathy_train' \
33
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
34
+ --output_dir "./work_dirs/internvl_stage2_finetune_coco_364_bs1024_ep5" \
35
+ --overwrite_output_dir True \
36
+ --force_image_size 364 \
37
+ --drop_path_rate 0.3 \
38
+ --use_custom_trainer \
39
+ --dataloader_num_workers 2 \
40
+ --pad_to_max_length True \
41
+ --bf16 True \
42
+ --num_train_epochs 5 \
43
+ --per_device_train_batch_size 32 \
44
+ --gradient_accumulation_steps 1 \
45
+ --evaluation_strategy "no" \
46
+ --save_strategy "steps" \
47
+ --save_steps 100 \
48
+ --save_total_limit 5 \
49
+ --learning_rate 1e-6 \
50
+ --weight_decay 0.05 \
51
+ --warmup_steps 100 \
52
+ --lr_scheduler_type "cosine" \
53
+ --logging_steps 1 \
54
+ --max_seq_length 80 \
55
+ --do_train True \
56
+ --optim adamw_torch \
57
+ --deepspeed "zero_stage1_config.json" \
58
+ --report_to "tensorboard"
VLMEvalKit_old/InternVL/internvl_g/shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ export VIT_LAYER_DECAY_RATE=0.9
4
+ export QLLAMA_LAYER_DECAY_RATE=0.9
5
+
6
+ PARTITION=${PARTITION:-"VC2"}
7
+ GPUS=${GPUS:-32}
8
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
9
+ QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
10
+ NODES=$((GPUS / GPUS_PER_NODE))
11
+ CPUS_PER_TASK=${CPUS_PER_TASK:-10}
12
+ SRUN_ARGS=${SRUN_ARGS:-""}
13
+
14
+
15
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
16
+
17
+ # number of gpus: 32
18
+ # batch size per gpu: 32
19
+ # gradient accumulation steps: 1
20
+ # total batch size: 1024
21
+ # epoch: 10
22
+ srun -p ${PARTITION} \
23
+ --gres=gpu:${GPUS_PER_NODE} \
24
+ --nodes=${NODES} \
25
+ --ntasks=${GPUS} \
26
+ --ntasks-per-node=${GPUS_PER_NODE} \
27
+ --cpus-per-task=${CPUS_PER_TASK} \
28
+ --kill-on-bad-exit=1 \
29
+ --quotatype=${QUOTA_TYPE} \
30
+ ${SRUN_ARGS} \
31
+ python -u internvl/train/internvl_stage2_finetune.py \
32
+ --dataset_name 'flickr30k_en_train' \
33
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
34
+ --output_dir "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10" \
35
+ --overwrite_output_dir True \
36
+ --force_image_size 364 \
37
+ --drop_path_rate 0.3 \
38
+ --use_custom_trainer \
39
+ --dataloader_num_workers 2 \
40
+ --pad_to_max_length True \
41
+ --bf16 True \
42
+ --num_train_epochs 10 \
43
+ --per_device_train_batch_size 32 \
44
+ --gradient_accumulation_steps 1 \
45
+ --evaluation_strategy "no" \
46
+ --save_strategy "steps" \
47
+ --save_steps 100 \
48
+ --save_total_limit 5 \
49
+ --learning_rate 1e-6 \
50
+ --weight_decay 0.05 \
51
+ --warmup_steps 100 \
52
+ --lr_scheduler_type "cosine" \
53
+ --logging_steps 1 \
54
+ --max_seq_length 80 \
55
+ --do_train True \
56
+ --optim adamw_torch \
57
+ --deepspeed "zero_stage1_config.json" \
58
+ --report_to "tensorboard"
VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10_lora_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_en_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --use_backbone_lora 16 \
39
+ --use_qllama_lora 16 \
40
+ --force_image_size 224 \
41
+ --drop_path_rate 0.0 \
42
+ --dataloader_num_workers 2 \
43
+ --pad_to_max_length True \
44
+ --bf16 True \
45
+ --num_train_epochs 10 \
46
+ --per_device_train_batch_size ${BATCH_SIZE} \
47
+ --gradient_accumulation_steps 1 \
48
+ --evaluation_strategy "no" \
49
+ --save_strategy "steps" \
50
+ --save_steps 100 \
51
+ --save_total_limit 5 \
52
+ --learning_rate 1e-6 \
53
+ --weight_decay 0.05 \
54
+ --warmup_steps 100 \
55
+ --lr_scheduler_type "cosine" \
56
+ --logging_steps 1 \
57
+ --max_seq_length 80 \
58
+ --do_train True \
59
+ --optim adamw_torch \
60
+ --deepspeed "zero_stage3_config.json" \
61
+ --report_to "tensorboard"
VLMEvalKit_old/InternVL/internvl_g/shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ GPUS=${GPUS:-4}
4
+ BATCH_SIZE=${BATCH_SIZE:-32}
5
+
6
+
7
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
8
+ export MASTER_PORT=34229
9
+ export TF_CPP_MIN_LOG_LEVEL=3
10
+ export LAUNCHER=pytorch
11
+
12
+ OUTPUT_DIR='work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10_lora_4gpu'
13
+
14
+ if [ ! -d "$OUTPUT_DIR" ]; then
15
+ mkdir -p "$OUTPUT_DIR"
16
+ fi
17
+
18
+ # number of gpus: 32
19
+ # batch size per gpu: 32
20
+ # gradient accumulation steps: 1
21
+ # total batch size: 1024
22
+ # epoch: 10
23
+ torchrun \
24
+ --nnodes=1 \
25
+ --node_rank=0 \
26
+ --master_addr=127.0.0.1 \
27
+ --nproc_per_node=${GPUS} \
28
+ --master_port=${MASTER_PORT} \
29
+ internvl/train/internvl_stage2_finetune.py \
30
+ --dataset_name 'flickr30k_cn_train' \
31
+ --model_name_or_path "./pretrained/InternVL-14B-224px" \
32
+ --output_dir ${OUTPUT_DIR} \
33
+ --overwrite_output_dir True \
34
+ --freeze_model \
35
+ --freeze_vision_model \
36
+ --freeze_qllama \
37
+ --unfreeze_qllama_head \
38
+ --use_backbone_lora 16 \
39
+ --use_qllama_lora 16 \
40
+ --force_image_size 224 \
41
+ --drop_path_rate 0.0 \
42
+ --dataloader_num_workers 2 \
43
+ --pad_to_max_length True \
44
+ --bf16 True \
45
+ --num_train_epochs 10 \
46
+ --per_device_train_batch_size ${BATCH_SIZE} \
47
+ --gradient_accumulation_steps 1 \
48
+ --evaluation_strategy "no" \
49
+ --save_strategy "steps" \
50
+ --save_steps 100 \
51
+ --save_total_limit 5 \
52
+ --learning_rate 1e-6 \
53
+ --weight_decay 0.05 \
54
+ --warmup_steps 100 \
55
+ --lr_scheduler_type "cosine" \
56
+ --logging_steps 1 \
57
+ --max_seq_length 80 \
58
+ --do_train True \
59
+ --optim adamw_torch \
60
+ --deepspeed "zero_stage3_config.json" \
61
+ --report_to "tensorboard"
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_504x504.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'ADE20KDataset'
3
+ data_root = 'data/ade/ADEChallengeData2016'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (504, 504)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations', reduce_zero_label=True),
10
+ dict(type='Resize', img_scale=(2016, 504), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2016, 504),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='SETR_Resize', keep_ratio=True,
28
+ crop_size=crop_size, setr_multi_scale=True),
29
+ dict(type='ResizeToMultiple', size_divisor=14),
30
+ dict(type='RandomFlip'),
31
+ dict(type='Normalize', **img_norm_cfg),
32
+ dict(type='ImageToTensor', keys=['img']),
33
+ dict(type='Collect', keys=['img']),
34
+ ])
35
+ ]
36
+ data = dict(
37
+ samples_per_gpu=4,
38
+ workers_per_gpu=4,
39
+ train=dict(
40
+ type=dataset_type,
41
+ data_root=data_root,
42
+ img_dir='images/training',
43
+ ann_dir='annotations/training',
44
+ pipeline=train_pipeline),
45
+ val=dict(
46
+ type=dataset_type,
47
+ data_root=data_root,
48
+ img_dir='images/validation',
49
+ ann_dir='annotations/validation',
50
+ pipeline=test_pipeline),
51
+ test=dict(
52
+ type=dataset_type,
53
+ data_root=data_root,
54
+ img_dir='images/validation',
55
+ ann_dir='annotations/validation',
56
+ pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/ade20k_640x640.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'ADE20KDataset'
3
+ data_root = 'data/ade/ADEChallengeData2016'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (640, 640)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations', reduce_zero_label=True),
10
+ dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2560, 640),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='Resize', keep_ratio=True),
28
+ dict(type='RandomFlip'),
29
+ dict(type='Normalize', **img_norm_cfg),
30
+ dict(type='ImageToTensor', keys=['img']),
31
+ dict(type='Collect', keys=['img']),
32
+ ])
33
+ ]
34
+ data = dict(
35
+ samples_per_gpu=4,
36
+ workers_per_gpu=4,
37
+ train=dict(
38
+ type=dataset_type,
39
+ data_root=data_root,
40
+ img_dir='images/training',
41
+ ann_dir='annotations/training',
42
+ pipeline=train_pipeline),
43
+ val=dict(
44
+ type=dataset_type,
45
+ data_root=data_root,
46
+ img_dir='images/validation',
47
+ ann_dir='annotations/validation',
48
+ pipeline=test_pipeline),
49
+ test=dict(
50
+ type=dataset_type,
51
+ data_root=data_root,
52
+ img_dir='images/validation',
53
+ ann_dir='annotations/validation',
54
+ pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/cityscapes_832x832.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = './cityscapes.py'
2
+ img_norm_cfg = dict(
3
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
4
+ crop_size = (832, 832)
5
+ train_pipeline = [
6
+ dict(type='LoadImageFromFile'),
7
+ dict(type='LoadAnnotations'),
8
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
9
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
10
+ dict(type='RandomFlip', prob=0.5),
11
+ dict(type='PhotoMetricDistortion'),
12
+ dict(type='Normalize', **img_norm_cfg),
13
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
14
+ dict(type='DefaultFormatBundle'),
15
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
16
+ ]
17
+ test_pipeline = [
18
+ dict(type='LoadImageFromFile'),
19
+ dict(
20
+ type='MultiScaleFlipAug',
21
+ img_scale=(2048, 1024),
22
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
23
+ flip=False,
24
+ transforms=[
25
+ dict(type='Resize', keep_ratio=True),
26
+ dict(type='RandomFlip'),
27
+ dict(type='Normalize', **img_norm_cfg),
28
+ dict(type='ImageToTensor', keys=['img']),
29
+ dict(type='Collect', keys=['img']),
30
+ ])
31
+ ]
32
+ data = dict(
33
+ train=dict(pipeline=train_pipeline),
34
+ val=dict(pipeline=test_pipeline),
35
+ test=dict(pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff10k.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'COCOStuffDataset'
3
+ data_root = 'data/coco_stuff10k'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (512, 512)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations', reduce_zero_label=True),
10
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2048, 512),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='Resize', keep_ratio=True),
28
+ dict(type='RandomFlip'),
29
+ dict(type='Normalize', **img_norm_cfg),
30
+ dict(type='ImageToTensor', keys=['img']),
31
+ dict(type='Collect', keys=['img']),
32
+ ])
33
+ ]
34
+ data = dict(
35
+ samples_per_gpu=4,
36
+ workers_per_gpu=4,
37
+ train=dict(
38
+ type=dataset_type,
39
+ data_root=data_root,
40
+ reduce_zero_label=True,
41
+ img_dir='images/train2014',
42
+ ann_dir='annotations/train2014',
43
+ pipeline=train_pipeline),
44
+ val=dict(
45
+ type=dataset_type,
46
+ data_root=data_root,
47
+ reduce_zero_label=True,
48
+ img_dir='images/test2014',
49
+ ann_dir='annotations/test2014',
50
+ pipeline=test_pipeline),
51
+ test=dict(
52
+ type=dataset_type,
53
+ data_root=data_root,
54
+ reduce_zero_label=True,
55
+ img_dir='images/test2014',
56
+ ann_dir='annotations/test2014',
57
+ pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/coco-stuff164k.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'COCOStuffDataset'
3
+ data_root = 'data/coco_stuff164k'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ crop_size = (512, 512)
7
+ train_pipeline = [
8
+ dict(type='LoadImageFromFile'),
9
+ dict(type='LoadAnnotations'),
10
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
11
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12
+ dict(type='RandomFlip', prob=0.5),
13
+ dict(type='PhotoMetricDistortion'),
14
+ dict(type='Normalize', **img_norm_cfg),
15
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16
+ dict(type='DefaultFormatBundle'),
17
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18
+ ]
19
+ test_pipeline = [
20
+ dict(type='LoadImageFromFile'),
21
+ dict(
22
+ type='MultiScaleFlipAug',
23
+ img_scale=(2048, 512),
24
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25
+ flip=False,
26
+ transforms=[
27
+ dict(type='Resize', keep_ratio=True),
28
+ dict(type='RandomFlip'),
29
+ dict(type='Normalize', **img_norm_cfg),
30
+ dict(type='ImageToTensor', keys=['img']),
31
+ dict(type='Collect', keys=['img']),
32
+ ])
33
+ ]
34
+ data = dict(
35
+ samples_per_gpu=4,
36
+ workers_per_gpu=4,
37
+ train=dict(
38
+ type=dataset_type,
39
+ data_root=data_root,
40
+ img_dir='images/train2017',
41
+ ann_dir='annotations/train2017',
42
+ pipeline=train_pipeline),
43
+ val=dict(
44
+ type=dataset_type,
45
+ data_root=data_root,
46
+ img_dir='images/val2017',
47
+ ann_dir='annotations/val2017',
48
+ pipeline=test_pipeline),
49
+ test=dict(
50
+ type=dataset_type,
51
+ data_root=data_root,
52
+ img_dir='images/val2017',
53
+ ann_dir='annotations/val2017',
54
+ pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/datasets/hrf.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ dataset_type = 'HRFDataset'
3
+ data_root = 'data/HRF'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ img_scale = (2336, 3504)
7
+ crop_size = (256, 256)
8
+ train_pipeline = [
9
+ dict(type='LoadImageFromFile'),
10
+ dict(type='LoadAnnotations'),
11
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
12
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
13
+ dict(type='RandomFlip', prob=0.5),
14
+ dict(type='PhotoMetricDistortion'),
15
+ dict(type='Normalize', **img_norm_cfg),
16
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
17
+ dict(type='DefaultFormatBundle'),
18
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
19
+ ]
20
+ test_pipeline = [
21
+ dict(type='LoadImageFromFile'),
22
+ dict(
23
+ type='MultiScaleFlipAug',
24
+ img_scale=img_scale,
25
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
26
+ flip=False,
27
+ transforms=[
28
+ dict(type='Resize', keep_ratio=True),
29
+ dict(type='RandomFlip'),
30
+ dict(type='Normalize', **img_norm_cfg),
31
+ dict(type='ImageToTensor', keys=['img']),
32
+ dict(type='Collect', keys=['img'])
33
+ ])
34
+ ]
35
+
36
+ data = dict(
37
+ samples_per_gpu=4,
38
+ workers_per_gpu=4,
39
+ train=dict(
40
+ type='RepeatDataset',
41
+ times=40000,
42
+ dataset=dict(
43
+ type=dataset_type,
44
+ data_root=data_root,
45
+ img_dir='images/training',
46
+ ann_dir='annotations/training',
47
+ pipeline=train_pipeline)),
48
+ val=dict(
49
+ type=dataset_type,
50
+ data_root=data_root,
51
+ img_dir='images/validation',
52
+ ann_dir='annotations/validation',
53
+ pipeline=test_pipeline),
54
+ test=dict(
55
+ type=dataset_type,
56
+ data_root=data_root,
57
+ img_dir='images/validation',
58
+ ann_dir='annotations/validation',
59
+ pipeline=test_pipeline))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ann_r50-d8.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='ANNHead',
19
+ in_channels=[1024, 2048],
20
+ in_index=[2, 3],
21
+ channels=512,
22
+ project_channels=256,
23
+ query_scales=(1, ),
24
+ key_pool_scales=(1, 3, 6, 8),
25
+ dropout_ratio=0.1,
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
31
+ auxiliary_head=dict(
32
+ type='FCNHead',
33
+ in_channels=1024,
34
+ in_index=2,
35
+ channels=256,
36
+ num_convs=1,
37
+ concat_input=False,
38
+ dropout_ratio=0.1,
39
+ num_classes=19,
40
+ norm_cfg=norm_cfg,
41
+ align_corners=False,
42
+ loss_decode=dict(
43
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
44
+ # model training and testing settings
45
+ train_cfg=dict(),
46
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/ccnet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='CCHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ recurrence=2,
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/cgnet.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ backbone=dict(
6
+ type='CGNet',
7
+ norm_cfg=norm_cfg,
8
+ in_channels=3,
9
+ num_channels=(32, 64, 128),
10
+ num_blocks=(3, 21),
11
+ dilations=(2, 4),
12
+ reductions=(8, 16)),
13
+ decode_head=dict(
14
+ type='FCNHead',
15
+ in_channels=256,
16
+ in_index=2,
17
+ channels=256,
18
+ num_convs=0,
19
+ concat_input=False,
20
+ dropout_ratio=0,
21
+ num_classes=19,
22
+ norm_cfg=norm_cfg,
23
+ loss_decode=dict(
24
+ type='CrossEntropyLoss',
25
+ use_sigmoid=False,
26
+ loss_weight=1.0,
27
+ class_weight=[
28
+ 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
29
+ 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
30
+ 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
31
+ 10.396974, 10.055647
32
+ ])),
33
+ # model training and testing settings
34
+ train_cfg=dict(sampler=None),
35
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/danet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='DAHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ pam_channels=64,
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='ASPPHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ dilations=(1, 12, 24, 36),
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained=None,
6
+ backbone=dict(
7
+ type='UNet',
8
+ in_channels=3,
9
+ base_channels=64,
10
+ num_stages=5,
11
+ strides=(1, 1, 1, 1, 1),
12
+ enc_num_convs=(2, 2, 2, 2, 2),
13
+ dec_num_convs=(2, 2, 2, 2),
14
+ downsamples=(True, True, True, True),
15
+ enc_dilations=(1, 1, 1, 1, 1),
16
+ dec_dilations=(1, 1, 1, 1),
17
+ with_cp=False,
18
+ conv_cfg=None,
19
+ norm_cfg=norm_cfg,
20
+ act_cfg=dict(type='ReLU'),
21
+ upsample_cfg=dict(type='InterpConv'),
22
+ norm_eval=False),
23
+ decode_head=dict(
24
+ type='ASPPHead',
25
+ in_channels=64,
26
+ in_index=4,
27
+ channels=16,
28
+ dilations=(1, 12, 24, 36),
29
+ dropout_ratio=0.1,
30
+ num_classes=2,
31
+ norm_cfg=norm_cfg,
32
+ align_corners=False,
33
+ loss_decode=dict(
34
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
35
+ auxiliary_head=dict(
36
+ type='FCNHead',
37
+ in_channels=128,
38
+ in_index=3,
39
+ channels=64,
40
+ num_convs=1,
41
+ concat_input=False,
42
+ dropout_ratio=0.1,
43
+ num_classes=2,
44
+ norm_cfg=norm_cfg,
45
+ align_corners=False,
46
+ loss_decode=dict(
47
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
48
+ # model training and testing settings
49
+ train_cfg=dict(),
50
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/dmnet_r50-d8.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='DMHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ filter_sizes=(1, 3, 5, 7),
23
+ dropout_ratio=0.1,
24
+ num_classes=19,
25
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
29
+ auxiliary_head=dict(
30
+ type='FCNHead',
31
+ in_channels=1024,
32
+ in_index=2,
33
+ channels=256,
34
+ num_convs=1,
35
+ concat_input=False,
36
+ dropout_ratio=0.1,
37
+ num_classes=19,
38
+ norm_cfg=norm_cfg,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
42
+ # model training and testing settings
43
+ train_cfg=dict(),
44
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/emanet_r50-d8.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='EMAHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=256,
22
+ ema_channels=512,
23
+ num_bases=64,
24
+ num_stages=3,
25
+ momentum=0.1,
26
+ dropout_ratio=0.1,
27
+ num_classes=19,
28
+ norm_cfg=norm_cfg,
29
+ align_corners=False,
30
+ loss_decode=dict(
31
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
32
+ auxiliary_head=dict(
33
+ type='FCNHead',
34
+ in_channels=1024,
35
+ in_index=2,
36
+ channels=256,
37
+ num_convs=1,
38
+ concat_input=False,
39
+ dropout_ratio=0.1,
40
+ num_classes=19,
41
+ norm_cfg=norm_cfg,
42
+ align_corners=False,
43
+ loss_decode=dict(
44
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
45
+ # model training and testing settings
46
+ train_cfg=dict(),
47
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/encnet_r50-d8.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='EncHead',
19
+ in_channels=[512, 1024, 2048],
20
+ in_index=(1, 2, 3),
21
+ channels=512,
22
+ num_codes=32,
23
+ use_se_loss=True,
24
+ add_lateral=False,
25
+ dropout_ratio=0.1,
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
31
+ loss_se_decode=dict(
32
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
33
+ auxiliary_head=dict(
34
+ type='FCNHead',
35
+ in_channels=1024,
36
+ in_index=2,
37
+ channels=256,
38
+ num_convs=1,
39
+ concat_input=False,
40
+ dropout_ratio=0.1,
41
+ num_classes=19,
42
+ norm_cfg=norm_cfg,
43
+ align_corners=False,
44
+ loss_decode=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
46
+ # model training and testing settings
47
+ train_cfg=dict(),
48
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/erfnet_fcn.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained=None,
6
+ backbone=dict(
7
+ type='ERFNet',
8
+ in_channels=3,
9
+ enc_downsample_channels=(16, 64, 128),
10
+ enc_stage_non_bottlenecks=(5, 8),
11
+ enc_non_bottleneck_dilations=(2, 4, 8, 16),
12
+ enc_non_bottleneck_channels=(64, 128),
13
+ dec_upsample_channels=(64, 16),
14
+ dec_stages_non_bottleneck=(2, 2),
15
+ dec_non_bottleneck_channels=(64, 16),
16
+ dropout_ratio=0.1,
17
+ init_cfg=None),
18
+ decode_head=dict(
19
+ type='FCNHead',
20
+ in_channels=16,
21
+ channels=128,
22
+ num_convs=1,
23
+ concat_input=False,
24
+ dropout_ratio=0.1,
25
+ num_classes=19,
26
+ norm_cfg=norm_cfg,
27
+ align_corners=False,
28
+ loss_decode=dict(
29
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
30
+ # model training and testing settings
31
+ train_cfg=dict(),
32
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fast_scnn.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ backbone=dict(
6
+ type='FastSCNN',
7
+ downsample_dw_channels=(32, 48),
8
+ global_in_channels=64,
9
+ global_block_channels=(64, 96, 128),
10
+ global_block_strides=(2, 2, 1),
11
+ global_out_channels=128,
12
+ higher_in_channels=64,
13
+ lower_in_channels=128,
14
+ fusion_out_channels=128,
15
+ out_indices=(0, 1, 2),
16
+ norm_cfg=norm_cfg,
17
+ align_corners=False),
18
+ decode_head=dict(
19
+ type='DepthwiseSeparableFCNHead',
20
+ in_channels=128,
21
+ channels=128,
22
+ concat_input=False,
23
+ num_classes=19,
24
+ in_index=-1,
25
+ norm_cfg=norm_cfg,
26
+ align_corners=False,
27
+ loss_decode=dict(
28
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)),
29
+ auxiliary_head=[
30
+ dict(
31
+ type='FCNHead',
32
+ in_channels=128,
33
+ channels=32,
34
+ num_convs=1,
35
+ num_classes=19,
36
+ in_index=-2,
37
+ norm_cfg=norm_cfg,
38
+ concat_input=False,
39
+ align_corners=False,
40
+ loss_decode=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
42
+ dict(
43
+ type='FCNHead',
44
+ in_channels=64,
45
+ channels=32,
46
+ num_convs=1,
47
+ num_classes=19,
48
+ in_index=-3,
49
+ norm_cfg=norm_cfg,
50
+ concat_input=False,
51
+ align_corners=False,
52
+ loss_decode=dict(
53
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
54
+ ],
55
+ # model training and testing settings
56
+ train_cfg=dict(),
57
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ dilations=(1, 1, 2, 4),
11
+ strides=(1, 2, 2, 2),
12
+ out_indices=(1, 2, 3),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ neck=dict(
18
+ type='JPU',
19
+ in_channels=(512, 1024, 2048),
20
+ mid_channels=512,
21
+ start_level=0,
22
+ end_level=-1,
23
+ dilations=(1, 2, 4, 8),
24
+ align_corners=False,
25
+ norm_cfg=norm_cfg),
26
+ decode_head=dict(
27
+ type='PSPHead',
28
+ in_channels=2048,
29
+ in_index=2,
30
+ channels=512,
31
+ pool_scales=(1, 2, 3, 6),
32
+ dropout_ratio=0.1,
33
+ num_classes=19,
34
+ norm_cfg=norm_cfg,
35
+ align_corners=False,
36
+ loss_decode=dict(
37
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
38
+ auxiliary_head=dict(
39
+ type='FCNHead',
40
+ in_channels=1024,
41
+ in_index=1,
42
+ channels=256,
43
+ num_convs=1,
44
+ concat_input=False,
45
+ dropout_ratio=0.1,
46
+ num_classes=19,
47
+ norm_cfg=norm_cfg,
48
+ align_corners=False,
49
+ loss_decode=dict(
50
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
51
+ # model training and testing settings
52
+ train_cfg=dict(),
53
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/gcnet_r50-d8.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='GCHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ ratio=1 / 4.,
23
+ pooling_type='att',
24
+ fusion_types=('channel_add', ),
25
+ dropout_ratio=0.1,
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
31
+ auxiliary_head=dict(
32
+ type='FCNHead',
33
+ in_channels=1024,
34
+ in_index=2,
35
+ channels=256,
36
+ num_convs=1,
37
+ concat_input=False,
38
+ dropout_ratio=0.1,
39
+ num_classes=19,
40
+ norm_cfg=norm_cfg,
41
+ align_corners=False,
42
+ loss_decode=dict(
43
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
44
+ # model training and testing settings
45
+ train_cfg=dict(),
46
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/isanet_r50-d8.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='ISAHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ isa_channels=256,
23
+ down_factor=(8, 8),
24
+ dropout_ratio=0.1,
25
+ num_classes=19,
26
+ norm_cfg=norm_cfg,
27
+ align_corners=False,
28
+ loss_decode=dict(
29
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
30
+ auxiliary_head=dict(
31
+ type='FCNHead',
32
+ in_channels=1024,
33
+ in_index=2,
34
+ channels=256,
35
+ num_convs=1,
36
+ concat_input=False,
37
+ dropout_ratio=0.1,
38
+ num_classes=19,
39
+ norm_cfg=norm_cfg,
40
+ align_corners=False,
41
+ loss_decode=dict(
42
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
43
+ # model training and testing settings
44
+ train_cfg=dict(),
45
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/lraspp_m-v3-d8.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ backbone=dict(
6
+ type='MobileNetV3',
7
+ arch='large',
8
+ out_indices=(1, 3, 16),
9
+ norm_cfg=norm_cfg),
10
+ decode_head=dict(
11
+ type='LRASPPHead',
12
+ in_channels=(16, 24, 960),
13
+ in_index=(0, 1, 2),
14
+ channels=128,
15
+ input_transform='multiple_select',
16
+ dropout_ratio=0.1,
17
+ num_classes=19,
18
+ norm_cfg=norm_cfg,
19
+ act_cfg=dict(type='ReLU'),
20
+ align_corners=False,
21
+ loss_decode=dict(
22
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
23
+ # model training and testing settings
24
+ train_cfg=dict(),
25
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/nonlocal_r50-d8.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='EncoderDecoder',
5
+ pretrained='open-mmlab://resnet50_v1c',
6
+ backbone=dict(
7
+ type='ResNetV1c',
8
+ depth=50,
9
+ num_stages=4,
10
+ out_indices=(0, 1, 2, 3),
11
+ dilations=(1, 1, 2, 4),
12
+ strides=(1, 2, 1, 1),
13
+ norm_cfg=norm_cfg,
14
+ norm_eval=False,
15
+ style='pytorch',
16
+ contract_dilation=True),
17
+ decode_head=dict(
18
+ type='NLHead',
19
+ in_channels=2048,
20
+ in_index=3,
21
+ channels=512,
22
+ dropout_ratio=0.1,
23
+ reduction=2,
24
+ use_scale=True,
25
+ mode='embedded_gaussian',
26
+ num_classes=19,
27
+ norm_cfg=norm_cfg,
28
+ align_corners=False,
29
+ loss_decode=dict(
30
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
31
+ auxiliary_head=dict(
32
+ type='FCNHead',
33
+ in_channels=1024,
34
+ in_index=2,
35
+ channels=256,
36
+ num_convs=1,
37
+ concat_input=False,
38
+ dropout_ratio=0.1,
39
+ num_classes=19,
40
+ norm_cfg=norm_cfg,
41
+ align_corners=False,
42
+ loss_decode=dict(
43
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
44
+ # model training and testing settings
45
+ train_cfg=dict(),
46
+ test_cfg=dict(mode='whole'))
VLMEvalKit_old/InternVL/segmentation/configs/_base_/models/pointrend_r50.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
3
+ model = dict(
4
+ type='CascadeEncoderDecoder',
5
+ num_stages=2,
6
+ pretrained='open-mmlab://resnet50_v1c',
7
+ backbone=dict(
8
+ type='ResNetV1c',
9
+ depth=50,
10
+ num_stages=4,
11
+ out_indices=(0, 1, 2, 3),
12
+ dilations=(1, 1, 1, 1),
13
+ strides=(1, 2, 2, 2),
14
+ norm_cfg=norm_cfg,
15
+ norm_eval=False,
16
+ style='pytorch',
17
+ contract_dilation=True),
18
+ neck=dict(
19
+ type='FPN',
20
+ in_channels=[256, 512, 1024, 2048],
21
+ out_channels=256,
22
+ num_outs=4),
23
+ decode_head=[
24
+ dict(
25
+ type='FPNHead',
26
+ in_channels=[256, 256, 256, 256],
27
+ in_index=[0, 1, 2, 3],
28
+ feature_strides=[4, 8, 16, 32],
29
+ channels=128,
30
+ dropout_ratio=-1,
31
+ num_classes=19,
32
+ norm_cfg=norm_cfg,
33
+ align_corners=False,
34
+ loss_decode=dict(
35
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
36
+ dict(
37
+ type='PointHead',
38
+ in_channels=[256],
39
+ in_index=[0],
40
+ channels=256,
41
+ num_fcs=3,
42
+ coarse_pred_each_layer=True,
43
+ dropout_ratio=-1,
44
+ num_classes=19,
45
+ align_corners=False,
46
+ loss_decode=dict(
47
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
48
+ ],
49
+ # model training and testing settings
50
+ train_cfg=dict(
51
+ num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
52
+ test_cfg=dict(
53
+ mode='whole',
54
+ subdivision_steps=2,
55
+ subdivision_num_points=8196,
56
+ scale_factor=2))