tuandunghcmut commited on
Commit
74c960e
·
verified ·
1 Parent(s): f40d9b1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. EAGLE/eagle/model/language_model/__init__.py +0 -0
  2. EAGLE/eagle/model/language_model/eagle_llama.py +173 -0
  3. EAGLE/eagle/model/multimodal_encoder/__init__.py +0 -0
  4. EAGLE/eagle/model/multimodal_encoder/clip_encoder.py +89 -0
  5. EAGLE/eagle/model/multimodal_encoder/convnext_encoder.py +141 -0
  6. EAGLE/eagle/model/multimodal_encoder/hr_clip_encoder.py +175 -0
  7. EAGLE/eagle/model/multimodal_encoder/pix2struct_encoder.py +146 -0
  8. EAGLE/eagle/model/multimodal_encoder/vision_models/__init__.py +0 -0
  9. EAGLE/eagle/model/multimodal_encoder/vision_models/convnext.py +1108 -0
  10. EAGLE/eagle/model/multimodal_encoder/vision_models/eva_vit.py +1235 -0
  11. EAGLE/eagle/model/multimodal_projector/__init__.py +0 -0
  12. EAGLE/eagle/model/multimodal_projector/builder.py +50 -0
  13. EAGLE/lmms_eval/api/__init__.py +0 -0
  14. EAGLE/lmms_eval/api/filter.py +53 -0
  15. EAGLE/lmms_eval/api/instance.py +29 -0
  16. EAGLE/lmms_eval/api/metrics.py +431 -0
  17. EAGLE/lmms_eval/api/model.py +203 -0
  18. EAGLE/lmms_eval/api/registry.py +139 -0
  19. EAGLE/lmms_eval/api/samplers.py +94 -0
  20. EAGLE/lmms_eval/api/task.py +1118 -0
  21. EAGLE/lmms_eval/filters/__init__.py +44 -0
  22. EAGLE/lmms_eval/filters/decontamination.py +23 -0
  23. EAGLE/lmms_eval/filters/extraction.py +60 -0
  24. EAGLE/lmms_eval/filters/selection.py +48 -0
  25. EAGLE/lmms_eval/filters/transformation.py +48 -0
  26. EAGLE/lmms_eval/models/__init__.py +16 -0
  27. EAGLE/lmms_eval/models/eagle.py +415 -0
  28. EAGLE/lmms_eval/models/gpt4v.py +129 -0
  29. EAGLE/lmms_eval/tasks/__init__.py +155 -0
  30. EAGLE/lmms_eval/tasks/_task_utils/file_utils.py +8 -0
  31. EAGLE/lmms_eval/tasks/_task_utils/gpt_eval_utils.py +0 -0
  32. EAGLE/lmms_eval/tasks/_task_utils/vqa_eval_metric.py +213 -0
  33. EAGLE/lmms_eval/tasks/cmmmu/_cmmmu.yaml +4 -0
  34. EAGLE/lmms_eval/tasks/cmmmu/_default_template_cmmmu_yaml +8 -0
  35. EAGLE/lmms_eval/tasks/cmmmu/cmmmu_test.yaml +12 -0
  36. EAGLE/lmms_eval/tasks/cmmmu/cmmmu_val.yaml +15 -0
  37. EAGLE/lmms_eval/tasks/cmmmu/utils.py +421 -0
  38. EAGLE/lmms_eval/tasks/gqa/gqa.yaml +32 -0
  39. EAGLE/lmms_eval/tasks/gqa/utils.py +23 -0
  40. EAGLE/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml +39 -0
  41. EAGLE/lmms_eval/tasks/llava-in-the-wild/rule.json +11 -0
  42. EAGLE/lmms_eval/tasks/llava-in-the-wild/utils.py +197 -0
  43. EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_cn_yaml +22 -0
  44. EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_en_yaml +25 -0
  45. EAGLE/lmms_eval/tasks/mmbench/cc_utils.py +109 -0
  46. EAGLE/lmms_eval/tasks/mmbench/cn_utils.py +127 -0
  47. EAGLE/lmms_eval/tasks/mmbench/en_utils.py +126 -0
  48. EAGLE/lmms_eval/tasks/mmbench/mmbench.yaml +11 -0
  49. EAGLE/lmms_eval/tasks/mmbench/mmbench_cc.yaml +34 -0
  50. EAGLE/lmms_eval/tasks/mmbench/mmbench_cn.yaml +9 -0
EAGLE/eagle/model/language_model/__init__.py ADDED
File without changes
EAGLE/eagle/model/language_model/eagle_llama.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2023 Haotian Liu
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+
30
+ from typing import List, Optional, Tuple, Union
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from transformers import AutoConfig, AutoModelForCausalLM, \
36
+ LlamaConfig, LlamaModel, LlamaForCausalLM
37
+
38
+ from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.generation.utils import GenerateOutput
40
+
41
+ from ..eagle_arch import EagleMetaModel, EagleMetaForCausalLM
42
+
43
+
44
+ class EagleConfig(LlamaConfig):
45
+ model_type = "eagle_llama"
46
+
47
+
48
+ class EagleLlamaModel(EagleMetaModel, LlamaModel):
49
+ config_class = EagleConfig
50
+
51
+ def __init__(self, config: LlamaConfig):
52
+ super(EagleLlamaModel, self).__init__(config)
53
+
54
+
55
+ class EagleLlamaForCausalLM(LlamaForCausalLM, EagleMetaForCausalLM):
56
+ config_class = EagleConfig
57
+
58
+ def __init__(self, config):
59
+ super(LlamaForCausalLM, self).__init__(config)
60
+ self.model = EagleLlamaModel(config)
61
+ self.pretraining_tp = config.pretraining_tp
62
+ self.vocab_size = config.vocab_size
63
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
64
+
65
+ # Initialize weights and apply final processing
66
+ self.post_init()
67
+
68
+ def get_model(self):
69
+ return self.model
70
+
71
+ def forward(
72
+ self,
73
+ input_ids: torch.LongTensor = None,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ position_ids: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
77
+ inputs_embeds: Optional[torch.FloatTensor] = None,
78
+ labels: Optional[torch.LongTensor] = None,
79
+ use_cache: Optional[bool] = None,
80
+ output_attentions: Optional[bool] = None,
81
+ output_hidden_states: Optional[bool] = None,
82
+ images: Optional[torch.FloatTensor] = None,
83
+ image_sizes: Optional[List[List[int]]] = None,
84
+ return_dict: Optional[bool] = None,
85
+ **kwargs # for llama3, upgrade the transformers and will receive an additional argument cache_position
86
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
87
+
88
+ if inputs_embeds is None:
89
+ (
90
+ input_ids,
91
+ position_ids,
92
+ attention_mask,
93
+ past_key_values,
94
+ inputs_embeds,
95
+ labels
96
+ ) = self.prepare_inputs_labels_for_multimodal(
97
+ input_ids,
98
+ position_ids,
99
+ attention_mask,
100
+ past_key_values,
101
+ labels,
102
+ images,
103
+ image_sizes
104
+ )
105
+
106
+ return super().forward(
107
+ input_ids=input_ids,
108
+ attention_mask=attention_mask,
109
+ position_ids=position_ids,
110
+ past_key_values=past_key_values,
111
+ inputs_embeds=inputs_embeds,
112
+ labels=labels,
113
+ use_cache=use_cache,
114
+ output_attentions=output_attentions,
115
+ output_hidden_states=output_hidden_states,
116
+ return_dict=return_dict
117
+ )
118
+
119
+ @torch.no_grad()
120
+ def generate(
121
+ self,
122
+ inputs: Optional[torch.Tensor] = None,
123
+ images: Optional[torch.Tensor] = None,
124
+ image_sizes: Optional[torch.Tensor] = None,
125
+ **kwargs,
126
+ ) -> Union[GenerateOutput, torch.LongTensor]:
127
+ position_ids = kwargs.pop("position_ids", None)
128
+ attention_mask = kwargs.pop("attention_mask", None)
129
+ if "inputs_embeds" in kwargs:
130
+ raise NotImplementedError("`inputs_embeds` is not supported")
131
+
132
+ if images is not None:
133
+ (
134
+ inputs,
135
+ position_ids,
136
+ attention_mask,
137
+ _,
138
+ inputs_embeds,
139
+ _
140
+ ) = self.prepare_inputs_labels_for_multimodal(
141
+ inputs,
142
+ position_ids,
143
+ attention_mask,
144
+ None,
145
+ None,
146
+ images,
147
+ image_sizes=image_sizes
148
+ )
149
+ else:
150
+ inputs_embeds = self.get_model().embed_tokens(inputs)
151
+
152
+ return super().generate(
153
+ position_ids=position_ids,
154
+ attention_mask=attention_mask,
155
+ inputs_embeds=inputs_embeds,
156
+ **kwargs
157
+ )
158
+
159
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
160
+ inputs_embeds=None, **kwargs):
161
+ images = kwargs.pop("images", None)
162
+ image_sizes = kwargs.pop("image_sizes", None)
163
+ inputs = super().prepare_inputs_for_generation(
164
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
165
+ )
166
+ if images is not None:
167
+ inputs['images'] = images
168
+ if image_sizes is not None:
169
+ inputs['image_sizes'] = image_sizes
170
+ return inputs
171
+
172
+ AutoConfig.register("eagle_llama", EagleConfig)
173
+ AutoModelForCausalLM.register(EagleConfig, EagleLlamaForCausalLM)
EAGLE/eagle/model/multimodal_encoder/__init__.py ADDED
File without changes
EAGLE/eagle/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
6
+
7
+
8
+ class CLIPVisionTower(nn.Module):
9
+ def __init__(self, vision_tower, args, delay_load=False):
10
+ super().__init__()
11
+
12
+ self.is_loaded = False
13
+
14
+ self.vision_tower_name = vision_tower
15
+ self.select_layer = args.mm_vision_select_layer
16
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
21
+ self.load_model()
22
+ else:
23
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
24
+
25
+ def load_model(self, device_map=None):
26
+ if self.is_loaded:
27
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
28
+ return
29
+
30
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
31
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
32
+ self.vision_tower.requires_grad_(False)
33
+
34
+ self.is_loaded = True
35
+
36
+ def feature_select(self, image_forward_outs):
37
+ image_features = image_forward_outs.hidden_states[self.select_layer]
38
+ if self.select_feature == 'patch':
39
+ image_features = image_features[:, 1:]
40
+ elif self.select_feature == 'cls_patch':
41
+ image_features = image_features
42
+ else:
43
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
44
+ return image_features
45
+
46
+ @torch.no_grad()
47
+ def forward(self, images):
48
+ if type(images) is list:
49
+ image_features = []
50
+ for image in images:
51
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
52
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
53
+ image_features.append(image_feature)
54
+ else:
55
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
56
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
57
+
58
+ return image_features
59
+
60
+ @property
61
+ def dummy_feature(self):
62
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
63
+
64
+ @property
65
+ def dtype(self):
66
+ return self.vision_tower.dtype
67
+
68
+ @property
69
+ def device(self):
70
+ return self.vision_tower.device
71
+
72
+ @property
73
+ def config(self):
74
+ if self.is_loaded:
75
+ return self.vision_tower.config
76
+ else:
77
+ return self.cfg_only
78
+
79
+ @property
80
+ def hidden_size(self):
81
+ return self.config.hidden_size
82
+
83
+ @property
84
+ def num_patches_per_side(self):
85
+ return self.config.image_size // self.config.patch_size
86
+
87
+ @property
88
+ def num_patches(self):
89
+ return (self.config.image_size // self.config.patch_size) ** 2
EAGLE/eagle/model/multimodal_encoder/convnext_encoder.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/luogen1996/LLaVA-HR
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from transformers import CLIPImageProcessor
21
+ from .vision_models.convnext import convnext_xxlarge
22
+ from torch.utils.checkpoint import checkpoint
23
+
24
+ cfg={
25
+ "crop_size": 256,
26
+ "do_center_crop": True,
27
+ "do_normalize": True,
28
+ "do_resize": True,
29
+ "feature_extractor_type": "CLIPFeatureExtractor",
30
+ "image_mean": [
31
+ 0.48145466,
32
+ 0.4578275,
33
+ 0.40821073
34
+ ],
35
+ "image_std": [
36
+ 0.26862954,
37
+ 0.26130258,
38
+ 0.27577711
39
+ ],
40
+ "resample": 3,
41
+ "size": 256
42
+ }
43
+
44
+ class ConvNextVisionTower(nn.Module):
45
+ def __init__(self, vision_tower, args, delay_load=False):
46
+ super().__init__()
47
+
48
+ self.is_loaded = False
49
+ self.freeze_vision=args.freeze_vision
50
+ self.input_image_size=args.input_image_size
51
+ self.vision_tower_name = vision_tower
52
+ self.select_layer = -1 # hardcode
53
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
54
+
55
+ self.load_model()
56
+
57
+ def load_model(self):
58
+ self.image_processor = CLIPImageProcessor(**cfg)
59
+ if 'xxlarge' in self.vision_tower_name:
60
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
61
+ setattr(self.vision_tower, 'hidden_size', 3072)
62
+ else:
63
+ raise NotImplementedError
64
+
65
+ if self.freeze_vision:
66
+ self.vision_tower.requires_grad_(False)
67
+
68
+ # Hardcode
69
+ for s in self.vision_tower.stages:
70
+ s.grad_checkpointing = True
71
+
72
+ if self.input_image_size is not None:
73
+ self.image_processor.size=self.input_image_size
74
+ self.image_processor.crop_size={
75
+ 'height':self.input_image_size,
76
+ 'width': self.input_image_size
77
+ }
78
+
79
+ self.is_loaded = True
80
+
81
+ def feature_select(self, image_forward_outs):
82
+ image_features = image_forward_outs[self.select_layer]
83
+ return image_features
84
+
85
+ def forward_features(self, x):
86
+ x = self.vision_tower.stem(x)
87
+ image_forward_out=[]
88
+ for blk in self.vision_tower.stages:
89
+ x = blk(x)
90
+ b,c,h,w=x.shape
91
+ image_forward_out.append(x.view(b,c,-1).transpose(1,2))
92
+ return image_forward_out
93
+
94
+ def forward(self, images):
95
+ if self.freeze_vision:
96
+ with torch.no_grad():
97
+ image_features = self._forward_images(images)
98
+ else:
99
+ image_features = self._forward_images(images)
100
+
101
+ return image_features
102
+
103
+ def _forward_images(self, images):
104
+
105
+ image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
106
+ image_features = self.feature_select(image_forward_outs)
107
+
108
+ return image_features
109
+
110
+ @property
111
+ def dummy_feature(self):
112
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
113
+
114
+ @property
115
+ def dtype(self):
116
+ return next(self.vision_tower.parameters()).dtype
117
+
118
+ @property
119
+ def device(self):
120
+ return next(self.vision_tower.parameters()).device
121
+
122
+ @property
123
+ def config(self):
124
+ assert NotImplementedError
125
+ pass
126
+
127
+ @property
128
+ def num_attention_heads(self):
129
+ # as constant
130
+ return 16
131
+ @property
132
+ def num_layers(self):
133
+ # as constant
134
+ return 4
135
+ @property
136
+ def hidden_size(self):
137
+ return self.vision_tower.hidden_size
138
+
139
+ @property
140
+ def num_patches(self):
141
+ return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
EAGLE/eagle/model/multimodal_encoder/hr_clip_encoder.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # Mostly copy-paste from LLaVA-HR
17
+ # https://github.com/luogen1996/LLaVA-HR
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.utils.checkpoint import checkpoint
22
+
23
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
24
+
25
+ import math
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from typing import List, Optional
29
+
30
+
31
+ def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
32
+ batch_size = pixel_values.shape[0]
33
+ target_dtype = self.patch_embedding.weight.dtype
34
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
35
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
36
+
37
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
38
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
39
+ position_embeddings = self.position_embedding(self.position_ids)
40
+
41
+ if position_embeddings.shape[1]!=embeddings.shape[1]:
42
+ position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1])
43
+
44
+ embeddings = embeddings + position_embeddings
45
+ return embeddings
46
+
47
+
48
+ def resample_pos_embed(
49
+ posemb,
50
+ new_size: int,
51
+ num_prefix_tokens: int = 1,
52
+ interpolation: str = 'bicubic',
53
+ antialias: bool = True,
54
+ verbose: bool = False,
55
+ ):
56
+ new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))]
57
+ num_pos_tokens = posemb.shape[1] - num_prefix_tokens
58
+ old_size = int(math.sqrt(num_pos_tokens))
59
+ bs=posemb.shape[0]
60
+
61
+ if num_prefix_tokens:
62
+ posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:]
63
+ else:
64
+ posemb_prefix, posemb = None, posemb
65
+
66
+ # do the interpolation
67
+ embed_dim = posemb.shape[-1]
68
+ orig_dtype = posemb.dtype
69
+ posemb = posemb.float() # interpolate needs float32
70
+ posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2)
71
+ posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
72
+ posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim)
73
+ posemb = posemb.to(dtype=orig_dtype)
74
+
75
+ # add back extra (class, etc) prefix tokens
76
+ if posemb_prefix is not None:
77
+ posemb = torch.cat([posemb_prefix, posemb],1)
78
+
79
+ if not torch.jit.is_scripting() and verbose:
80
+ print(f'Resized position embedding: {old_size} to {new_size}.')
81
+
82
+ return posemb
83
+
84
+ class HRCLIPVisionTower(nn.Module):
85
+ def __init__(self, vision_tower, args, delay_load=False):
86
+ super().__init__()
87
+
88
+ self.is_loaded = False
89
+ self.freeze_vision=args.freeze_vision
90
+ self.input_image_size=args.input_image_size
91
+ self.vision_tower_name = vision_tower
92
+ self.select_layer = args.mm_vision_select_layer
93
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
94
+
95
+ if not delay_load:
96
+ self.load_model()
97
+ else:
98
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
99
+
100
+
101
+ def load_model(self):
102
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
103
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
104
+ # checkpointing for clip
105
+ self.vision_tower.vision_model.encoder.gradient_checkpointing =True
106
+
107
+ if self.freeze_vision:
108
+ self.vision_tower.requires_grad_(False)
109
+
110
+ cls_=self.vision_tower.vision_model.embeddings
111
+ bound_method = forward_embeddings.__get__(cls_, cls_.__class__)
112
+ setattr(cls_, 'forward', bound_method)
113
+
114
+ if self.input_image_size is not None:
115
+ self.image_processor.size=self.input_image_size
116
+ self.image_processor.crop_size={
117
+ 'height':self.input_image_size,
118
+ 'width': self.input_image_size
119
+ }
120
+
121
+ self.is_loaded = True
122
+
123
+ def forward(self, x):
124
+ # 448 image input
125
+ blks = self.vision_tower.vision_model.encoder.layers
126
+ x = self.vision_tower.vision_model.embeddings(x)
127
+ x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:])
128
+
129
+ # inference of fast branch
130
+ for blk in blks:
131
+ if self.training:
132
+ x=checkpoint(
133
+ blk.__call__,
134
+ x,
135
+ None,
136
+ None
137
+ )[0]
138
+ else:
139
+ x = blk(x, None, None)[0]
140
+
141
+ return x
142
+
143
+ @property
144
+ def dummy_feature(self):
145
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
146
+
147
+ @property
148
+ def dtype(self):
149
+ return self.vision_tower.dtype
150
+
151
+ @property
152
+ def device(self):
153
+ return self.vision_tower.device
154
+
155
+
156
+ @property
157
+ def num_attention_heads(self):
158
+ return self.config.num_attention_heads
159
+ @property
160
+ def num_layers(self):
161
+ return self.config.num_hidden_layers
162
+ @property
163
+ def config(self):
164
+ if self.is_loaded:
165
+ return self.vision_tower.config
166
+ else:
167
+ return self.cfg_only
168
+
169
+ @property
170
+ def hidden_size(self):
171
+ return self.config.hidden_size
172
+
173
+ @property
174
+ def num_patches(self):
175
+ return (self.config.image_size // self.config.patch_size) ** 2
EAGLE/eagle/model/multimodal_encoder/pix2struct_encoder.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import re
18
+ from PIL import Image
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoModel, CLIPImageProcessor
22
+ from PIL import Image
23
+ import requests
24
+ import torch.nn.functional as F
25
+ from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration
26
+
27
+ cfg={
28
+ "crop_size": 256,
29
+ "do_center_crop": True,
30
+ "do_normalize": True,
31
+ "do_resize": True,
32
+ "feature_extractor_type": "CLIPFeatureExtractor",
33
+ "image_mean": [
34
+ 0.48145466,
35
+ 0.4578275,
36
+ 0.40821073
37
+ ],
38
+ "image_std": [
39
+ 0.26862954,
40
+ 0.26130258,
41
+ 0.27577711
42
+ ],
43
+ "resample": 3,
44
+ "size": 256
45
+ }
46
+
47
+ '''
48
+ Pixel2Struct-Large Model (pretrained version)
49
+ '''
50
+ class Pix2StructLargeVisionTower(nn.Module):
51
+ def __init__(self, vision_tower, args, delay_load=False):
52
+ super().__init__()
53
+
54
+ self.is_loaded = False
55
+ self.vision_tower_name = vision_tower
56
+ self.do_resize = args.do_resize
57
+ self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor
58
+ self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
59
+ self.input_image_size = args.input_image_size
60
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
61
+ self.freeze_vision = args.freeze_vision
62
+
63
+ self.args = args
64
+ if not self.is_loaded:
65
+ self.load_model()
66
+
67
+ def load_model(self):
68
+ if self.is_loaded:
69
+ return
70
+ whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large")
71
+ self.vision_tower = whole_model.encoder
72
+ self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large")
73
+ self.pix2struct_processor.image_processor.is_vqa = False
74
+
75
+ self.image_processor = CLIPImageProcessor(**cfg)
76
+ if self.input_image_size is not None:
77
+ self.image_processor.size=self.input_image_size
78
+ self.image_processor.crop_size={
79
+ 'height':self.input_image_size,
80
+ 'width': self.input_image_size
81
+ }
82
+
83
+ if self.freeze_vision:
84
+ self.vision_tower.requires_grad_(False)
85
+
86
+ self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1)
87
+ self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1)
88
+
89
+ self.is_loaded = True
90
+
91
+ def feature_select(self, image_forward_outs):
92
+ image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0
93
+ if self.select_feature == 'patch':
94
+ image_features = image_features[:, 1:]
95
+ elif self.select_feature == 'cls_patch':
96
+ image_features = image_features
97
+ else:
98
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
99
+ return image_features
100
+
101
+ # @torch.no_grad()
102
+ def forward(self, images):
103
+
104
+ if self.de_normalize:
105
+ mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
106
+ std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
107
+ x = (images * std + mean) * 255.0
108
+ x = self.pix2struct_processor(images=x.float(), return_tensors="pt")
109
+
110
+ image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state
111
+ bs, n, c = image_features.shape
112
+ image_features = image_features[:, :2025, :] # HARD CODE
113
+
114
+ if self.do_resize:
115
+ image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE
116
+ image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE
117
+ return image_features
118
+ else:
119
+ return image_features
120
+
121
+
122
+ @property
123
+ def dummy_feature(self):
124
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
125
+
126
+ @property
127
+ def dtype(self):
128
+ return next(self.vision_tower.parameters()).dtype
129
+
130
+ @property
131
+ def device(self):
132
+ return next(self.vision_tower.parameters()).device
133
+
134
+ @property
135
+ def config(self):
136
+ return self.vision_tower.config
137
+
138
+ @property
139
+ def hidden_size(self):
140
+ # Hard code
141
+ hidden_dim = 1536
142
+ return hidden_dim
143
+
144
+ @property
145
+ def num_patches(self):
146
+ return self.config['num_patches']
EAGLE/eagle/model/multimodal_encoder/vision_models/__init__.py ADDED
File without changes
EAGLE/eagle/model/multimodal_encoder/vision_models/convnext.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConvNeXt
2
+
3
+ Papers:
4
+ * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
+ @Article{liu2022convnet,
6
+ author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
+ title = {A ConvNet for the 2020s},
8
+ journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
+ year = {2022},
10
+ }
11
+
12
+ * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
+ @article{Woo2023ConvNeXtV2,
14
+ title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
+ author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
+ year={2023},
17
+ journal={arXiv preprint arXiv:2301.00808},
18
+ }
19
+
20
+ Original code and weights from:
21
+ * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
+ * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
+
24
+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
+
26
+ Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
+ """
28
+ # ConvNeXt
29
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
30
+ # All rights reserved.
31
+ # This source code is licensed under the MIT license
32
+
33
+ # ConvNeXt-V2
34
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
35
+ # All rights reserved.
36
+ # This source code is licensed under the license found in the
37
+ # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
+ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
+
40
+ from collections import OrderedDict
41
+ from functools import partial
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+
47
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
49
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50
+ from timm.layers import NormMlpClassifierHead, ClassifierHead
51
+ from timm.models._builder import build_model_with_cfg
52
+ from timm.models._manipulate import named_apply, checkpoint_seq
53
+ from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
54
+
55
+ __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
56
+
57
+
58
+ class Downsample(nn.Module):
59
+
60
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
61
+ super().__init__()
62
+ avg_stride = stride if dilation == 1 else 1
63
+ if stride > 1 or dilation > 1:
64
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
65
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
66
+ else:
67
+ self.pool = nn.Identity()
68
+
69
+ if in_chs != out_chs:
70
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
71
+ else:
72
+ self.conv = nn.Identity()
73
+
74
+ def forward(self, x):
75
+ x = self.pool(x)
76
+ x = self.conv(x)
77
+ return x
78
+
79
+
80
+ class ConvNeXtBlock(nn.Module):
81
+ """ ConvNeXt Block
82
+ There are two equivalent implementations:
83
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
84
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
85
+
86
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
87
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
88
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_chs: int,
94
+ out_chs: Optional[int] = None,
95
+ kernel_size: int = 7,
96
+ stride: int = 1,
97
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
98
+ mlp_ratio: float = 4,
99
+ conv_mlp: bool = False,
100
+ conv_bias: bool = True,
101
+ use_grn: bool = False,
102
+ ls_init_value: Optional[float] = 1e-6,
103
+ act_layer: Union[str, Callable] = 'gelu',
104
+ norm_layer: Optional[Callable] = None,
105
+ drop_path: float = 0.,
106
+ ):
107
+ """
108
+
109
+ Args:
110
+ in_chs: Block input channels.
111
+ out_chs: Block output channels (same as in_chs if None).
112
+ kernel_size: Depthwise convolution kernel size.
113
+ stride: Stride of depthwise convolution.
114
+ dilation: Tuple specifying input and output dilation of block.
115
+ mlp_ratio: MLP expansion ratio.
116
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
117
+ conv_bias: Apply bias for all convolution (linear) layers.
118
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
119
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
120
+ act_layer: Activation layer.
121
+ norm_layer: Normalization layer (defaults to LN if not specified).
122
+ drop_path: Stochastic depth probability.
123
+ """
124
+ super().__init__()
125
+ out_chs = out_chs or in_chs
126
+ dilation = to_ntuple(2)(dilation)
127
+ act_layer = get_act_layer(act_layer)
128
+ if not norm_layer:
129
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
130
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
131
+ self.use_conv_mlp = conv_mlp
132
+ self.conv_dw = create_conv2d(
133
+ in_chs,
134
+ out_chs,
135
+ kernel_size=kernel_size,
136
+ stride=stride,
137
+ dilation=dilation[0],
138
+ depthwise=True,
139
+ bias=conv_bias,
140
+ )
141
+ self.norm = norm_layer(out_chs)
142
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
143
+ self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
144
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
145
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
146
+ else:
147
+ self.shortcut = nn.Identity()
148
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
+
150
+ def forward(self, x):
151
+ shortcut = x
152
+ x = self.conv_dw(x)
153
+ if self.use_conv_mlp:
154
+ x = self.norm(x)
155
+ x = self.mlp(x)
156
+ else:
157
+ x = x.permute(0, 2, 3, 1)
158
+ x = self.norm(x)
159
+ x = self.mlp(x)
160
+ x = x.permute(0, 3, 1, 2)
161
+ if self.weight is not None:
162
+ x = x.mul(self.weight.reshape(1, -1, 1, 1))
163
+
164
+ x = self.drop_path(x) + self.shortcut(shortcut)
165
+ return x
166
+
167
+
168
+ class ConvNeXtStage(nn.Module):
169
+
170
+ def __init__(
171
+ self,
172
+ in_chs,
173
+ out_chs,
174
+ kernel_size=7,
175
+ stride=2,
176
+ depth=2,
177
+ dilation=(1, 1),
178
+ drop_path_rates=None,
179
+ ls_init_value=1.0,
180
+ conv_mlp=False,
181
+ conv_bias=True,
182
+ use_grn=False,
183
+ act_layer='gelu',
184
+ norm_layer=None,
185
+ norm_layer_cl=None
186
+ ):
187
+ super().__init__()
188
+ self.grad_checkpointing = False
189
+
190
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
191
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
192
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
193
+ self.downsample = nn.Sequential(
194
+ norm_layer(in_chs),
195
+ create_conv2d(
196
+ in_chs,
197
+ out_chs,
198
+ kernel_size=ds_ks,
199
+ stride=stride,
200
+ dilation=dilation[0],
201
+ padding=pad,
202
+ bias=conv_bias,
203
+ ),
204
+ )
205
+ in_chs = out_chs
206
+ else:
207
+ self.downsample = nn.Identity()
208
+
209
+ drop_path_rates = drop_path_rates or [0.] * depth
210
+ stage_blocks = []
211
+ for i in range(depth):
212
+ stage_blocks.append(ConvNeXtBlock(
213
+ in_chs=in_chs,
214
+ out_chs=out_chs,
215
+ kernel_size=kernel_size,
216
+ dilation=dilation[1],
217
+ drop_path=drop_path_rates[i],
218
+ ls_init_value=ls_init_value,
219
+ conv_mlp=conv_mlp,
220
+ conv_bias=conv_bias,
221
+ use_grn=use_grn,
222
+ act_layer=act_layer,
223
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
224
+ ))
225
+ in_chs = out_chs
226
+ self.blocks = nn.Sequential(*stage_blocks)
227
+
228
+ def forward(self, x):
229
+ x = self.downsample(x)
230
+ if self.grad_checkpointing and not torch.jit.is_scripting():
231
+ x = checkpoint_seq(self.blocks, x)
232
+ else:
233
+ x = self.blocks(x)
234
+ return x
235
+
236
+
237
+ class ConvNeXt(nn.Module):
238
+ r""" ConvNeXt
239
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_chans: int = 3,
245
+ num_classes: int = 1000,
246
+ global_pool: str = 'avg',
247
+ output_stride: int = 32,
248
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
249
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
250
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
251
+ ls_init_value: Optional[float] = 1e-6,
252
+ stem_type: str = 'patch',
253
+ patch_size: int = 4,
254
+ head_init_scale: float = 1.,
255
+ head_norm_first: bool = False,
256
+ head_hidden_size: Optional[int] = None,
257
+ conv_mlp: bool = False,
258
+ conv_bias: bool = True,
259
+ use_grn: bool = False,
260
+ act_layer: Union[str, Callable] = 'gelu',
261
+ norm_layer: Optional[Union[str, Callable]] = None,
262
+ norm_eps: Optional[float] = None,
263
+ drop_rate: float = 0.,
264
+ drop_path_rate: float = 0.,
265
+ ):
266
+ """
267
+ Args:
268
+ in_chans: Number of input image channels.
269
+ num_classes: Number of classes for classification head.
270
+ global_pool: Global pooling type.
271
+ output_stride: Output stride of network, one of (8, 16, 32).
272
+ depths: Number of blocks at each stage.
273
+ dims: Feature dimension at each stage.
274
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
275
+ ls_init_value: Init value for Layer Scale, disabled if None.
276
+ stem_type: Type of stem.
277
+ patch_size: Stem patch size for patch stem.
278
+ head_init_scale: Init scaling value for classifier weights and biases.
279
+ head_norm_first: Apply normalization before global pool + head.
280
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
281
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
282
+ conv_bias: Use bias layers w/ all convolutions.
283
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
284
+ act_layer: Activation layer type.
285
+ norm_layer: Normalization layer type.
286
+ drop_rate: Head pre-classifier dropout rate.
287
+ drop_path_rate: Stochastic depth drop rate.
288
+ """
289
+ super().__init__()
290
+ assert output_stride in (8, 16, 32)
291
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
292
+ if norm_layer is None:
293
+ norm_layer = LayerNorm2d
294
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
295
+ if norm_eps is not None:
296
+ norm_layer = partial(norm_layer, eps=norm_eps)
297
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298
+ else:
299
+ assert conv_mlp,\
300
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
301
+ norm_layer_cl = norm_layer
302
+ if norm_eps is not None:
303
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
304
+
305
+ self.num_classes = num_classes
306
+ self.drop_rate = drop_rate
307
+ self.feature_info = []
308
+
309
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
310
+ if stem_type == 'patch':
311
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
+ self.stem = nn.Sequential(
313
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
314
+ norm_layer(dims[0]),
315
+ )
316
+ stem_stride = patch_size
317
+ else:
318
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319
+ self.stem = nn.Sequential(
320
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
321
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322
+ norm_layer(dims[0]),
323
+ )
324
+ stem_stride = 4
325
+
326
+ self.stages = nn.Sequential()
327
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
328
+ stages = []
329
+ prev_chs = dims[0]
330
+ curr_stride = stem_stride
331
+ dilation = 1
332
+ # 4 feature resolution stages, each consisting of multiple residual blocks
333
+ for i in range(4):
334
+ stride = 2 if curr_stride == 2 or i > 0 else 1
335
+ if curr_stride >= output_stride and stride > 1:
336
+ dilation *= stride
337
+ stride = 1
338
+ curr_stride *= stride
339
+ first_dilation = 1 if dilation in (1, 2) else 2
340
+ out_chs = dims[i]
341
+ stages.append(ConvNeXtStage(
342
+ prev_chs,
343
+ out_chs,
344
+ kernel_size=kernel_sizes[i],
345
+ stride=stride,
346
+ dilation=(first_dilation, dilation),
347
+ depth=depths[i],
348
+ drop_path_rates=dp_rates[i],
349
+ ls_init_value=ls_init_value,
350
+ conv_mlp=conv_mlp,
351
+ conv_bias=conv_bias,
352
+ use_grn=use_grn,
353
+ act_layer=act_layer,
354
+ norm_layer=norm_layer,
355
+ norm_layer_cl=norm_layer_cl,
356
+ ))
357
+ prev_chs = out_chs
358
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360
+ self.stages = nn.Sequential(*stages)
361
+ self.num_features = prev_chs
362
+
363
+ # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364
+ # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
365
+ if head_norm_first:
366
+ assert not head_hidden_size
367
+ self.norm_pre = norm_layer(self.num_features)
368
+ self.head = ClassifierHead(
369
+ self.num_features,
370
+ num_classes,
371
+ pool_type=global_pool,
372
+ drop_rate=self.drop_rate,
373
+ )
374
+ else:
375
+ self.norm_pre = nn.Identity()
376
+ self.head = NormMlpClassifierHead(
377
+ self.num_features,
378
+ num_classes,
379
+ hidden_size=head_hidden_size,
380
+ pool_type=global_pool,
381
+ drop_rate=self.drop_rate,
382
+ norm_layer=norm_layer,
383
+ act_layer='gelu',
384
+ )
385
+ named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386
+
387
+ @torch.jit.ignore
388
+ def group_matcher(self, coarse=False):
389
+ return dict(
390
+ stem=r'^stem',
391
+ blocks=r'^stages\.(\d+)' if coarse else [
392
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
393
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
394
+ (r'^norm_pre', (99999,))
395
+ ]
396
+ )
397
+
398
+ @torch.jit.ignore
399
+ def set_grad_checkpointing(self, enable=True):
400
+ for s in self.stages:
401
+ s.grad_checkpointing = enable
402
+
403
+ @torch.jit.ignore
404
+ def get_classifier(self):
405
+ return self.head.fc
406
+
407
+ def reset_classifier(self, num_classes=0, global_pool=None):
408
+ self.head.reset(num_classes, global_pool)
409
+
410
+ def forward_features(self, x):
411
+ x = self.stem(x)
412
+ x = self.stages(x)
413
+ x = self.norm_pre(x)
414
+ return x
415
+
416
+ def forward_head(self, x, pre_logits: bool = False):
417
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
418
+
419
+ def forward(self, x):
420
+ x = self.forward_features(x)
421
+ x = self.forward_head(x)
422
+ return x
423
+
424
+
425
+ def _init_weights(module, name=None, head_init_scale=1.0):
426
+ if isinstance(module, nn.Conv2d):
427
+ trunc_normal_(module.weight, std=.02)
428
+ if module.bias is not None:
429
+ nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Linear):
431
+ trunc_normal_(module.weight, std=.02)
432
+ nn.init.zeros_(module.bias)
433
+ if name and 'head.' in name:
434
+ module.weight.data.mul_(head_init_scale)
435
+ module.bias.data.mul_(head_init_scale)
436
+
437
+
438
+ def checkpoint_filter_fn(state_dict, model):
439
+ """ Remap FB checkpoints -> timm """
440
+ if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
441
+ out_dict={}
442
+ out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
443
+ return out_dict # non-FB checkpoint
444
+ if 'model' in state_dict:
445
+ state_dict = state_dict['model']
446
+
447
+ out_dict = {}
448
+ if 'visual.trunk.stem.0.weight' in state_dict:
449
+ out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
450
+ k.startswith('visual.trunk.')}
451
+
452
+ if 'visual.head.proj.weight' in state_dict:
453
+ out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
454
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
455
+ elif 'visual.head.mlp.fc1.weight' in state_dict:
456
+ out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
457
+ out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
458
+ out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
459
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
460
+ return out_dict
461
+
462
+ import re
463
+ for k, v in state_dict.items():
464
+ k = k.replace('downsample_layers.0.', 'stem.')
465
+ k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
466
+ k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
467
+ k = k.replace('dwconv', 'conv_dw')
468
+ k = k.replace('pwconv', 'mlp.fc')
469
+ if 'grn' in k:
470
+ k = k.replace('grn.beta', 'mlp.grn.bias')
471
+ k = k.replace('grn.gamma', 'mlp.grn.weight')
472
+ v = v.reshape(v.shape[-1])
473
+ k = k.replace('head.', 'head.fc.')
474
+ if k.startswith('norm.'):
475
+ k = k.replace('norm', 'head.norm')
476
+ if v.ndim == 2 and 'head' not in k:
477
+ model_shape = model.state_dict()[k].shape
478
+ v = v.reshape(model_shape)
479
+ k=k.replace('gamma','weight')
480
+ out_dict[k] = v
481
+
482
+ return out_dict
483
+
484
+
485
+ def _create_convnext(variant, pretrained=False, **kwargs):
486
+ if kwargs.get('pretrained_cfg', '') == 'fcmae':
487
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
488
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
489
+ kwargs.setdefault('pretrained_strict', False)
490
+
491
+ model = build_model_with_cfg(
492
+ ConvNeXt, variant, pretrained,
493
+ pretrained_filter_fn=checkpoint_filter_fn,
494
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
495
+ **kwargs)
496
+ return model
497
+
498
+
499
+ def _cfg(url='', **kwargs):
500
+ return {
501
+ 'url': url,
502
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
503
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
504
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
505
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
506
+ **kwargs
507
+ }
508
+
509
+
510
+ def _cfgv2(url='', **kwargs):
511
+ return {
512
+ 'url': url,
513
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
514
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
515
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
516
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
517
+ 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
518
+ 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
519
+ 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
520
+ **kwargs
521
+ }
522
+
523
+
524
+ default_cfgs = generate_default_cfgs({
525
+ # timm specific variants
526
+ 'convnext_tiny.in12k_ft_in1k': _cfg(
527
+ hf_hub_id='timm/',
528
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
529
+ 'convnext_small.in12k_ft_in1k': _cfg(
530
+ hf_hub_id='timm/',
531
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
532
+
533
+ 'convnext_atto.d2_in1k': _cfg(
534
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
535
+ hf_hub_id='timm/',
536
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
537
+ 'convnext_atto_ols.a2_in1k': _cfg(
538
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
539
+ hf_hub_id='timm/',
540
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
541
+ 'convnext_femto.d1_in1k': _cfg(
542
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
543
+ hf_hub_id='timm/',
544
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
545
+ 'convnext_femto_ols.d1_in1k': _cfg(
546
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
547
+ hf_hub_id='timm/',
548
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
549
+ 'convnext_pico.d1_in1k': _cfg(
550
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
551
+ hf_hub_id='timm/',
552
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
553
+ 'convnext_pico_ols.d1_in1k': _cfg(
554
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
555
+ hf_hub_id='timm/',
556
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
557
+ 'convnext_nano.in12k_ft_in1k': _cfg(
558
+ hf_hub_id='timm/',
559
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
560
+ 'convnext_nano.d1h_in1k': _cfg(
561
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
562
+ hf_hub_id='timm/',
563
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
564
+ 'convnext_nano_ols.d1h_in1k': _cfg(
565
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
566
+ hf_hub_id='timm/',
567
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
568
+ 'convnext_tiny_hnf.a2h_in1k': _cfg(
569
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
570
+ hf_hub_id='timm/',
571
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
572
+
573
+ 'convnext_tiny.in12k_ft_in1k_384': _cfg(
574
+ hf_hub_id='timm/',
575
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
576
+ 'convnext_small.in12k_ft_in1k_384': _cfg(
577
+ hf_hub_id='timm/',
578
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
579
+
580
+ 'convnext_nano.in12k': _cfg(
581
+ hf_hub_id='timm/',
582
+ crop_pct=0.95, num_classes=11821),
583
+ 'convnext_tiny.in12k': _cfg(
584
+ hf_hub_id='timm/',
585
+ crop_pct=0.95, num_classes=11821),
586
+ 'convnext_small.in12k': _cfg(
587
+ hf_hub_id='timm/',
588
+ crop_pct=0.95, num_classes=11821),
589
+
590
+ 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
591
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
592
+ hf_hub_id='timm/',
593
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
594
+ 'convnext_small.fb_in22k_ft_in1k': _cfg(
595
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
596
+ hf_hub_id='timm/',
597
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
598
+ 'convnext_base.fb_in22k_ft_in1k': _cfg(
599
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
600
+ hf_hub_id='timm/',
601
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
602
+ 'convnext_large.fb_in22k_ft_in1k': _cfg(
603
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
604
+ hf_hub_id='timm/',
605
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
606
+ 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
607
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
608
+ hf_hub_id='timm/',
609
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
610
+
611
+ 'convnext_tiny.fb_in1k': _cfg(
612
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
613
+ hf_hub_id='timm/',
614
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
615
+ 'convnext_small.fb_in1k': _cfg(
616
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
617
+ hf_hub_id='timm/',
618
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
619
+ 'convnext_base.fb_in1k': _cfg(
620
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
621
+ hf_hub_id='timm/',
622
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
623
+ 'convnext_large.fb_in1k': _cfg(
624
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
625
+ hf_hub_id='timm/',
626
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
627
+
628
+ 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
629
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
630
+ hf_hub_id='timm/',
631
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
632
+ 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
633
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
634
+ hf_hub_id='timm/',
635
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
636
+ 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
637
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
638
+ hf_hub_id='timm/',
639
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
640
+ 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
641
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
642
+ hf_hub_id='timm/',
643
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
644
+ 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
645
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
646
+ hf_hub_id='timm/',
647
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
648
+
649
+ 'convnext_tiny.fb_in22k': _cfg(
650
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
651
+ hf_hub_id='timm/',
652
+ num_classes=21841),
653
+ 'convnext_small.fb_in22k': _cfg(
654
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
655
+ hf_hub_id='timm/',
656
+ num_classes=21841),
657
+ 'convnext_base.fb_in22k': _cfg(
658
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
659
+ hf_hub_id='timm/',
660
+ num_classes=21841),
661
+ 'convnext_large.fb_in22k': _cfg(
662
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
663
+ hf_hub_id='timm/',
664
+ num_classes=21841),
665
+ 'convnext_xlarge.fb_in22k': _cfg(
666
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
667
+ hf_hub_id='timm/',
668
+ num_classes=21841),
669
+
670
+ 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
671
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
672
+ hf_hub_id='timm/',
673
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
674
+ 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
675
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
676
+ hf_hub_id='timm/',
677
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
678
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
679
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
680
+ hf_hub_id='timm/',
681
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
682
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
683
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
684
+ hf_hub_id='timm/',
685
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
686
+ 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
687
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
688
+ hf_hub_id='timm/',
689
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
690
+ 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
691
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
692
+ hf_hub_id='timm/',
693
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
694
+ 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
695
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
696
+ hf_hub_id='timm/',
697
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
698
+ 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
699
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
700
+ hf_hub_id='timm/',
701
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
702
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
703
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
704
+ hf_hub_id='timm/',
705
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
706
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
707
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
708
+ hf_hub_id='timm/',
709
+ input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
710
+
711
+ 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
712
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
713
+ hf_hub_id='timm/',
714
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
715
+ 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
716
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
717
+ hf_hub_id='timm/',
718
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
719
+ 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
720
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
721
+ hf_hub_id='timm/',
722
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
723
+ 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
724
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
725
+ hf_hub_id='timm/',
726
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
727
+ 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
728
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
729
+ hf_hub_id='timm/',
730
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
731
+ 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
732
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
733
+ hf_hub_id='timm/',
734
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
735
+ 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
736
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
737
+ hf_hub_id='timm/',
738
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
739
+ 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
740
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
741
+ hf_hub_id='timm/',
742
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
743
+
744
+ 'convnextv2_atto.fcmae': _cfgv2(
745
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
746
+ hf_hub_id='timm/',
747
+ num_classes=0),
748
+ 'convnextv2_femto.fcmae': _cfgv2(
749
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
750
+ hf_hub_id='timm/',
751
+ num_classes=0),
752
+ 'convnextv2_pico.fcmae': _cfgv2(
753
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
754
+ hf_hub_id='timm/',
755
+ num_classes=0),
756
+ 'convnextv2_nano.fcmae': _cfgv2(
757
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
758
+ hf_hub_id='timm/',
759
+ num_classes=0),
760
+ 'convnextv2_tiny.fcmae': _cfgv2(
761
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
762
+ hf_hub_id='timm/',
763
+ num_classes=0),
764
+ 'convnextv2_base.fcmae': _cfgv2(
765
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
766
+ hf_hub_id='timm/',
767
+ num_classes=0),
768
+ 'convnextv2_large.fcmae': _cfgv2(
769
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
770
+ hf_hub_id='timm/',
771
+ num_classes=0),
772
+ 'convnextv2_huge.fcmae': _cfgv2(
773
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
774
+ hf_hub_id='timm/',
775
+ num_classes=0),
776
+
777
+ 'convnextv2_small.untrained': _cfg(),
778
+
779
+ # CLIP weights, fine-tuned on in1k or in12k + in1k
780
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
781
+ hf_hub_id='timm/',
782
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
783
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
784
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
785
+ hf_hub_id='timm/',
786
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
787
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
788
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
789
+ hf_hub_id='timm/',
790
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
791
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
792
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
793
+ hf_hub_id='timm/',
794
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
795
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
796
+
797
+ 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
798
+ hf_hub_id='timm/',
799
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
800
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
801
+ 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
802
+ hf_hub_id='timm/',
803
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
804
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
805
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
806
+ hf_hub_id='timm/',
807
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
808
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
809
+ ),
810
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
811
+ hf_hub_id='timm/',
812
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
813
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
814
+ ),
815
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
816
+ hf_hub_id='timm/',
817
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
818
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
819
+
820
+ 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
821
+ hf_hub_id='timm/',
822
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
823
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
824
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
825
+ hf_hub_id='timm/',
826
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
827
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
828
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
829
+ hf_hub_id='timm/',
830
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
831
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
832
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
833
+ hf_hub_id='timm/',
834
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
835
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
836
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
837
+ hf_hub_id='timm/',
838
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
839
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
840
+
841
+ # CLIP original image tower weights
842
+ 'convnext_base.clip_laion2b': _cfg(
843
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
844
+ hf_hub_filename='open_clip_pytorch_model.bin',
845
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
846
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
847
+ 'convnext_base.clip_laion2b_augreg': _cfg(
848
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
849
+ hf_hub_filename='open_clip_pytorch_model.bin',
850
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
851
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
852
+ 'convnext_base.clip_laiona': _cfg(
853
+ hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
854
+ hf_hub_filename='open_clip_pytorch_model.bin',
855
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
856
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
857
+ 'convnext_base.clip_laiona_320': _cfg(
858
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
859
+ hf_hub_filename='open_clip_pytorch_model.bin',
860
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
861
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
862
+ 'convnext_base.clip_laiona_augreg_320': _cfg(
863
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
864
+ hf_hub_filename='open_clip_pytorch_model.bin',
865
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
866
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
867
+ 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
868
+ hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
869
+ hf_hub_filename='open_clip_pytorch_model.bin',
870
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
871
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
872
+ 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
873
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
874
+ hf_hub_filename='open_clip_pytorch_model.bin',
875
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
876
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
877
+ 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
878
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
879
+ hf_hub_filename='open_clip_pytorch_model.bin',
880
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
881
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
882
+ 'convnext_xxlarge.clip_laion2b_soup': _cfg(
883
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
884
+ hf_hub_filename='open_clip_pytorch_model.bin',
885
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
886
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
887
+ 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
888
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
889
+ hf_hub_filename='open_clip_pytorch_model.bin',
890
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
891
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
892
+ })
893
+
894
+
895
+ # @register_model
896
+ # def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
897
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
898
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
899
+ # model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
900
+ # return model
901
+
902
+
903
+ # @register_model
904
+ # def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
905
+ # # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
906
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
907
+ # model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
908
+ # return model
909
+
910
+
911
+ # @register_model
912
+ # def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
913
+ # # timm femto variant
914
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
915
+ # model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
916
+ # return model
917
+
918
+
919
+ # @register_model
920
+ # def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
921
+ # # timm femto variant
922
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
923
+ # model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
924
+ # return model
925
+
926
+
927
+ # @register_model
928
+ # def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
929
+ # # timm pico variant
930
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
931
+ # model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
932
+ # return model
933
+
934
+
935
+ # @register_model
936
+ # def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
937
+ # # timm nano variant with overlapping 3x3 conv stem
938
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
939
+ # model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
940
+ # return model
941
+
942
+
943
+ # @register_model
944
+ # def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
945
+ # # timm nano variant with standard stem and head
946
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
947
+ # model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
948
+ # return model
949
+
950
+
951
+ # @register_model
952
+ # def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
953
+ # # experimental nano variant with overlapping conv stem
954
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
955
+ # model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
956
+ # return model
957
+
958
+
959
+ # @register_model
960
+ # def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
961
+ # # experimental tiny variant with norm before pooling in head (head norm first)
962
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
963
+ # model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
964
+ # return model
965
+
966
+
967
+ # @register_model
968
+ # def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
969
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
970
+ # model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
971
+ # return model
972
+
973
+
974
+ # @register_model
975
+ # def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
976
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
977
+ # model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
978
+ # return model
979
+
980
+ # @register_model
981
+ # def convnext_base_clip(pretrained='', **kwargs) -> ConvNeXt:
982
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
983
+ # model = _create_convnext(pretrained, pretrained=True, **dict(model_args, **kwargs))
984
+ # return model
985
+
986
+ # @register_model
987
+ # def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
988
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
989
+ # model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
990
+ # return model
991
+
992
+
993
+ # @register_model
994
+ # def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
995
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
996
+ # model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
997
+ # return model
998
+
999
+
1000
+ # @register_model
1001
+ # def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
1002
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
1003
+ # model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
1004
+ # return model
1005
+
1006
+
1007
+ # @register_model
1008
+ # def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
1009
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
1010
+ # model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1011
+ # return model
1012
+
1013
+
1014
+ # @register_model
1015
+ def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
1016
+ model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
1017
+ model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1018
+ return model
1019
+
1020
+
1021
+ # @register_model
1022
+ # def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
1023
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1024
+ # model_args = dict(
1025
+ # depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
1026
+ # model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
1027
+ # return model
1028
+
1029
+
1030
+ # @register_model
1031
+ # def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
1032
+ # # timm femto variant
1033
+ # model_args = dict(
1034
+ # depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
1035
+ # model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
1036
+ # return model
1037
+
1038
+
1039
+ # @register_model
1040
+ # def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
1041
+ # # timm pico variant
1042
+ # model_args = dict(
1043
+ # depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
1044
+ # model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
1045
+ # return model
1046
+
1047
+
1048
+ # @register_model
1049
+ # def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
1050
+ # # timm nano variant with standard stem and head
1051
+ # model_args = dict(
1052
+ # depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
1053
+ # model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
1054
+ # return model
1055
+
1056
+
1057
+ # @register_model
1058
+ # def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
1059
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
1060
+ # model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
1061
+ # return model
1062
+
1063
+
1064
+ # @register_model
1065
+ # def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
1066
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
1067
+ # model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
1068
+ # return model
1069
+
1070
+
1071
+ # @register_model
1072
+ # def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
1073
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
1074
+ # model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
1075
+ # return model
1076
+
1077
+
1078
+ # @register_model
1079
+ # def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
1080
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
1081
+ # model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
1082
+ # return model
1083
+
1084
+
1085
+ # @register_model
1086
+ # def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
1087
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
1088
+ # model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
1089
+ # return model
1090
+
1091
+
1092
+ # register_model_deprecations(__name__, {
1093
+ # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
1094
+ # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
1095
+ # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
1096
+ # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
1097
+ # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
1098
+ # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
1099
+ # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
1100
+ # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
1101
+ # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
1102
+ # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
1103
+ # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
1104
+ # 'convnext_small_in22k': 'convnext_small.fb_in22k',
1105
+ # 'convnext_base_in22k': 'convnext_base.fb_in22k',
1106
+ # 'convnext_large_in22k': 'convnext_large.fb_in22k',
1107
+ # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
1108
+ # })
EAGLE/eagle/model/multimodal_encoder/vision_models/eva_vit.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file is modified from https://github.com/baaivision/EVA
16
+
17
+
18
+ import os
19
+ import fvcore.nn.weight_init as weight_init
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import math
24
+ import numpy as np
25
+ import logging
26
+ from functools import partial
27
+ from scipy import interpolate
28
+ from math import pi
29
+ from einops import rearrange, repeat
30
+ import warnings
31
+ from PIL import Image
32
+ import torch.utils.checkpoint as cp
33
+ from transformers import CLIPImageProcessor
34
+ # from ..utils.attention import FlashAttention, FlashMHA
35
+ # try:
36
+ # import xformers.ops as xops
37
+ # except:
38
+ # pass
39
+
40
+ logger = logging.getLogger(__name__)
41
+ BatchNorm2d = torch.nn.BatchNorm2d
42
+
43
+ class Conv2d(torch.nn.Conv2d):
44
+ """
45
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
46
+ """
47
+
48
+ def __init__(self, *args, **kwargs):
49
+ """
50
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
51
+ Args:
52
+ norm (nn.Module, optional): a normalization layer
53
+ activation (callable(Tensor) -> Tensor): a callable activation function
54
+ It assumes that norm layer is used before activation.
55
+ """
56
+ norm = kwargs.pop("norm", None)
57
+ activation = kwargs.pop("activation", None)
58
+ super().__init__(*args, **kwargs)
59
+
60
+ self.norm = norm
61
+ self.activation = activation
62
+
63
+ def forward(self, x):
64
+ # torchscript does not support SyncBatchNorm yet
65
+ # https://github.com/pytorch/pytorch/issues/40507
66
+ # and we skip these codes in torchscript since:
67
+ # 1. currently we only support torchscript in evaluation mode
68
+ # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
69
+ # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
70
+ if not torch.jit.is_scripting():
71
+ with warnings.catch_warnings(record=True):
72
+ if x.numel() == 0 and self.training:
73
+ # https://github.com/pytorch/pytorch/issues/12013
74
+ assert not isinstance(
75
+ self.norm, torch.nn.SyncBatchNorm
76
+ ), "SyncBatchNorm does not support empty inputs!"
77
+
78
+ x = F.conv2d(
79
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
80
+ )
81
+ if self.norm is not None:
82
+ x = self.norm(x)
83
+ if self.activation is not None:
84
+ x = self.activation(x)
85
+ return x
86
+
87
+
88
+ def window_partition(x, window_size):
89
+ """
90
+ Partition into non-overlapping windows with padding if needed.
91
+ Args:
92
+ x (tensor): input tokens with [B, H, W, C].
93
+ window_size (int): window size.
94
+ Returns:
95
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
96
+ (Hp, Wp): padded height and width before partition
97
+ """
98
+ B, H, W, C = x.shape
99
+
100
+ pad_h = (window_size - H % window_size) % window_size
101
+ pad_w = (window_size - W % window_size) % window_size
102
+ if pad_h > 0 or pad_w > 0:
103
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
104
+ Hp, Wp = H + pad_h, W + pad_w
105
+
106
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
107
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
108
+ return windows, (Hp, Wp)
109
+
110
+
111
+ def window_unpartition(windows, window_size, pad_hw, hw):
112
+ """
113
+ Window unpartition into original sequences and removing padding.
114
+ Args:
115
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
116
+ window_size (int): window size.
117
+ pad_hw (Tuple): padded height and width (Hp, Wp).
118
+ hw (Tuple): original height and width (H, W) before padding.
119
+ Returns:
120
+ x: unpartitioned sequences with [B, H, W, C].
121
+ """
122
+ Hp, Wp = pad_hw
123
+ H, W = hw
124
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
125
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
126
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
127
+
128
+ if Hp > H or Wp > W:
129
+ x = x[:, :H, :W, :].contiguous()
130
+ return x
131
+
132
+
133
+ def get_rel_pos(q_size, k_size, rel_pos):
134
+ """
135
+ Get relative positional embeddings according to the relative positions of
136
+ query and key sizes.
137
+ Args:
138
+ q_size (int): size of query q.
139
+ k_size (int): size of key k.
140
+ rel_pos (Tensor): relative position embeddings (L, C).
141
+ Returns:
142
+ Extracted positional embeddings according to relative positions.
143
+ """
144
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
145
+ use_log_interpolation = True
146
+
147
+ # Interpolate rel pos if needed.
148
+ if rel_pos.shape[0] != max_rel_dist:
149
+ if not use_log_interpolation:
150
+ # Interpolate rel pos.
151
+ rel_pos_resized = F.interpolate(
152
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
153
+ size=max_rel_dist,
154
+ mode="linear",
155
+ )
156
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
157
+ else:
158
+ src_size = rel_pos.shape[0]
159
+ dst_size = max_rel_dist
160
+
161
+ # q = 1.13492
162
+ q = 1.0903078
163
+ dis = []
164
+
165
+ cur = 1
166
+ for i in range(src_size // 2):
167
+ dis.append(cur)
168
+ cur += q ** (i + 1)
169
+
170
+ r_ids = [-_ for _ in reversed(dis)]
171
+ x = r_ids + [0] + dis
172
+ t = dst_size // 2.0
173
+ dx = np.arange(-t, t + 0.1, 1.0)
174
+ all_rel_pos_bias = []
175
+ for i in range(rel_pos.shape[1]):
176
+ z = rel_pos[:, i].view(src_size).cpu().float().numpy()
177
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
178
+ all_rel_pos_bias.append(
179
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
180
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
181
+ else:
182
+ rel_pos_resized = rel_pos
183
+
184
+ # Scale the coords with short length if shapes for q and k are different.
185
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
186
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
187
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
188
+
189
+ return rel_pos_resized[relative_coords.long()]
190
+
191
+
192
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
193
+ """
194
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
195
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
196
+ Args:
197
+ attn (Tensor): attention map.
198
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
199
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
200
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
201
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
202
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
203
+ Returns:
204
+ attn (Tensor): attention map with added relative positional embeddings.
205
+ """
206
+ q_h, q_w = q_size
207
+ k_h, k_w = k_size
208
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
209
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
210
+
211
+ B, _, dim = q.shape
212
+ r_q = q.reshape(B, q_h, q_w, dim)
213
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
214
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
215
+
216
+ attn = (
217
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
218
+ ).view(B, q_h * q_w, k_h * k_w)
219
+
220
+ return attn
221
+
222
+
223
+ def get_abs_pos(abs_pos, has_cls_token, hw):
224
+ """
225
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
226
+ dimension for the original embeddings.
227
+ Args:
228
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
229
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
230
+ hw (Tuple): size of input image tokens.
231
+ Returns:
232
+ Absolute positional embeddings after processing with shape (1, H, W, C)
233
+ """
234
+ h, w = hw
235
+ if has_cls_token:
236
+ abs_pos = abs_pos[:, 1:]
237
+ xy_num = abs_pos.shape[1]
238
+ size = int(math.sqrt(xy_num))
239
+ assert size * size == xy_num
240
+
241
+ if size != h or size != w:
242
+ original_datatype = abs_pos.dtype
243
+ new_abs_pos = F.interpolate(
244
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented
245
+ size=(h, w),
246
+ mode="bicubic",
247
+ align_corners=False,
248
+ ).to(original_datatype)
249
+
250
+ return new_abs_pos.permute(0, 2, 3, 1)
251
+ else:
252
+ return abs_pos.reshape(1, h, w, -1)
253
+
254
+
255
+ class PatchEmbed(nn.Module):
256
+ """
257
+ Image to Patch Embedding.
258
+ """
259
+
260
+ def __init__(
261
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
262
+ ):
263
+ """
264
+ Args:
265
+ kernel_size (Tuple): kernel size of the projection layer.
266
+ stride (Tuple): stride of the projection layer.
267
+ padding (Tuple): padding size of the projection layer.
268
+ in_chans (int): Number of input image channels.
269
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
270
+ """
271
+ super().__init__()
272
+
273
+ self.proj = nn.Conv2d(
274
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
275
+ )
276
+
277
+ def forward(self, x):
278
+ x = self.proj(x)
279
+ # B C H W -> B H W C
280
+ x = x.permute(0, 2, 3, 1)
281
+ return x
282
+
283
+
284
+ def broadcat(tensors, dim = -1):
285
+ num_tensors = len(tensors)
286
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
287
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
288
+ shape_len = list(shape_lens)[0]
289
+ dim = (dim + shape_len) if dim < 0 else dim
290
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
291
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
292
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
293
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
294
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
295
+ expanded_dims.insert(dim, (dim, dims[dim]))
296
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
297
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
298
+ return torch.cat(tensors, dim = dim)
299
+
300
+
301
+
302
+ def rotate_half(x):
303
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
304
+ x1, x2 = x.unbind(dim = -1)
305
+ x = torch.stack((-x2, x1), dim = -1)
306
+ return rearrange(x, '... d r -> ... (d r)')
307
+
308
+
309
+
310
+ class VisionRotaryEmbedding(nn.Module):
311
+ def __init__(
312
+ self,
313
+ dim,
314
+ pt_seq_len,
315
+ ft_seq_len=None,
316
+ custom_freqs = None,
317
+ freqs_for = 'lang',
318
+ theta = 10000,
319
+ max_freq = 10,
320
+ num_freqs = 1,
321
+ ):
322
+ super().__init__()
323
+ if custom_freqs:
324
+ freqs = custom_freqs
325
+ elif freqs_for == 'lang':
326
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
327
+ elif freqs_for == 'pixel':
328
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
329
+ elif freqs_for == 'constant':
330
+ freqs = torch.ones(num_freqs).float()
331
+ else:
332
+ raise ValueError(f'unknown modality {freqs_for}')
333
+
334
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
335
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
336
+
337
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
338
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
339
+
340
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
341
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
342
+
343
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
344
+
345
+ self.register_buffer("freqs_cos", freqs.cos())
346
+ self.register_buffer("freqs_sin", freqs.sin())
347
+
348
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
349
+
350
+ def forward(self, t, start_index = 0):
351
+ rot_dim = self.freqs_cos.shape[-1]
352
+ end_index = start_index + rot_dim
353
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
354
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
355
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
356
+ return torch.cat((t_left, t, t_right), dim = -1)
357
+
358
+
359
+
360
+
361
+ class VisionRotaryEmbeddingFast(nn.Module):
362
+ def __init__(
363
+ self,
364
+ dim,
365
+ pt_seq_len=16,
366
+ ft_seq_len=None,
367
+ custom_freqs = None,
368
+ freqs_for = 'lang',
369
+ theta = 10000,
370
+ max_freq = 10,
371
+ num_freqs = 1,
372
+ ):
373
+ super().__init__()
374
+ if custom_freqs:
375
+ freqs = custom_freqs
376
+ elif freqs_for == 'lang':
377
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
378
+ elif freqs_for == 'pixel':
379
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
380
+ elif freqs_for == 'constant':
381
+ freqs = torch.ones(num_freqs).float()
382
+ else:
383
+ raise ValueError(f'unknown modality {freqs_for}')
384
+
385
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
386
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
387
+
388
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
389
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
390
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
391
+
392
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
393
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
394
+
395
+ self.register_buffer("freqs_cos", freqs_cos)
396
+ self.register_buffer("freqs_sin", freqs_sin)
397
+
398
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
399
+
400
+ def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
401
+
402
+
403
+ class FrozenBatchNorm2d(nn.Module):
404
+ """
405
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
406
+ It contains non-trainable buffers called
407
+ "weight" and "bias", "running_mean", "running_var",
408
+ initialized to perform identity transformation.
409
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
410
+ which are computed from the original four parameters of BN.
411
+ The affine transform `x * weight + bias` will perform the equivalent
412
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
413
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
414
+ will be left unchanged as identity transformation.
415
+ Other pre-trained backbone models may contain all 4 parameters.
416
+ The forward is implemented by `F.batch_norm(..., training=False)`.
417
+ """
418
+
419
+ _version = 3
420
+
421
+ def __init__(self, num_features, eps=1e-5):
422
+ super().__init__()
423
+ self.num_features = num_features
424
+ self.eps = eps
425
+ self.register_buffer("weight", torch.ones(num_features))
426
+ self.register_buffer("bias", torch.zeros(num_features))
427
+ self.register_buffer("running_mean", torch.zeros(num_features))
428
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
429
+
430
+ def forward(self, x):
431
+ if x.requires_grad:
432
+ # When gradients are needed, F.batch_norm will use extra memory
433
+ # because its backward op computes gradients for weight/bias as well.
434
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
435
+ bias = self.bias - self.running_mean * scale
436
+ scale = scale.reshape(1, -1, 1, 1)
437
+ bias = bias.reshape(1, -1, 1, 1)
438
+ out_dtype = x.dtype # may be half
439
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
440
+ else:
441
+ # When gradients are not needed, F.batch_norm is a single fused op
442
+ # and provide more optimization opportunities.
443
+ return F.batch_norm(
444
+ x,
445
+ self.running_mean,
446
+ self.running_var,
447
+ self.weight,
448
+ self.bias,
449
+ training=False,
450
+ eps=self.eps,
451
+ )
452
+
453
+ def _load_from_state_dict(
454
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
455
+ ):
456
+ version = local_metadata.get("version", None)
457
+
458
+ if version is None or version < 2:
459
+ # No running_mean/var in early versions
460
+ # This will silent the warnings
461
+ if prefix + "running_mean" not in state_dict:
462
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
463
+ if prefix + "running_var" not in state_dict:
464
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
465
+
466
+ super()._load_from_state_dict(
467
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
468
+ )
469
+
470
+ def __repr__(self):
471
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
472
+
473
+ @classmethod
474
+ def convert_frozen_batchnorm(cls, module):
475
+ """
476
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
477
+ Args:
478
+ module (torch.nn.Module):
479
+ Returns:
480
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
481
+ Otherwise, in-place convert module and return it.
482
+ Similar to convert_sync_batchnorm in
483
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
484
+ """
485
+ bn_module = nn.modules.batchnorm
486
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
487
+ res = module
488
+ if isinstance(module, bn_module):
489
+ res = cls(module.num_features)
490
+ if module.affine:
491
+ res.weight.data = module.weight.data.clone().detach()
492
+ res.bias.data = module.bias.data.clone().detach()
493
+ res.running_mean.data = module.running_mean.data
494
+ res.running_var.data = module.running_var.data
495
+ res.eps = module.eps
496
+ else:
497
+ for name, child in module.named_children():
498
+ new_child = cls.convert_frozen_batchnorm(child)
499
+ if new_child is not child:
500
+ res.add_module(name, new_child)
501
+ return res
502
+
503
+ class LayerNorm(nn.Module):
504
+ """
505
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
506
+ variance normalization over the channel dimension for inputs that have shape
507
+ (batch_size, channels, height, width).
508
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
509
+ """
510
+
511
+ def __init__(self, normalized_shape, eps=1e-6):
512
+ super().__init__()
513
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
514
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
515
+ self.eps = eps
516
+ self.normalized_shape = (normalized_shape,)
517
+
518
+ def forward(self, x):
519
+ u = x.mean(1, keepdim=True)
520
+ s = (x - u).pow(2).mean(1, keepdim=True)
521
+ x = (x - u) / torch.sqrt(s + self.eps)
522
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
523
+ return x
524
+
525
+
526
+ class CNNBlockBase(nn.Module):
527
+ """
528
+ A CNN block is assumed to have input channels, output channels and a stride.
529
+ The input and output of `forward()` method must be NCHW tensors.
530
+ The method can perform arbitrary computation but must match the given
531
+ channels and stride specification.
532
+ Attribute:
533
+ in_channels (int):
534
+ out_channels (int):
535
+ stride (int):
536
+ """
537
+
538
+ def __init__(self, in_channels, out_channels, stride):
539
+ """
540
+ The `__init__` method of any subclass should also contain these arguments.
541
+ Args:
542
+ in_channels (int):
543
+ out_channels (int):
544
+ stride (int):
545
+ """
546
+ super().__init__()
547
+ self.in_channels = in_channels
548
+ self.out_channels = out_channels
549
+ self.stride = stride
550
+
551
+ def freeze(self):
552
+ """
553
+ Make this block not trainable.
554
+ This method sets all parameters to `requires_grad=False`,
555
+ and convert all BatchNorm layers to FrozenBatchNorm
556
+ Returns:
557
+ the block itself
558
+ """
559
+ for p in self.parameters():
560
+ p.requires_grad = False
561
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
562
+ return self
563
+
564
+ def get_norm(norm, out_channels):
565
+ """
566
+ Args:
567
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
568
+ or a callable that takes a channel number and returns
569
+ the normalization layer as a nn.Module.
570
+ Returns:
571
+ nn.Module or None: the normalization layer
572
+ """
573
+ if norm is None:
574
+ return None
575
+ if isinstance(norm, str):
576
+ if len(norm) == 0:
577
+ return None
578
+ norm = {
579
+ "BN": BatchNorm2d,
580
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
581
+ "SyncBN": nn.SyncBatchNorm,
582
+ "FrozenBN": FrozenBatchNorm2d,
583
+ "GN": lambda channels: nn.GroupNorm(32, channels),
584
+ # for debugging:
585
+ "nnSyncBN": nn.SyncBatchNorm,
586
+ "LN": lambda channels: LayerNorm(channels)
587
+ }[norm]
588
+ return norm(out_channels)
589
+
590
+ class DropPath(nn.Module):
591
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
592
+ """
593
+
594
+ def __init__(self, drop_prob=None):
595
+ super(DropPath, self).__init__()
596
+ self.drop_prob = drop_prob
597
+
598
+ def forward(self, x):
599
+ if self.drop_prob == 0. or not self.training:
600
+ return x
601
+ keep_prob = 1 - self.drop_prob
602
+ # work with diff dim tensors, not just 2D ConvNets
603
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
604
+ random_tensor = keep_prob + \
605
+ torch.rand(shape, dtype=x.dtype, device=x.device)
606
+ random_tensor.floor_() # binarize
607
+ output = x.div(keep_prob) * random_tensor
608
+ return output
609
+
610
+
611
+
612
+ class SwiGLU(nn.Module):
613
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
614
+ norm_layer=nn.LayerNorm, subln=False
615
+ ):
616
+ super().__init__()
617
+ out_features = out_features or in_features
618
+ hidden_features = hidden_features or in_features
619
+
620
+ self.w1 = nn.Linear(in_features, hidden_features)
621
+ self.w2 = nn.Linear(in_features, hidden_features)
622
+
623
+ self.act = act_layer()
624
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
625
+ self.w3 = nn.Linear(hidden_features, out_features)
626
+
627
+ self.drop = nn.Dropout(drop)
628
+
629
+ def forward(self, x):
630
+ x1 = self.w1(x)
631
+ x2 = self.w2(x)
632
+ hidden = self.act(x1) * x2
633
+ x = self.ffn_ln(hidden)
634
+ x = self.w3(x)
635
+ x = self.drop(x)
636
+ return x
637
+
638
+
639
+ class Attention(nn.Module):
640
+ def __init__(
641
+ self,
642
+ dim,
643
+ num_heads=8,
644
+ qkv_bias=True,
645
+ qk_scale=None,
646
+ attn_head_dim=None,
647
+ norm_layer=nn.LayerNorm,
648
+ rope=None,
649
+ xattn=True,
650
+ subln=False
651
+ ):
652
+ super().__init__()
653
+ self.num_heads = num_heads
654
+ head_dim = dim // num_heads
655
+ if attn_head_dim is not None:
656
+ head_dim = attn_head_dim
657
+ all_head_dim = head_dim * self.num_heads
658
+ self.scale = qk_scale or head_dim ** -0.5
659
+
660
+ self.subln = subln
661
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
662
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
663
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
664
+
665
+ if qkv_bias:
666
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
667
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
668
+ else:
669
+ self.q_bias = None
670
+ self.v_bias = None
671
+
672
+ self.rope = rope
673
+ self.xattn = xattn
674
+ self.proj = nn.Linear(all_head_dim, dim)
675
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
676
+
677
+ if self.xattn:
678
+ factory_kwargs = {'device': 'cuda', 'dtype': torch.float16}
679
+ self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs)
680
+
681
+ def forward(self, x):
682
+ B, H, W, C = x.shape
683
+ x = x.view(B, -1, C)
684
+ N = H * W
685
+
686
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
687
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
688
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
689
+
690
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
691
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
692
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
693
+
694
+ ## rope
695
+ q = self.rope(q).type_as(v)
696
+ k = self.rope(k).type_as(v)
697
+
698
+ if self.xattn:
699
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
700
+ k = k.permute(0, 2, 1, 3)
701
+ v = v.permute(0, 2, 1, 3)
702
+
703
+ kv = torch.stack([k, v], dim=2)
704
+ x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False)
705
+ # x = xops.memory_efficient_attention(q, k, v)
706
+ x = x.reshape(B, N, -1)
707
+ x = self.inner_attn_ln(x)
708
+ else:
709
+ q = q * self.scale
710
+ attn = (q @ k.transpose(-2, -1))
711
+ attn = attn.softmax(dim=-1).type_as(x)
712
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
713
+ x = self.inner_attn_ln(x)
714
+
715
+ x = self.proj(x)
716
+ x = x.view(B, H, W, C)
717
+
718
+ return x
719
+
720
+
721
+ class ResBottleneckBlock(CNNBlockBase):
722
+ """
723
+ The standard bottleneck residual block without the last activation layer.
724
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
725
+ """
726
+
727
+ def __init__(
728
+ self,
729
+ in_channels,
730
+ out_channels,
731
+ bottleneck_channels,
732
+ norm="LN",
733
+ act_layer=nn.GELU,
734
+ ):
735
+ """
736
+ Args:
737
+ in_channels (int): Number of input channels.
738
+ out_channels (int): Number of output channels.
739
+ bottleneck_channels (int): number of output channels for the 3x3
740
+ "bottleneck" conv layers.
741
+ norm (str or callable): normalization for all conv layers.
742
+ See :func:`layers.get_norm` for supported format.
743
+ act_layer (callable): activation for all conv layers.
744
+ """
745
+ super().__init__(in_channels, out_channels, 1)
746
+
747
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
748
+ self.norm1 = get_norm(norm, bottleneck_channels)
749
+ self.act1 = act_layer()
750
+
751
+ self.conv2 = Conv2d(
752
+ bottleneck_channels,
753
+ bottleneck_channels,
754
+ 3,
755
+ padding=1,
756
+ bias=False,
757
+ )
758
+ self.norm2 = get_norm(norm, bottleneck_channels)
759
+ self.act2 = act_layer()
760
+
761
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
762
+ self.norm3 = get_norm(norm, out_channels)
763
+
764
+ for layer in [self.conv1, self.conv2, self.conv3]:
765
+ weight_init.c2_msra_fill(layer)
766
+ for layer in [self.norm1, self.norm2]:
767
+ layer.weight.data.fill_(1.0)
768
+ layer.bias.data.zero_()
769
+ # zero init last norm layer.
770
+ self.norm3.weight.data.zero_()
771
+ self.norm3.bias.data.zero_()
772
+
773
+ def forward(self, x):
774
+ out = x
775
+ for layer in self.children():
776
+ out = layer(out)
777
+
778
+ out = x + out
779
+ return out
780
+
781
+
782
+ class Block(nn.Module):
783
+ """Transformer blocks with support of window attention and residual propagation blocks"""
784
+
785
+ def __init__(
786
+ self,
787
+ dim,
788
+ num_heads,
789
+ mlp_ratio=4*2/3,
790
+ qkv_bias=True,
791
+ drop_path=0.0,
792
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
793
+ window_size=0,
794
+ use_residual_block=False,
795
+ rope=None,
796
+ xattn=True,
797
+ subln=False,
798
+ # with_cp=True,
799
+ ):
800
+ """
801
+ Args:
802
+ dim (int): Number of input channels.
803
+ num_heads (int): Number of attention heads in each ViT block.
804
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
805
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
806
+ drop_path (float): Stochastic depth rate.
807
+ norm_layer (nn.Module): Normalization layer.
808
+ act_layer (nn.Module): Activation layer.
809
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
810
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
811
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
812
+ use window attention.
813
+ use_residual_block (bool): If True, use a residual block after the MLP block.
814
+ input_size (int or None): Input resolution for calculating the relative positional
815
+ parameter size.
816
+ """
817
+ super().__init__()
818
+ self.norm1 = norm_layer(dim)
819
+ self.attn = Attention(
820
+ dim,
821
+ num_heads=num_heads,
822
+ qkv_bias=qkv_bias,
823
+ rope=rope,
824
+ xattn=xattn,
825
+ subln=subln
826
+ )
827
+
828
+
829
+ # self.with_cp = with_cp
830
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
831
+ self.norm2 = norm_layer(dim)
832
+ self.mlp = SwiGLU(
833
+ in_features=dim,
834
+ hidden_features=int(dim * mlp_ratio),
835
+ subln=True,
836
+ norm_layer=norm_layer,
837
+ )
838
+
839
+ self.window_size = window_size
840
+
841
+ self.use_residual_block = use_residual_block
842
+ if use_residual_block:
843
+ # Use a residual block with bottleneck channel as dim // 2
844
+ self.residual = ResBottleneckBlock(
845
+ in_channels=dim,
846
+ out_channels=dim,
847
+ bottleneck_channels=dim // 2,
848
+ norm="LN",
849
+ )
850
+
851
+ def _forward(self, x):
852
+ shortcut = x
853
+ x = self.norm1(x)
854
+
855
+ # Window partition
856
+ if self.window_size > 0:
857
+ H, W = x.shape[1], x.shape[2]
858
+ x, pad_hw = window_partition(x, self.window_size)
859
+
860
+ x = self.attn(x)
861
+
862
+ # Reverse window partition
863
+ if self.window_size > 0:
864
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
865
+
866
+ x = shortcut + self.drop_path(x)
867
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
868
+
869
+ if self.use_residual_block:
870
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
871
+
872
+ return x
873
+
874
+ def forward(self, x, with_cp=False):
875
+ # if self.with_cp and self.training:
876
+ if with_cp:
877
+ x = cp.checkpoint(self._forward, x)
878
+ else:
879
+ x = self._forward(x)
880
+ return x
881
+
882
+ #@BACKBONES.register_module()
883
+ class EVAViT(nn.Module):
884
+ """
885
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
886
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
887
+ https://arxiv.org/abs/2203.16527
888
+ """
889
+
890
+ def __init__(
891
+ self,
892
+ img_size=1024,
893
+ patch_size=16,
894
+ in_chans=3,
895
+ embed_dim=768,
896
+ depth=12,
897
+ num_heads=12,
898
+ mlp_ratio=4*2/3,
899
+ qkv_bias=True,
900
+ drop_path_rate=0.0,
901
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
902
+ act_layer=nn.GELU,
903
+ use_abs_pos=True,
904
+ use_rel_pos=False,
905
+ # sim_fpn=None,
906
+ rope=True,
907
+ pt_hw_seq_len=16,
908
+ intp_freq=True,
909
+ window_size=0,
910
+ global_window_size=0,
911
+ window_block_indexes=(),
912
+ residual_block_indexes=(),
913
+ pretrain_img_size=224,
914
+ pretrain_use_cls_token=True,
915
+ out_feature="last_feat",
916
+ subln=False,
917
+ xattn=True,
918
+ # with_cp=True,
919
+ frozen=False,
920
+ ):
921
+ """
922
+ Args:
923
+ img_size (int): Input image size.
924
+ patch_size (int): Patch size.
925
+ in_chans (int): Number of input image channels.
926
+ embed_dim (int): Patch embedding dimension.
927
+ depth (int): Depth of ViT.
928
+ num_heads (int): Number of attention heads in each ViT block.
929
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
930
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
931
+ drop_path_rate (float): Stochastic depth rate.
932
+ norm_layer (nn.Module): Normalization layer.
933
+ act_layer (nn.Module): Activation layer.
934
+ use_abs_pos (bool): If True, use absolute positional embeddings.
935
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
936
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
937
+ window_size (int): Window size for window attention blocks.
938
+ window_block_indexes (list): Indexes for blocks using window attention.
939
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
940
+ use_act_checkpoint (bool): If True, use activation checkpointing.
941
+ pretrain_img_size (int): input image size for pretraining models.
942
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
943
+ out_feature (str): name of the feature from the last block.
944
+ """
945
+ super().__init__()
946
+ self.pretrain_use_cls_token = pretrain_use_cls_token
947
+ self.patch_embed = PatchEmbed(
948
+ kernel_size=(patch_size, patch_size),
949
+ stride=(patch_size, patch_size),
950
+ in_chans=in_chans,
951
+ embed_dim=embed_dim,
952
+ )
953
+ self.frozen = frozen
954
+ self.gradient_checkpointing = False
955
+
956
+ if use_abs_pos:
957
+ # Initialize absolute positional embedding with pretrain image size.
958
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
959
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
960
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
961
+ else:
962
+ self.pos_embed = None
963
+
964
+ half_head_dim = embed_dim // num_heads // 2
965
+ hw_seq_len = img_size // patch_size
966
+
967
+ self.rope_win = VisionRotaryEmbeddingFast(
968
+ dim=half_head_dim,
969
+ pt_seq_len=pt_hw_seq_len,
970
+ ft_seq_len=window_size if intp_freq else None,
971
+ )
972
+ self.rope_glb = VisionRotaryEmbeddingFast(
973
+ dim=half_head_dim,
974
+ pt_seq_len=pt_hw_seq_len,
975
+ ft_seq_len=hw_seq_len if intp_freq else None,
976
+ )
977
+
978
+ # stochastic depth decay rule
979
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
980
+
981
+ self.blocks = nn.ModuleList()
982
+ for i in range(depth):
983
+ block = Block(
984
+ dim=embed_dim,
985
+ num_heads=num_heads,
986
+ mlp_ratio=mlp_ratio,
987
+ qkv_bias=qkv_bias,
988
+ drop_path=dpr[i],
989
+ norm_layer=norm_layer,
990
+ window_size=window_size if i in window_block_indexes else global_window_size,
991
+ use_residual_block=i in residual_block_indexes,
992
+ rope=self.rope_win if i in window_block_indexes else self.rope_glb,
993
+ xattn=xattn,
994
+ subln=subln,
995
+ # with_cp=with_cp,
996
+ )
997
+
998
+ self.blocks.append(block)
999
+
1000
+ self._out_feature_channels = {out_feature: embed_dim}
1001
+ self._out_feature_strides = {out_feature: patch_size}
1002
+ self._out_features = [out_feature]
1003
+
1004
+ if self.pos_embed is not None:
1005
+ nn.init.normal_(self.pos_embed, std=0.02)
1006
+
1007
+ self._freeze_stages()
1008
+
1009
+ def _freeze_stages(self):
1010
+ if self.frozen:
1011
+ self.eval()
1012
+ for m in self.parameters():
1013
+ m.requires_grad = False
1014
+
1015
+ def forward(self, x):
1016
+ x = self.patch_embed(x)
1017
+ if self.pos_embed is not None:
1018
+ x = x + get_abs_pos(
1019
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
1020
+ )
1021
+
1022
+ for blk in self.blocks:
1023
+ x = blk(x, with_cp=self.gradient_checkpointing) # b, h, w, c
1024
+ x = x.permute(0, 3, 1, 2) # b, c, h, w
1025
+
1026
+ return x
1027
+
1028
+
1029
+ class EVAVITVisionTower(nn.Module):
1030
+ def __init__(self, vision_tower, args, delay_load=False):
1031
+ super().__init__()
1032
+
1033
+ self.is_loaded = False
1034
+ self.vision_tower_name = vision_tower
1035
+ self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
1036
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
1037
+
1038
+ self.args = args
1039
+ self.vision_tower, vision_tower_config = build_eva_vit(args=args,
1040
+ model_name=vision_tower,
1041
+ image_size=args.input_image_size
1042
+ )
1043
+ self.input_image_size=args.input_image_size
1044
+ self.vision_tower.config = vision_tower_config
1045
+ self.freeze_vision = args.freeze_vision
1046
+
1047
+ if not self.is_loaded:
1048
+ self.load_model()
1049
+
1050
+
1051
+ def load_model(self):
1052
+ if self.is_loaded:
1053
+ return
1054
+
1055
+ # hardcode
1056
+ self.image_processor = CLIPImageProcessor(crop_size={"height": self.args.input_image_size, "width": self.args.input_image_size},
1057
+ size={'shortest_edge': self.args.input_image_size},
1058
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
1059
+ image_std=[0.26862954, 0.26130258, 0.27577711])
1060
+
1061
+ # load weights
1062
+ if self.args.vision_tower_pretrained_from is not None:
1063
+ if not os.path.exists(self.args.vision_tower_pretrained_from):
1064
+ import warnings
1065
+ warnings.warn("The vision tower weights for EVA-02 vision tower does not exists, this will cause problem if you are training the model from scratch!")
1066
+ self.is_loaded = True
1067
+ return
1068
+
1069
+ pretrained_params = torch.load(self.args.vision_tower_pretrained_from)
1070
+ if 'ema_state' in pretrained_params:
1071
+ pretrained_params = pretrained_params['ema_state']
1072
+ elif 'module' in pretrained_params:
1073
+ pretrained_params = pretrained_params['module']
1074
+
1075
+ from collections import OrderedDict
1076
+ new_params = OrderedDict()
1077
+
1078
+ kw = ""
1079
+ if "det" in self.args.vision_tower_pretrained_from.lower():
1080
+ kw = "backbone.net."
1081
+ elif "clip" in self.args.vision_tower_pretrained_from.lower():
1082
+ kw = "visual."
1083
+
1084
+ for k, v in pretrained_params.items():
1085
+ if len(kw) > 0:
1086
+ if kw in k and ("rope" not in k):
1087
+ new_params[k.replace(kw, "")] = v
1088
+ else:
1089
+ if "rope" not in k:
1090
+ new_params[k] = v
1091
+
1092
+ incompatiblekeys = self.vision_tower.load_state_dict(new_params, strict=False)
1093
+ for k in incompatiblekeys[0]:
1094
+ if "rope" not in k:
1095
+ warnings.warn(f"Find incompatible keys {k} in state dict.")
1096
+
1097
+
1098
+ if self.freeze_vision:
1099
+ self.vision_tower.requires_grad_(False)
1100
+
1101
+ self.is_loaded = True
1102
+
1103
+
1104
+ # @torch.no_grad()
1105
+ def forward(self, images):
1106
+ if type(images) is list:
1107
+ image_features = []
1108
+ for image in images:
1109
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
1110
+ image_feature = image_forward_out.flatten(2,3).transpose(1,2) # b, n, c
1111
+ image_features.append(image_feature)
1112
+ else:
1113
+ image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
1114
+
1115
+ return image_forward_out
1116
+
1117
+ @property
1118
+ def dummy_feature(self):
1119
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
1120
+
1121
+ @property
1122
+ def dtype(self):
1123
+ return next(self.vision_tower.parameters()).dtype
1124
+
1125
+ @property
1126
+ def device(self):
1127
+ return next(self.vision_tower.parameters()).device
1128
+
1129
+ @property
1130
+ def config(self):
1131
+ return self.vision_tower.config
1132
+
1133
+ @property
1134
+ def hidden_size(self):
1135
+ #return self.config.hidden_size
1136
+ return self.config['hidden_dim']
1137
+
1138
+ @property
1139
+ def num_patches(self):
1140
+ # return (self.config.image_size // self.config.patch_size) ** 2
1141
+ return self.config['num_patches']
1142
+
1143
+
1144
+ def build_eva_vit(args,
1145
+ model_name=None,
1146
+ image_size=224,
1147
+ window_attn=True
1148
+ ):
1149
+
1150
+ if "336" in args.vision_tower_pretrained_from:
1151
+ pretrained_image_size = 336
1152
+ else:
1153
+ pretrained_image_size = 224
1154
+
1155
+ if "clip" in args.vision_tower_pretrained_from.lower():
1156
+ subln = True
1157
+ else:
1158
+ subln = False
1159
+
1160
+ if model_name == 'eva02-l-16':
1161
+ # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
1162
+ if window_attn:
1163
+ window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
1164
+ else:
1165
+ window_block_indexes = ()
1166
+
1167
+ model = EVAViT(
1168
+ img_size=image_size,
1169
+ patch_size=16,
1170
+ window_size=16,
1171
+ in_chans=3,
1172
+ embed_dim=1024,
1173
+ depth=24,
1174
+ num_heads=16,
1175
+ mlp_ratio=4*2/3,
1176
+ window_block_indexes = window_block_indexes,
1177
+ qkv_bias=True,
1178
+ drop_path_rate=0.0,
1179
+ xattn=False,
1180
+ # with_cp=False,
1181
+ # frozen=True,
1182
+ )
1183
+ # image_size = 224 # HARDCODE
1184
+ eva_config = dict(image_size=image_size,
1185
+ patch_size=16,
1186
+ window_size=16,
1187
+ hidden_dim=1024,
1188
+ depth=24,
1189
+ num_heads=16,
1190
+ window_block_indexes=window_block_indexes,
1191
+ num_patches=image_size ** 2 // 16 ** 2,
1192
+ pretrained_from=args.vision_tower_pretrained_from
1193
+ )
1194
+
1195
+ elif model_name == 'eva02-l-14':
1196
+ # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
1197
+ if window_attn:
1198
+ window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
1199
+ else:
1200
+ window_block_indexes = ()
1201
+
1202
+ model = EVAViT(
1203
+ img_size=image_size,
1204
+ pretrain_img_size=pretrained_image_size,
1205
+ patch_size=14,
1206
+ window_size=16,
1207
+ in_chans=3,
1208
+ embed_dim=1024,
1209
+ depth=24,
1210
+ num_heads=16,
1211
+ mlp_ratio=4*2/3,
1212
+ window_block_indexes = window_block_indexes,
1213
+ qkv_bias=True,
1214
+ drop_path_rate=0.0,
1215
+ xattn=False,
1216
+ # with_cp=False,
1217
+ subln=subln,
1218
+ # frozen=True,
1219
+ )
1220
+ # image_size = 224 # HARDCODE
1221
+ eva_config = dict(image_size=image_size,
1222
+ patch_size=14,
1223
+ window_size=16,
1224
+ hidden_dim=1024,
1225
+ depth=24,
1226
+ num_heads=16,
1227
+ window_block_indexes=window_block_indexes,
1228
+ num_patches=image_size ** 2 // 14 ** 2,
1229
+ pretrained_from=args.vision_tower_pretrained_from
1230
+ )
1231
+
1232
+ else:
1233
+ raise NotImplementedError
1234
+
1235
+ return model, eva_config
EAGLE/eagle/model/multimodal_projector/__init__.py ADDED
File without changes
EAGLE/eagle/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ class IdentityMap(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_projector_type": 'identity'}
15
+
16
+
17
+ class SimpleResBlock(nn.Module):
18
+ def __init__(self, channels):
19
+ super().__init__()
20
+ self.pre_norm = nn.LayerNorm(channels)
21
+
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(channels, channels),
24
+ nn.GELU(),
25
+ nn.Linear(channels, channels)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, fpn_input_dim=[], **kwargs):
33
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
34
+
35
+ if projector_type == 'linear':
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == 'identity':
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f'Unknown projector type: {projector_type}')
EAGLE/lmms_eval/api/__init__.py ADDED
File without changes
EAGLE/lmms_eval/api/filter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from lmms_eval.api.instance import Instance
5
+ from datasets import Dataset
6
+
7
+
8
+ class Filter:
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ def apply(self, resps, docs):
23
+ """
24
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
25
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
26
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
27
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
28
+ """
29
+ return resps
30
+
31
+
32
+ @dataclass
33
+ class FilterEnsemble:
34
+ """
35
+ FilterEnsemble creates a pipeline applying multiple filters.
36
+ Its intended usage is to stack multiple post-processing steps in order.
37
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
38
+ pipeline separately.
39
+ """
40
+
41
+ name: str
42
+ filters: List[Filter]
43
+
44
+ def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
45
+ resps = [inst.resps for inst in instances] # operate just on the model responses
46
+ for f in self.filters:
47
+ # apply filters in sequence
48
+ resps = f.apply(resps, docs)
49
+
50
+ # add the end results after filtering to filtered_requests of their respective source instances.
51
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
52
+ for inst, resp in zip(instances, resps):
53
+ inst.filtered_resps[self.name] = resp
EAGLE/lmms_eval/api/instance.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal, Tuple
3
+
4
+
5
+ @dataclass
6
+ class Instance:
7
+ request_type: Literal["loglikelihood", "generate_until"]
8
+ arguments: tuple
9
+ idx: int
10
+ metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here
11
+ resps: list = field(default_factory=list)
12
+ filtered_resps: dict = field(default_factory=dict)
13
+
14
+ # initialized after init
15
+ task_name: str = None
16
+ doc_id: str = None
17
+ repeats: str = None
18
+ doc: dict = None
19
+
20
+ def __post_init__(self) -> None:
21
+ # unpack metadata field
22
+ self.task_name, self.doc_id, self.repeats = self.metadata
23
+
24
+ @property
25
+ def args(self):
26
+ """
27
+ Returns (string,) where `string` is the string to calculate loglikelihood over
28
+ """
29
+ return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
EAGLE/lmms_eval/api/metrics.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Iterable
3
+
4
+ import numpy as np
5
+ import sacrebleu
6
+ import sklearn.metrics
7
+ import random
8
+ import evaluate
9
+ import torch
10
+
11
+ from lmms_eval.api.registry import register_metric, register_aggregation
12
+
13
+ import logging
14
+
15
+ eval_logger = logging.getLogger("lmms-eval")
16
+
17
+
18
+ # Register Aggregations First
19
+ @register_aggregation("mean")
20
+ def mean(arr):
21
+ return sum(arr) / len(arr)
22
+
23
+
24
+ @register_aggregation("median")
25
+ def median(arr):
26
+ return arr[len(arr) // 2]
27
+
28
+
29
+ # Certain metrics must be calculated across all documents in a benchmark.
30
+ # We use them as aggregation metrics, paired with no-op passthrough metric fns.
31
+ @register_aggregation("perplexity")
32
+ def perplexity(items):
33
+ # return math.exp(-mean(items))
34
+ items = torch.exp(torch.tensor(items)).tolist()
35
+ return sum(items) / len(items)
36
+
37
+
38
+ @register_aggregation("weighted_perplexity")
39
+ def weighted_perplexity(items):
40
+ return math.exp(-weighted_mean(items))
41
+
42
+
43
+ @register_aggregation("bits_per_byte")
44
+ def bits_per_byte(items):
45
+ return -weighted_mean(items) / math.log(2)
46
+
47
+
48
+ @register_aggregation("f1")
49
+ def f1_score(items):
50
+ unzipped_list = list(zip(*items))
51
+ golds = unzipped_list[0]
52
+ preds = unzipped_list[1]
53
+ fscore = sklearn.metrics.f1_score(golds, preds)
54
+
55
+ return np.max(fscore)
56
+
57
+
58
+ @register_aggregation("matthews_corrcoef")
59
+ def matthews_corrcoef(items):
60
+ unzipped_list = list(zip(*items))
61
+ golds = unzipped_list[0]
62
+ preds = unzipped_list[1]
63
+ # print(preds)
64
+ return sklearn.metrics.matthews_corrcoef(golds, preds)
65
+
66
+
67
+ @register_aggregation("bleu")
68
+ def bleu(items):
69
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
70
+ for evaluating a generated sentence to a reference sentence. It counts matching
71
+ n-grams in the candidate translation to n-grams in the reference text, where
72
+ 1-gram or unigram would be each token and a bigram comparison would be each
73
+ word pair. The comparison is made regardless of word order
74
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
75
+ Paper: https://www.aclweb.org/anthology/P02-1040/
76
+
77
+ Higher is better
78
+ """
79
+ refs = list(zip(*items))[0]
80
+ preds = list(zip(*items))[1]
81
+ refs, preds = _sacreformat(refs, preds)
82
+ return sacrebleu.corpus_bleu(preds, refs).score
83
+
84
+
85
+ @register_aggregation("chrf")
86
+ def chrf(items):
87
+ """chrF++ is a tool for automatic evaluation of machine translation output
88
+ based on character n-gram precision and recall enhanced with word n-grams.
89
+ Source: https://github.com/m-popovic/chrF
90
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
91
+
92
+ Higher is better # TODO I think
93
+ """
94
+ refs = list(zip(*items))[0]
95
+ preds = list(zip(*items))[1]
96
+ refs, preds = _sacreformat(refs, preds)
97
+ return sacrebleu.corpus_chrf(preds, refs).score
98
+
99
+
100
+ @register_aggregation("ter")
101
+ def ter(items):
102
+ """Translation Error Rate is an error metric for machine translation that
103
+ measures the number of edits required to change a system output into one
104
+ of the references
105
+ Source: http://www.cs.umd.edu/~snover/tercom/
106
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
107
+
108
+ Lower is better
109
+ """
110
+ refs = list(zip(*items))[0]
111
+ preds = list(zip(*items))[1]
112
+ refs, preds = _sacreformat(refs, preds)
113
+ return sacrebleu.corpus_ter(preds, refs).score
114
+
115
+
116
+ @register_metric(
117
+ metric="acc",
118
+ higher_is_better=True,
119
+ output_type=["loglikelihood", "multiple_choice"],
120
+ aggregation="mean",
121
+ )
122
+ def acc_fn(items): # This is a passthrough function
123
+ return items
124
+
125
+
126
+ @register_metric(
127
+ metric="acc_norm",
128
+ higher_is_better=True,
129
+ output_type=["loglikelihood", "multiple_choice"],
130
+ aggregation="mean",
131
+ )
132
+ def acc_norm_fn(items): # This is a passthrough function
133
+ return items
134
+
135
+
136
+ @register_metric(
137
+ metric="acc_mutual_info",
138
+ higher_is_better=True,
139
+ output_type="multiple_choice",
140
+ aggregation="mean",
141
+ )
142
+ def acc_mutual_info_fn(items): # This is a passthrough function
143
+ return items
144
+
145
+
146
+ exact_match = evaluate.load("exact_match")
147
+
148
+
149
+ @register_metric(
150
+ metric="exact_match",
151
+ higher_is_better=True,
152
+ output_type="generate_until",
153
+ aggregation="mean",
154
+ )
155
+ def exact_match_fn(**kwargs):
156
+ return exact_match.compute(**kwargs)
157
+
158
+
159
+ @register_metric(
160
+ metric="perplexity",
161
+ higher_is_better=False,
162
+ output_type="loglikelihood",
163
+ aggregation="perplexity",
164
+ )
165
+ def perplexity_fn(items): # This is a passthrough function
166
+ return items
167
+
168
+
169
+ def levenshtein_distance(s1, s2):
170
+ if len(s1) > len(s2):
171
+ s1, s2 = s2, s1
172
+
173
+ distances = range(len(s1) + 1)
174
+ for i2, c2 in enumerate(s2):
175
+ distances_ = [i2 + 1]
176
+ for i1, c1 in enumerate(s1):
177
+ if c1 == c2:
178
+ distances_.append(distances[i1])
179
+ else:
180
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
181
+ distances = distances_
182
+ return distances[-1]
183
+
184
+
185
+ @register_metric(
186
+ metric="anls",
187
+ higher_is_better=True,
188
+ output_type="generate_until",
189
+ aggregation="mean",
190
+ )
191
+ def anls(
192
+ references,
193
+ predictions,
194
+ thresh_hold=0.5,
195
+ ): # This is a passthrough function
196
+ """https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/infographicsvqa_eval.py"""
197
+ values = []
198
+ for answer in references:
199
+ # preprocess both the answers - gt and prediction
200
+ gt_answer = " ".join(answer.strip().lower().split())
201
+ det_answer = " ".join(predictions[0].strip().lower().split())
202
+
203
+ # dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
204
+ dist = levenshtein_distance(gt_answer, det_answer)
205
+ length = max(len(answer.upper()), len(predictions[0].upper()))
206
+ values.append(0.0 if length == 0 else float(dist) / float(length))
207
+
208
+ question_result = 1 - min(values)
209
+
210
+ if question_result < thresh_hold:
211
+ question_result = 0
212
+ return {"anls": question_result}
213
+
214
+
215
+ def pop_stddev(arr):
216
+ mu = mean(arr)
217
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
218
+
219
+
220
+ def sample_stddev(arr):
221
+ mu = mean(arr)
222
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
223
+
224
+
225
+ def mean_stderr(arr):
226
+ return sample_stddev(arr) / math.sqrt(len(arr))
227
+
228
+
229
+ @register_metric(
230
+ metric="mcc",
231
+ higher_is_better=True,
232
+ output_type="multiple_choice",
233
+ aggregation="matthews_corrcoef",
234
+ )
235
+ def mcc_fn(items): # This is a passthrough function
236
+ return items
237
+
238
+
239
+ @register_metric(
240
+ metric="f1",
241
+ higher_is_better=True,
242
+ output_type="multiple_choice",
243
+ aggregation="f1",
244
+ )
245
+ def f1_fn(items): # This is a passthrough function
246
+ return items
247
+
248
+
249
+ @register_metric(
250
+ metric="bleu",
251
+ higher_is_better=True,
252
+ output_type="generate_until",
253
+ aggregation="bleu",
254
+ )
255
+ def bleu_fn(items): # This is a passthrough function
256
+ return items
257
+
258
+
259
+ @register_metric(
260
+ metric="chrf",
261
+ higher_is_better=True,
262
+ output_type="generate_until",
263
+ aggregation="chrf",
264
+ )
265
+ def chrf_fn(items): # This is a passthrough function
266
+ return items
267
+
268
+
269
+ @register_metric(
270
+ metric="ter",
271
+ higher_is_better=True,
272
+ output_type="generate_until",
273
+ aggregation="ter",
274
+ )
275
+ def ter_fn(items): # This is a passthrough function
276
+ return items
277
+
278
+
279
+ @register_metric(
280
+ metric="acc_all",
281
+ higher_is_better=True,
282
+ output_type="loglikelihood",
283
+ aggregation="mean",
284
+ )
285
+ def acc_all(items):
286
+ # Only count as correct if all answers are labeled correctly for each question
287
+ question_scoring_dict = {}
288
+ preds = list(zip(*items))[0]
289
+ docs = list(zip(*items))[1]
290
+
291
+ for doc, pred in zip(docs, preds):
292
+ paragraph_id = doc["idx"]["paragraph"]
293
+ question_id = doc["idx"]["question"]
294
+ if (paragraph_id, question_id) not in question_scoring_dict:
295
+ question_scoring_dict[(paragraph_id, question_id)] = []
296
+
297
+ gold_label = doc["label"] == 1
298
+
299
+ question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
300
+ acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
301
+ return acc
302
+
303
+
304
+ def acc_all_stderr(items):
305
+ # Only count as correct if all answers are labeled correctly for each question
306
+ question_scoring_dict = {}
307
+ preds = list(zip(*items))[0]
308
+ docs = list(zip(*items))[1]
309
+
310
+ for doc, pred in zip(docs, preds):
311
+ question_id = doc["idx"]["question"]
312
+ if question_id not in question_scoring_dict:
313
+ question_scoring_dict[question_id] = []
314
+
315
+ gold_label = doc["label"] == 1
316
+ question_scoring_dict[question_id].append(gold_label == pred)
317
+
318
+ acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
319
+ return acc
320
+
321
+
322
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
323
+ """Compute max metric between prediction and each ground truth."""
324
+ scores_for_ground_truths = []
325
+ for ground_truth in ground_truths:
326
+ score = metric_fn(prediction, ground_truth)
327
+ scores_for_ground_truths.append(score)
328
+ return max(scores_for_ground_truths)
329
+
330
+
331
+ def weighted_mean(items):
332
+ a, b = zip(*items)
333
+ return sum(a) / sum(b)
334
+
335
+
336
+ def is_non_str_iterable(obj):
337
+ return isinstance(obj, Iterable) and not isinstance(obj, str)
338
+
339
+
340
+ def _sacreformat(refs, preds):
341
+ """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
342
+ # Sacrebleu expects (List[str], List[List[str])
343
+ # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
344
+
345
+ # Note [ref1_stream] is the first reference for each pred.
346
+ # So lists are size N and (M, N) for N preds and M possible refs for each pred
347
+ # This is a different order of dimensions that I would expect
348
+
349
+ # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
350
+ # Must become List[List[str]] with the inner list corresponding to preds
351
+ if not is_non_str_iterable(refs):
352
+ refs = list(refs)
353
+ if not is_non_str_iterable(refs[0]):
354
+ refs = [[ref] for ref in refs]
355
+ refs = list(zip(*refs))
356
+ # Note the number of refs in each ref list much match the number of preds
357
+
358
+ # We expect preds to be List[str] or List[List[str]]. Must become List[str]
359
+ if not is_non_str_iterable(preds):
360
+ preds = list(preds)
361
+ if is_non_str_iterable(preds[0]):
362
+ assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
363
+ preds = [pred[0] for pred in preds]
364
+
365
+ return refs, preds
366
+
367
+
368
+ # stderr stuff
369
+
370
+
371
+ class _bootstrap_internal:
372
+ def __init__(self, f, n) -> None:
373
+ self.f = f
374
+ self.n = n
375
+
376
+ def __call__(self, v):
377
+ i, xs = v
378
+ rnd = random.Random()
379
+ rnd.seed(i)
380
+ res = []
381
+ for _ in range(self.n):
382
+ res.append(self.f(rnd.choices(xs, k=len(xs))))
383
+ return res
384
+
385
+
386
+ def bootstrap_stderr(f, xs, iters):
387
+ import multiprocessing as mp
388
+
389
+ pool = mp.Pool(mp.cpu_count())
390
+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
391
+ # equivalent to stderr calculated without Bessel's correction in the stddev.
392
+ # Unfortunately, I haven't been able to figure out what the right correction is
393
+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
394
+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
395
+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
396
+ res = []
397
+ chunk_size = min(1000, iters)
398
+ from tqdm import tqdm
399
+
400
+ print("bootstrapping for stddev:", f.__name__)
401
+ for bootstrap in tqdm(
402
+ pool.imap(
403
+ _bootstrap_internal(f, chunk_size),
404
+ [(i, xs) for i in range(iters // chunk_size)],
405
+ ),
406
+ total=iters // chunk_size,
407
+ ):
408
+ # sample w replacement
409
+ res.extend(bootstrap)
410
+
411
+ pool.close()
412
+ return sample_stddev(res)
413
+
414
+
415
+ def stderr_for_metric(metric, bootstrap_iters):
416
+ bootstrappable = [
417
+ median,
418
+ matthews_corrcoef,
419
+ f1_score,
420
+ perplexity,
421
+ bleu,
422
+ chrf,
423
+ ter,
424
+ ]
425
+
426
+ if metric in bootstrappable:
427
+ return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
428
+
429
+ stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
430
+
431
+ return stderr.get(metric, None)
EAGLE/lmms_eval/api/model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+
4
+ from typing import Union, List, Tuple, Optional, Type, TypeVar
5
+ from sqlitedict import SqliteDict
6
+ import json
7
+ import hashlib
8
+ from lmms_eval.api.instance import Instance
9
+ from tqdm import tqdm
10
+ from lmms_eval import utils
11
+ import logging
12
+
13
+ eval_logger = logging.getLogger("lmms-eval")
14
+
15
+ T = TypeVar("T", bound="lmms")
16
+
17
+
18
+ class lmms(abc.ABC):
19
+ def __init__(self) -> None:
20
+ """Defines the interface that should be implemented by all lmms subclasses.
21
+ lmmss are assumed to take image-text as input and yield strings as output
22
+ (inputs/outputs should be tokenization-agnostic.)
23
+ """
24
+ # set rank and world size to a single process, by default.
25
+ self._rank = 0
26
+ self._world_size = 1
27
+ self.cache_hook = CacheHook(None)
28
+ self.task_dict = {}
29
+
30
+ @abc.abstractmethod
31
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
32
+ """Compute log-likelihood of generating a continuation from a context.
33
+ Downstream tasks should attempt to use loglikelihood instead of other
34
+ LMM calls whenever possible.
35
+
36
+ :param requests: list[Instance]
37
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
38
+ `context: str`
39
+ Context string. Implementations of LMM must be able to handle an
40
+ empty context string.
41
+ `continuation: str`
42
+ The continuation over which log likelihood will be calculated. If
43
+ there is a word boundary, the space should be in the continuation.
44
+ For example, context="hello" continuation=" world" is correct.
45
+ 'visual_list: list[dict]'
46
+ Visual input to the model. Can be None.
47
+
48
+ :return: list[tuple[float, bool]]
49
+ A list of pairs (logprob, isgreedy)
50
+ `logprob: float`
51
+ The log probability of `continuation`.
52
+ `isgreedy`:
53
+ Whether `continuation` would be generated by greedy sampling from `context`.
54
+ """
55
+ pass
56
+
57
+ # TODO: Add an optional max length
58
+ @abc.abstractmethod
59
+ def generate_until(self, requests) -> List[str]:
60
+ """Generate greedily until a stopping sequence
61
+
62
+ :param requests: list[Instance]
63
+ A list of Instance objects with property `args` which returns a tuple (context, until).
64
+ context: str
65
+ Context string
66
+ generation_kwargs: dict
67
+ Generation Kwargs
68
+ 'visual_list: list[dict]'
69
+ Visual input to the model. Can be None.
70
+ :return: list[str]
71
+ A list of strings continuation
72
+ continuation: str
73
+ The generated continuation.
74
+ """
75
+ pass
76
+
77
+ @classmethod
78
+ def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T:
79
+ """
80
+ Creates an instance of the LMM class using the given argument string and additional config.
81
+
82
+ Parameters:
83
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
84
+ - additional_config: Optional dictionary containing additional configuration parameters.
85
+
86
+ Returns:
87
+ - Instance of the LMM class.
88
+ """
89
+ additional_config = {} if additional_config is None else additional_config
90
+ args = utils.simple_parse_args_string(arg_string)
91
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
92
+ return cls(**args, **args2)
93
+
94
+ @property
95
+ def rank(self):
96
+ # used in the case of parallelism. Hardcoded to
97
+ # ensure no errors arise using API models which do
98
+ # not support multi-device parallelism nor expect it.
99
+ return self._rank
100
+
101
+ @property
102
+ def world_size(self):
103
+ # used in the case of parallelism. Hardcoded to
104
+ # ensure no errors arise using API models which do
105
+ # not support multi-device parallelism nor expect it.
106
+ return self._world_size
107
+
108
+ def set_cache_hook(self, cache_hook) -> None:
109
+ self.cache_hook = cache_hook
110
+
111
+
112
+ ### SQLite-based caching of LMM responses
113
+ def hash_args(attr, args):
114
+ dat = json.dumps([attr] + list(args))
115
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
116
+
117
+
118
+ class CacheHook:
119
+ def __init__(self, cachinglm) -> None:
120
+ if cachinglm is None:
121
+ self.dbdict = None
122
+ return
123
+
124
+ self.dbdict = cachinglm.dbdict
125
+
126
+ def add_partial(self, attr, req, res) -> None:
127
+ if self.dbdict is None:
128
+ return
129
+ hsh = hash_args(attr, req)
130
+ self.dbdict[hsh] = res
131
+
132
+
133
+ class CachingLMM:
134
+ def __init__(self, lm, cache_db) -> None:
135
+ """LMM wrapper that returns cached results if they exist, and uses the underlying LMM if not.
136
+
137
+ :param lm: LMM
138
+ Underlying LMM
139
+ :param cache_db: str
140
+ Path to cache db
141
+ """
142
+ self.lm = lm
143
+ self.cache_db = cache_db
144
+ if os.path.dirname(cache_db):
145
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
146
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
147
+
148
+ # add hook to lm
149
+ lm.set_cache_hook(self.get_cache_hook())
150
+
151
+ def __getattr__(self, attr):
152
+ lm_attr = getattr(self.lm, attr)
153
+ if not callable(lm_attr):
154
+ return lm_attr
155
+
156
+ def fn(requests):
157
+ res = []
158
+ remaining_reqs = []
159
+ warned = False
160
+ # figure out which ones are cached and which ones are new
161
+ eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...")
162
+ for req in tqdm(requests):
163
+ hsh = hash_args(attr, req.args)
164
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
165
+ # when we are doing non-greedy generation, don't use the cache
166
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
167
+ if not warned:
168
+ eval_logger.warning(f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests.")
169
+ warned = True
170
+ res.append(None)
171
+ remaining_reqs.append(req)
172
+ elif hsh in self.dbdict:
173
+ ob = self.dbdict[hsh]
174
+
175
+ assert ob is not None
176
+
177
+ res.append(ob)
178
+ else:
179
+ res.append(None)
180
+ remaining_reqs.append(req)
181
+
182
+ # actually run the LMM on the requests that do not have cached results
183
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
184
+
185
+ # stick the new ones back into the list and also cache any of the new ones
186
+ resptr = 0
187
+ for req, r in zip(remaining_reqs, rem_res):
188
+ while res[resptr] is not None:
189
+ resptr += 1
190
+
191
+ res[resptr] = r
192
+
193
+ # caching
194
+ hsh = hash_args(attr, req.args)
195
+ self.dbdict[hsh] = r
196
+ self.dbdict.commit()
197
+
198
+ return res
199
+
200
+ return fn
201
+
202
+ def get_cache_hook(self):
203
+ return CacheHook(self)
EAGLE/lmms_eval/api/registry.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lmms_eval.api.model import lmms
2
+
3
+ import logging
4
+
5
+ eval_logger = logging.getLogger("lmms-eval")
6
+
7
+ MODEL_REGISTRY = {}
8
+
9
+
10
+ def register_model(*names):
11
+ # either pass a list or a single alias.
12
+ # function receives them as a tuple of strings
13
+
14
+ def decorate(cls):
15
+ for name in names:
16
+ assert issubclass(cls, lmms), f"Model '{name}' ({cls.__name__}) must extend lmms class"
17
+
18
+ assert name not in MODEL_REGISTRY, f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
19
+
20
+ MODEL_REGISTRY[name] = cls
21
+ return cls
22
+
23
+ return decorate
24
+
25
+
26
+ def get_model(model_name):
27
+ try:
28
+ return MODEL_REGISTRY[model_name]
29
+ except KeyError:
30
+ raise ValueError(f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}")
31
+
32
+
33
+ TASK_REGISTRY = {} # Key: task name, Value: task ConfigurableTask class
34
+ GROUP_REGISTRY = {} # Key: group name, Value: list of task names or group names
35
+ ALL_TASKS = set() # Set of all task names and group names
36
+ func2task_index = {} # Key: task ConfigurableTask class, Value: task name
37
+
38
+
39
+ def register_task(name):
40
+ def decorate(fn):
41
+ assert name not in TASK_REGISTRY, f"task named '{name}' conflicts with existing registered task!"
42
+
43
+ TASK_REGISTRY[name] = fn
44
+ ALL_TASKS.add(name)
45
+ func2task_index[fn.__name__] = name
46
+ return fn
47
+
48
+ return decorate
49
+
50
+
51
+ def register_group(name):
52
+ def decorate(fn):
53
+ func_name = func2task_index[fn.__name__]
54
+ if name in GROUP_REGISTRY:
55
+ GROUP_REGISTRY[name].append(func_name)
56
+ else:
57
+ GROUP_REGISTRY[name] = [func_name]
58
+ ALL_TASKS.add(name)
59
+ return fn
60
+
61
+ return decorate
62
+
63
+
64
+ OUTPUT_TYPE_REGISTRY = {}
65
+ METRIC_REGISTRY = {}
66
+ METRIC_AGGREGATION_REGISTRY = {}
67
+ AGGREGATION_REGISTRY = {}
68
+ HIGHER_IS_BETTER_REGISTRY = {}
69
+
70
+ DEFAULT_METRIC_REGISTRY = {
71
+ "loglikelihood": [
72
+ "perplexity",
73
+ "acc",
74
+ ],
75
+ "multiple_choice": ["acc", "acc_norm"],
76
+ "generate_until": ["exact_match"],
77
+ }
78
+
79
+
80
+ def register_metric(**args):
81
+ # TODO: do we want to enforce a certain interface to registered metrics?
82
+ def decorate(fn):
83
+ assert "metric" in args
84
+ name = args["metric"]
85
+
86
+ for key, registry in [
87
+ ("metric", METRIC_REGISTRY),
88
+ ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
89
+ ("aggregation", METRIC_AGGREGATION_REGISTRY),
90
+ ]:
91
+ if key in args:
92
+ value = args[key]
93
+ assert value not in registry, f"{key} named '{value}' conflicts with existing registered {key}!"
94
+
95
+ if key == "metric":
96
+ registry[name] = fn
97
+ elif key == "aggregation":
98
+ registry[name] = AGGREGATION_REGISTRY[value]
99
+ else:
100
+ registry[name] = value
101
+
102
+ return fn
103
+
104
+ return decorate
105
+
106
+
107
+ def register_aggregation(name):
108
+ def decorate(fn):
109
+ assert name not in AGGREGATION_REGISTRY, f"aggregation named '{name}' conflicts with existing registered aggregation!"
110
+
111
+ AGGREGATION_REGISTRY[name] = fn
112
+ return fn
113
+
114
+ return decorate
115
+
116
+
117
+ def get_aggregation(name):
118
+ try:
119
+ return AGGREGATION_REGISTRY[name]
120
+ except KeyError:
121
+ eval_logger.warning(
122
+ "{} not a registered aggregation metric!".format(name),
123
+ )
124
+
125
+
126
+ def get_metric_aggregation(name):
127
+ try:
128
+ return METRIC_AGGREGATION_REGISTRY[name]
129
+ except KeyError:
130
+ eval_logger.warning(
131
+ "{} metric is not assigned a default aggregation!".format(name),
132
+ )
133
+
134
+
135
+ def is_higher_better(metric_name):
136
+ try:
137
+ return HIGHER_IS_BETTER_REGISTRY[metric_name]
138
+ except KeyError:
139
+ eval_logger.warning(f"higher_is_better not specified for metric '{metric_name}'!")
EAGLE/lmms_eval/api/samplers.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ContextSampler:
2
+ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
3
+ self.rnd = rnd
4
+ assert self.rnd, "must pass rnd to FewShotSampler!"
5
+
6
+ self.task = task
7
+ self.config = task._config
8
+
9
+ self.target_delimiter = self.config.target_delimiter
10
+ self.fewshot_delimiter = self.config.fewshot_delimiter
11
+
12
+ self.doc_to_text = self.task.doc_to_text
13
+ self.doc_to_target = self.task.doc_to_target
14
+ self.doc_to_choice = self.task.doc_to_choice
15
+
16
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
17
+ if fewshot_indices: # subset few-shot docs from
18
+ self.docs = self.docs.select(fewshot_indices)
19
+
20
+ def get_context(self, doc, num_fewshot):
21
+ # draw an extra fewshot sample if using same split as evaluating on
22
+ n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
23
+
24
+ # draw `n_samples` docs from fewshot_docs
25
+ fewshotex = self.sample(n_samples)
26
+
27
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
28
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
29
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
30
+
31
+ labeled_examples = (
32
+ self.fewshot_delimiter.join(
33
+ [
34
+ # TODO: is separating doc_to_text and doc_to_target by one space always desired?
35
+ (self.doc_to_text(doc) if (self.config.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)])
36
+ + self.target_delimiter
37
+ + (
38
+ str(self.doc_to_target(doc)[0])
39
+ if type(self.doc_to_target(doc)) is list
40
+ else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
41
+ )
42
+ for doc in selected_docs
43
+ ]
44
+ )
45
+ + self.fewshot_delimiter
46
+ )
47
+
48
+ return labeled_examples
49
+
50
+ def sample(self, n):
51
+ """
52
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
53
+ """
54
+
55
+ return self.rnd.sample(self.docs, n)
56
+
57
+
58
+ class FirstNSampler(ContextSampler):
59
+ def sample(self, n) -> None:
60
+ """
61
+ Draw the first `n` samples in order from the specified split.
62
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
63
+ """
64
+ assert n <= len(self.docs), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
65
+ return self.docs[:n]
66
+
67
+
68
+ class BalancedSampler(ContextSampler):
69
+ def sample(self, n) -> None:
70
+ """
71
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
72
+ TODO: what order should they be in? maybe random?
73
+ """
74
+
75
+ pass
76
+
77
+
78
+ class ManualSampler(ContextSampler):
79
+ def sample(self, n) -> None:
80
+ """ """
81
+ pass
82
+
83
+
84
+ SAMPLER_REGISTRY = {
85
+ "default": ContextSampler,
86
+ "first_n": FirstNSampler,
87
+ }
88
+
89
+
90
+ def get_sampler(name):
91
+ try:
92
+ return SAMPLER_REGISTRY[name]
93
+ except KeyError:
94
+ raise ValueError(f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}")
EAGLE/lmms_eval/api/task.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import dataclass, field, asdict
3
+
4
+ import itertools
5
+ import os
6
+ import re
7
+ import ast
8
+ import logging
9
+ import random
10
+ from tqdm import tqdm
11
+
12
+ import datasets
13
+ from datasets import Image, Sequence
14
+ import numpy as np
15
+ from PIL import ImageFile
16
+
17
+ from datasets import DownloadConfig
18
+ from typing import Union, List, Any
19
+ from collections.abc import Callable
20
+ from tenacity import retry, stop_after_attempt, wait_fixed
21
+
22
+ from lmms_eval import utils
23
+ from lmms_eval.api import samplers
24
+ from lmms_eval.api.instance import Instance
25
+
26
+ from lmms_eval.filters import build_filter_ensemble
27
+ from lmms_eval.api.registry import (
28
+ get_aggregation,
29
+ get_metric_aggregation,
30
+ is_higher_better,
31
+ DEFAULT_METRIC_REGISTRY,
32
+ METRIC_REGISTRY,
33
+ OUTPUT_TYPE_REGISTRY,
34
+ AGGREGATION_REGISTRY,
35
+ )
36
+
37
+ ALL_OUTPUT_TYPES = [
38
+ "loglikelihood",
39
+ "multiple_choice",
40
+ "generate_until",
41
+ ]
42
+
43
+
44
+ eval_logger = logging.getLogger("lmms-eval")
45
+
46
+ # HuggingfaceM4/NoCaps contains truncated image in test split
47
+ # Include this inside code block to avoid error
48
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
49
+
50
+
51
+ @dataclass
52
+ class TaskConfig(dict):
53
+ # task naming/registry
54
+ task: str = None
55
+ task_alias: str = None
56
+ group: Union[str, list] = None
57
+ group_alias: Union[str, list] = None
58
+ # HF dataset options.
59
+ # which dataset to use,
60
+ # and what splits for what purpose
61
+ dataset_path: str = None
62
+ dataset_name: str = None
63
+ dataset_kwargs: dict = None
64
+ training_split: str = None
65
+ validation_split: str = None
66
+ test_split: str = None
67
+ fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
68
+ # formatting / prompting options.
69
+ # see docs/advanced_task_guide.md for more info
70
+ process_docs: Callable = None
71
+ doc_to_visual: Union[Callable, str] = None
72
+ doc_to_text: Union[Callable, str] = None
73
+ doc_to_target: Union[Callable, str] = None
74
+ doc_to_choice: Union[Callable, str, dict, list] = None
75
+ process_results: Union[Callable, str] = None
76
+ use_prompt: str = None
77
+ description: str = ""
78
+ target_delimiter: str = " "
79
+ fewshot_delimiter: str = "\n\n"
80
+ fewshot_config: dict = None
81
+ # runtime configuration options
82
+ num_fewshot: int = None
83
+ # scoring options
84
+ metric_list: list = None
85
+ output_type: str = "generate_until"
86
+ generation_kwargs: dict = None
87
+ repeats: int = 1
88
+ filter_list: Union[str, list] = None
89
+ should_decontaminate: bool = False
90
+ doc_to_decontamination_query: str = None
91
+
92
+ metadata: Union[str, list] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
93
+
94
+ model_specific_prompt_kwargs: dict = None
95
+ model_specific_generation_kwargs: dict = None
96
+ model_specific_target_kwargs: dict = None
97
+
98
+ def __post_init__(self) -> None:
99
+ if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)):
100
+ import inspect
101
+ from importlib import import_module
102
+
103
+ self.dataset_path = inspect.getfile(import_module(self.dataset_path))
104
+
105
+ if self.generation_kwargs is not None:
106
+ if self.output_type != "generate_until":
107
+ eval_logger.warning(f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!")
108
+ assert self.output_type != "generate_until"
109
+
110
+ if "temperature" in self.generation_kwargs:
111
+ self.generation_kwargs["temperature"] = float(self.generation_kwargs["temperature"])
112
+
113
+ if "until" not in self.generation_kwargs:
114
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
115
+ else:
116
+ if self.output_type == "generate_until":
117
+ # ensure that we greedily generate in absence of explicit arguments otherwise
118
+ self.generation_kwargs = {
119
+ "until": None if self.fewshot_delimiter is None else [self.fewshot_delimiter],
120
+ "do_sample": False,
121
+ }
122
+
123
+ # TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
124
+
125
+ def __getitem__(self, item):
126
+ return getattr(self, item)
127
+
128
+ def __setitem__(self, item, value):
129
+ return setattr(self, item, value)
130
+
131
+ def to_dict(self):
132
+ """dumps the current config as a dictionary object, as a printable format.
133
+ null fields will not be printed.
134
+ Used for dumping results alongside full task configuration
135
+
136
+ :return: dict
137
+ A printable dictionary version of the TaskConfig object.
138
+
139
+ # TODO: should any default value in the TaskConfig not be printed?
140
+ """
141
+ cfg_dict = asdict(self)
142
+ # remove values that are `None`
143
+ for k, v in list(cfg_dict.items()):
144
+ if v is None:
145
+ cfg_dict.pop(k)
146
+ elif isinstance(v, Callable):
147
+ # TODO: this should handle Promptsource template objects as a separate case?
148
+ cfg_dict[k] = str(v)
149
+ return cfg_dict
150
+
151
+
152
+ class Task(abc.ABC):
153
+ """A task represents an entire benchmark including its dataset, problems,
154
+ answers, and evaluation methods. See BoolQ for a simple example implementation
155
+
156
+ A `doc` can be any python object which represents one instance of evaluation.
157
+ This is usually a dictionary e.g.
158
+ {"question": ..., "answer": ...} or
159
+ {"question": ..., question, answer)
160
+ """
161
+
162
+ VERSION = None
163
+
164
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
165
+ # or a path to a custom `datasets` loading script.
166
+ DATASET_PATH: str = None
167
+
168
+ # The name of a subset within `DATASET_PATH`.
169
+ DATASET_NAME: str = None
170
+
171
+ OUTPUT_TYPE: str = None
172
+
173
+ def __init__(
174
+ self,
175
+ data_dir=None,
176
+ cache_dir=None,
177
+ download_mode=None,
178
+ config=None,
179
+ ) -> None:
180
+ """
181
+ :param data_dir: str
182
+ Stores the path to a local folder containing the `Task`'s data files.
183
+ Use this to specify the path to manually downloaded data (usually when
184
+ the dataset is not publicly accessible).
185
+ :param cache_dir: str
186
+ The directory to read/write the `Task` dataset. This follows the
187
+ HuggingFace `datasets` API with the default cache directory located at:
188
+ `~/.cache/huggingface/datasets`
189
+ NOTE: You can change the cache location globally for a given process
190
+ to another directory:
191
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
192
+ :param download_mode: datasets.DownloadMode
193
+ How to treat pre-existing `Task` downloads and data.
194
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
195
+ Reuse download and reuse dataset.
196
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
197
+ Reuse download with fresh dataset.
198
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
199
+ Fresh download and fresh dataset.
200
+ """
201
+ self.download(data_dir, cache_dir, download_mode)
202
+ self._training_docs = None
203
+ self._fewshot_docs = None
204
+ self._instances = None
205
+
206
+ self._config = TaskConfig({**config}) if config else TaskConfig()
207
+
208
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
209
+
210
+ def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
211
+ """Downloads and returns the task dataset.
212
+ Override this method to download the dataset from a custom API.
213
+
214
+ :param data_dir: str
215
+ Stores the path to a local folder containing the `Task`'s data files.
216
+ Use this to specify the path to manually downloaded data (usually when
217
+ the dataset is not publicly accessible).
218
+ :param cache_dir: str
219
+ The directory to read/write the `Task` dataset. This follows the
220
+ HuggingFace `datasets` API with the default cache directory located at:
221
+ `~/.cache/huggingface/datasets`
222
+ NOTE: You can change the cache location globally for a given process
223
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
224
+ to another directory:
225
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
226
+ :param download_mode: datasets.DownloadMode
227
+ How to treat pre-existing `Task` downloads and data.
228
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
229
+ Reuse download and reuse dataset.
230
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
231
+ Reuse download with fresh dataset.
232
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
233
+ Fresh download and fresh dataset.
234
+ """
235
+ self.dataset = datasets.load_dataset(
236
+ path=self.DATASET_PATH,
237
+ name=self.DATASET_NAME,
238
+ data_dir=data_dir,
239
+ cache_dir=cache_dir,
240
+ download_mode=download_mode,
241
+ )
242
+ self.dataset_no_image = datasets.load_dataset(
243
+ path=self.DATASET_PATH,
244
+ name=self.DATASET_NAME,
245
+ data_dir=data_dir,
246
+ cache_dir=cache_dir,
247
+ download_mode=download_mode,
248
+ )
249
+ for doc_name in self.dataset_no_image:
250
+ remove_cols = []
251
+ features = self.dataset_no_image[doc_name].features
252
+ # If it is an Image instance or a Sequence of Image instance. Remove it
253
+ for feature in features:
254
+ if isinstance(features[feature], Image):
255
+ remove_cols.append(feature)
256
+ elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image):
257
+ remove_cols.append(feature)
258
+ for remove_col in remove_cols:
259
+ self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col)
260
+
261
+ @property
262
+ def config(self):
263
+ """Returns the TaskConfig associated with this class."""
264
+ return self._config
265
+
266
+ @abc.abstractmethod
267
+ def has_training_docs(self):
268
+ """Whether the task has a training set"""
269
+ pass
270
+
271
+ @abc.abstractmethod
272
+ def has_validation_docs(self):
273
+ """Whether the task has a validation set"""
274
+ pass
275
+
276
+ @abc.abstractmethod
277
+ def has_test_docs(self):
278
+ """Whether the task has a test set"""
279
+ pass
280
+
281
+ def training_docs(self):
282
+ """
283
+ :return: Iterable[obj]
284
+ A iterable of any object, that doc_to_text can handle
285
+ """
286
+ return []
287
+
288
+ def validation_docs(self):
289
+ """
290
+ :return: Iterable[obj]
291
+ A iterable of any object, that doc_to_text can handle
292
+ """
293
+ return []
294
+
295
+ def test_docs(self):
296
+ """
297
+ :return: Iterable[obj]
298
+ A iterable of any object, that doc_to_text can handle
299
+ """
300
+ return []
301
+
302
+ def fewshot_docs(self):
303
+ """
304
+ :return: Iterable[obj]
305
+ A iterable of any object, that doc_to_text can handle
306
+ """
307
+ if self.has_training_docs():
308
+ return self.training_docs()
309
+ elif self.has_validation_docs():
310
+ return self.validation_docs()
311
+ else:
312
+ if self.config.num_fewshot is not None:
313
+ eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
314
+ return self.test_docs()
315
+
316
+ def _process_doc(self, doc):
317
+ """
318
+ Override this to process (detokenize, strip, replace, etc.) individual
319
+ documents. This can be used in a map over documents of a data split.
320
+ E.g. `map(self._process_doc, self.dataset["validation"])`
321
+
322
+ :return: dict
323
+ The processed version of the specified `doc`.
324
+ """
325
+ return doc
326
+
327
+ @property
328
+ def instances(self):
329
+ """After calling `task.build_all_requests()`, tasks
330
+ maintain a list of the dataset instances which will be evaluated.
331
+ """
332
+ return self._instances
333
+
334
+ def fewshot_examples(self, k, rnd):
335
+ if self._training_docs is None:
336
+ self._training_docs = list(self.training_docs())
337
+
338
+ return rnd.sample(self._training_docs, k)
339
+
340
+ def doc_to_decontamination_query(self, doc) -> None:
341
+ print("Override doc_to_decontamination_query with document specific decontamination query.")
342
+ assert False
343
+
344
+ @abc.abstractmethod
345
+ def doc_to_text(self, doc):
346
+ pass
347
+
348
+ @abc.abstractmethod
349
+ def doc_to_target(self, doc):
350
+ pass
351
+
352
+ # @profile
353
+ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
354
+ """Build a set of Instances for a task, and store them in task.instances"""
355
+ if self.has_test_docs():
356
+ docs = self.test_docs()
357
+ split = self.config.test_split
358
+ elif self.has_validation_docs():
359
+ docs = self.validation_docs()
360
+ split = self.config.validation_split
361
+ else:
362
+ assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
363
+
364
+ eval_logger.info(f"Building contexts for task {self.CONFIG.task} on rank {rank}...")
365
+ instances = []
366
+ doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit)
367
+ doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator)
368
+ total_docs = sum(1 for _ in doc_id_iterator_counting)
369
+ pbar = tqdm(total=total_docs, desc=f"Building context", disable=(rank != 0))
370
+ for doc_id in doc_id_iterator:
371
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
372
+ fewshot_ctx = self.fewshot_context(doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, self.config.training_split if self.has_training_docs() else split)
373
+
374
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
375
+ inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), split=split)
376
+
377
+ if not isinstance(inst, list):
378
+ inst = [inst]
379
+
380
+ instances.extend(inst)
381
+ pbar.update(1)
382
+
383
+ pbar.close()
384
+ self._instances = instances
385
+ assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
386
+
387
+ @abc.abstractmethod
388
+ def construct_requests(self, doc_id, ctx, **kwargs):
389
+ """Uses RequestFactory to construct Requests and returns an iterable of
390
+ Requests which will be sent to the LMM.
391
+
392
+ :param doc_id: int
393
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
394
+ whichever is the main split used.
395
+ :param ctx: str
396
+ The context string, generated by fewshot_context. This includes the natural
397
+ language description, as well as the few shot examples, and the question
398
+ part of the document for `doc`.
399
+ :param repeats: int
400
+ TODO: update this docstring
401
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
402
+ can be increased for techniques like majority voting.
403
+ """
404
+ pass
405
+
406
+ @abc.abstractmethod
407
+ def process_results(self, doc, results):
408
+ """Take a single document and the LMM results and evaluates, returning a
409
+ dict where keys are the names of submetrics and values are the values of
410
+ the metric for that one document
411
+
412
+ :param doc:
413
+ The document as returned from training_docs, validation_docs, or test_docs.
414
+ :param results:
415
+ The results of the requests created in construct_requests.
416
+ """
417
+ pass
418
+
419
+ @abc.abstractmethod
420
+ def aggregation(self):
421
+ """
422
+ :returns: {str: [metric_score] -> float}
423
+ A dictionary where keys are the names of submetrics and values are
424
+ functions that aggregate a list of metric scores
425
+ """
426
+ pass
427
+
428
+ @abc.abstractmethod
429
+ def higher_is_better(self):
430
+ """
431
+ :returns: {str: bool}
432
+ A dictionary where keys are the names of submetrics and values are
433
+ whether a higher value of the submetric is better
434
+ """
435
+ pass
436
+
437
+ @classmethod
438
+ def count_bytes(cls, doc):
439
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
440
+ return len(doc.encode("utf-8"))
441
+
442
+ @utils.positional_deprecated
443
+ def fewshot_context(
444
+ self,
445
+ doc_id,
446
+ num_fewshot,
447
+ split,
448
+ rnd=random.Random(1234),
449
+ description=None,
450
+ ):
451
+ """Returns a fewshot context string that is made up of a prepended description
452
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
453
+
454
+ :param doc_id: int
455
+ The document id as returned from training_docs, validation_docs, or test_docs.
456
+ :param num_fewshot: int
457
+ The number of fewshot examples to provide in the returned context string.
458
+ :param split: str
459
+ The split of the document to retrieve from the dataset
460
+ :param rnd: random.Random
461
+ The pseudo-random number generator used to randomly sample examples.
462
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
463
+ :param description: str
464
+ The task's description that will be prepended to the fewshot examples.
465
+ :returns: str
466
+ The fewshot context.
467
+ """
468
+ assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
469
+
470
+ description = description if description else ""
471
+ doc = self.dataset_no_image[split][doc_id]
472
+
473
+ if num_fewshot == 0:
474
+ labeled_examples = ""
475
+ else:
476
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
477
+ if self.has_training_docs():
478
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
479
+ else:
480
+ if self._fewshot_docs is None:
481
+ self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
482
+
483
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
484
+
485
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
486
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
487
+
488
+ labeled_examples = "\n\n".join([self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]) + "\n\n"
489
+
490
+ example = self.doc_to_text(doc)
491
+ return description + labeled_examples + example
492
+
493
+ def apply_filters(self):
494
+ if hasattr(self, "_filters"):
495
+ for f in self._filters:
496
+ f.apply(self._instances, None)
497
+ else:
498
+ eval_logger.warning("No filter defined, passing through instances")
499
+ return self._instances
500
+
501
+ def dump_config(self) -> dict:
502
+ """Returns a dictionary representing the task's config.
503
+
504
+ :returns: str
505
+ The fewshot context.
506
+ """
507
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
508
+ # (num_fewshot)
509
+ return self.config.to_dict()
510
+
511
+
512
+ class ConfigurableTask(Task):
513
+ VERSION = "Yaml"
514
+ OUTPUT_TYPE = None
515
+ CONFIG = None
516
+
517
+ def __init__(self, model_name) -> None: # TODO no super() call here
518
+ # Get pre-configured attributes
519
+ self._config = self.CONFIG
520
+ # different model requires different prompt, we have to take those into account.
521
+
522
+ self.model_name = model_name
523
+ self._prepare_model_specific_config()
524
+
525
+ assert self.config.output_type in ALL_OUTPUT_TYPES
526
+ self.OUTPUT_TYPE = self.config.output_type
527
+
528
+ self.DATASET_PATH = self.config.dataset_path
529
+
530
+ if self.config.dataset_name is not None:
531
+ self.DATASET_NAME = self.config.dataset_name
532
+
533
+ self._prepare_metric_and_aggregation()
534
+
535
+ self.download(self.config.dataset_kwargs)
536
+ self._training_docs = None
537
+ self._fewshot_docs = None
538
+
539
+ if self.config.filter_list is not None:
540
+ self._filters = []
541
+ for filter_config in self.config.filter_list:
542
+ for filter_pipeline in filter_config:
543
+ filter_name = filter_config["name"]
544
+ filter_functions = filter_config["filter"]
545
+ components = []
546
+ for function in filter_functions:
547
+ kwargs = {key: function[key] for key in function if key != "function"}
548
+ components.append([function["function"], kwargs])
549
+ filter_pipeline = build_filter_ensemble(filter_name, components)
550
+ self._filters.append(filter_pipeline)
551
+ else:
552
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
553
+ if self.config.fewshot_config is not None:
554
+ self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234))
555
+
556
+ if self.has_test_docs():
557
+ self.task_docs = self.test_docs()
558
+ elif self.has_validation_docs():
559
+ self.task_docs = self.validation_docs()
560
+ else:
561
+ assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
562
+
563
+ # Test One Doc
564
+ self.features = list(self.task_docs.features.keys())
565
+ self.multiple_input = 0
566
+ self.multiple_target = 0
567
+ test_doc = self.task_docs[0]
568
+ test_text = self.doc_to_text(test_doc)
569
+ test_target = self.doc_to_target(test_doc)
570
+
571
+ if self.config.doc_to_choice is not None:
572
+ test_choice = self.doc_to_choice(test_doc)
573
+ if type(test_choice) is not list:
574
+ eval_logger.error("doc_to_choice must return list")
575
+ else:
576
+ num_choice = len(test_choice)
577
+
578
+ if type(test_text) is int:
579
+ self.multiple_input = num_choice
580
+ else:
581
+ test_choice = None
582
+
583
+ if type(test_target) is list:
584
+ self.multiple_target = len(test_target)
585
+ else:
586
+ if (type(test_target) is int) and (test_choice is not None):
587
+ test_target = test_choice[test_target]
588
+ else:
589
+ test_target = str(test_target)
590
+
591
+ if test_choice is not None:
592
+ check_choices = test_choice
593
+ else:
594
+ check_choices = [test_target]
595
+ if self.config.doc_to_choice is not None:
596
+ for choice in check_choices:
597
+ choice_has_whitespace = True if choice[0].isspace() else False
598
+ delimiter_has_whitespace = True if self.config.target_delimiter.rstrip() != self.config.target_delimiter else False
599
+
600
+ if delimiter_has_whitespace and choice_has_whitespace:
601
+ eval_logger.warning(f'Both target_delimiter and target choice: "{choice}" have whitespace')
602
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
603
+ eval_logger.warning(f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace')
604
+
605
+ def _prepare_model_specific_config(self):
606
+ self.model_specific_prompt_kwargs = self.config.model_specific_prompt_kwargs
607
+ if self.model_specific_prompt_kwargs is not None:
608
+ if self.model_name in self.model_specific_prompt_kwargs:
609
+ self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs[self.model_name]
610
+ else:
611
+ self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs.get("default", None)
612
+
613
+ self.model_specific_target_kwargs = self.config.model_specific_target_kwargs
614
+ if self.model_specific_target_kwargs is not None:
615
+ if self.model_name in self.model_specific_target_kwargs:
616
+ self.model_specific_target_kwargs = self.model_specific_target_kwargs[self.model_name]
617
+ else:
618
+ self.model_specific_target_kwargs = self.model_specific_target_kwargs.get("default", None)
619
+ self.model_specific_generation_kwargs = self.config.model_specific_generation_kwargs
620
+ if self.model_specific_generation_kwargs is not None:
621
+ if self.model_name in self.model_specific_generation_kwargs:
622
+ self.model_specific_generation_kwargs = self.model_specific_generation_kwargs[self.model_name]
623
+ else:
624
+ self.model_specific_generation_kwargs = self.model_specific_generation_kwargs.get("default", {})
625
+
626
+ self.config.generation_kwargs.update(self.model_specific_generation_kwargs)
627
+
628
+ def _prepare_metric_and_aggregation(self):
629
+ self._metric_fn_list = {}
630
+ self._metric_fn_kwargs = {}
631
+ self._aggregation_list = {}
632
+ self._higher_is_better = {}
633
+
634
+ if self.config.metric_list is None:
635
+ # TODO: handle this in TaskConfig.__post_init__ ?
636
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
637
+
638
+ for metric_name in _metric_list:
639
+ self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
640
+ self._metric_fn_kwargs[metric_name] = {}
641
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
642
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
643
+ else:
644
+ for metric_config in self.config.metric_list:
645
+ assert "metric" in metric_config
646
+ metric_name = metric_config["metric"]
647
+ kwargs = {key: metric_config[key] for key in metric_config if key not in ["metric", "aggregation", "higher_is_better"]}
648
+
649
+ if self.config.process_results is not None:
650
+ self._metric_fn_list[metric_name] = None
651
+ self._metric_fn_kwargs[metric_name] = {}
652
+ elif callable(metric_name):
653
+ metric_fn = metric_name.__call__
654
+ metric_name = metric_name.__name__
655
+ self._metric_fn_list[metric_name] = metric_fn
656
+ self._metric_fn_kwargs[metric_name] = kwargs
657
+ else:
658
+ self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
659
+ self._metric_fn_kwargs[metric_name] = kwargs
660
+
661
+ if "aggregation" in metric_config:
662
+ agg_name = metric_config["aggregation"]
663
+ if type(agg_name) == str:
664
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
665
+ elif callable(agg_name):
666
+ self._aggregation_list[metric_name] = metric_config["aggregation"]
667
+ else:
668
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
669
+ metric_agg = get_metric_aggregation(metric_name)
670
+ eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " f"using default " f"aggregation={INV_AGG_REGISTRY[metric_agg]}")
671
+ self._aggregation_list[metric_name] = metric_agg
672
+
673
+ if "higher_is_better" in metric_config:
674
+ self._higher_is_better[metric_name] = metric_config["higher_is_better"]
675
+ else:
676
+ eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. " f"using default " f"higher_is_better={is_higher_better(metric_name)}")
677
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
678
+
679
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
680
+ def download(self, dataset_kwargs=None) -> None:
681
+ download_config = DownloadConfig()
682
+ download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
683
+ download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
684
+ self.dataset = datasets.load_dataset(
685
+ path=self.DATASET_PATH,
686
+ name=self.DATASET_NAME,
687
+ download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
688
+ **dataset_kwargs if dataset_kwargs is not None else {},
689
+ )
690
+ self.dataset_no_image = datasets.load_dataset(
691
+ path=self.DATASET_PATH,
692
+ name=self.DATASET_NAME,
693
+ download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
694
+ **dataset_kwargs if dataset_kwargs is not None else {},
695
+ )
696
+ for doc_name in self.dataset_no_image:
697
+ remove_cols = []
698
+ features = self.dataset_no_image[doc_name].features
699
+ # If it is an Image instance or a Sequence of Image instance. Remove it
700
+ for feature in features:
701
+ if isinstance(features[feature], Image):
702
+ remove_cols.append(feature)
703
+ elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image):
704
+ remove_cols.append(feature)
705
+ for remove_col in remove_cols:
706
+ self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col)
707
+
708
+ def has_training_docs(self) -> bool:
709
+ if self.config.training_split is not None:
710
+ return True
711
+ else:
712
+ return False
713
+
714
+ def has_validation_docs(self) -> bool:
715
+ if self.config.validation_split is not None:
716
+ return True
717
+ else:
718
+ return False
719
+
720
+ def has_test_docs(self) -> bool:
721
+ if self.config.test_split is not None:
722
+ return True
723
+ else:
724
+ return False
725
+
726
+ def training_docs(self) -> datasets.Dataset:
727
+ if self.has_training_docs():
728
+ if self.config.process_docs is not None:
729
+ return self.config.process_docs(self.dataset[self.config.training_split])
730
+ return self.dataset[self.config.training_split]
731
+
732
+ def validation_docs(self) -> datasets.Dataset:
733
+ if self.has_validation_docs():
734
+ if self.config.process_docs is not None:
735
+ return self.config.process_docs(self.dataset[self.config.validation_split])
736
+ return self.dataset[self.config.validation_split]
737
+
738
+ def test_docs(self) -> datasets.Dataset:
739
+ if self.has_test_docs():
740
+ if self.config.process_docs is not None:
741
+ return self.config.process_docs(self.dataset[self.config.test_split])
742
+ return self.dataset[self.config.test_split]
743
+
744
+ def fewshot_docs(self):
745
+ if self.config.fewshot_split is not None:
746
+ return self.dataset[self.config.fewshot_split]
747
+ else:
748
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
749
+ eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.")
750
+ return super().fewshot_docs()
751
+
752
+ @utils.positional_deprecated
753
+ def fewshot_context(self, doc_id, num_fewshot, split):
754
+ """Returns a fewshot context string that is made up of a prepended description
755
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
756
+
757
+ :param doc_id: str
758
+ The document id as returned from training_docs, validation_docs, or test_docs.
759
+ :param num_fewshot: int
760
+ The number of fewshot examples to provide in the returned context string.
761
+ :returns: str
762
+ The fewshot context.
763
+ """
764
+ doc = self.dataset_no_image[split][doc_id]
765
+ if num_fewshot == 0:
766
+ # always prepend the (possibly empty) task description
767
+ labeled_examples = self.config.description
768
+ else:
769
+ labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot)
770
+ example = self.doc_to_text(doc)
771
+ if type(example) == str:
772
+ return labeled_examples + example
773
+ elif type(example) == list:
774
+ return [labeled_examples + ex for ex in example]
775
+ elif type(example) == int:
776
+ if self.config.doc_to_choice is not None:
777
+ choices = self.doc_to_choice(doc)
778
+ return labeled_examples + choices[example]
779
+ else:
780
+ return labeled_examples + str(example)
781
+
782
+ def apply_filters(self):
783
+ if hasattr(self, "_filters"):
784
+ for f in self._filters:
785
+ f.apply(self._instances, self.task_docs)
786
+ else:
787
+ eval_logger.warning("No filter defined, passing through instances")
788
+ return self._instances
789
+
790
+ def should_decontaminate(self):
791
+ return self.config.should_decontaminate
792
+
793
+ def doc_to_decontamination_query(self, doc):
794
+ if self.config.should_decontaminate:
795
+ if self.config.doc_to_decontamination_query is None:
796
+ return self.doc_to_text(doc)
797
+ else:
798
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
799
+ if doc_to_decontamination_query in self.features:
800
+ return doc[doc_to_decontamination_query]
801
+ elif callable(doc_to_decontamination_query):
802
+ return doc_to_decontamination_query(doc)
803
+ else:
804
+ return ast.literal_eval(utils.apply_template(self.config.doc_to_decontamination_query, doc))
805
+
806
+ def _process_doc(self, doc):
807
+ """
808
+ Override this to process (detokenize, strip, replace, etc.) individual
809
+ documents. This can be used in a map over documents of a data split.
810
+ E.g. `map(self._process_doc, self.dataset["validation"])`
811
+
812
+ :return: dict
813
+ The processed version of the specified `doc`.
814
+ """
815
+ return doc
816
+
817
+ def doc_to_text(self, doc):
818
+ doc_to_text = self.config.doc_to_text
819
+
820
+ if type(doc_to_text) == int:
821
+ return doc_to_text
822
+ elif type(doc_to_text) == str:
823
+ if doc_to_text in self.features:
824
+ # if self.config.doc_to_choice is not None:
825
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
826
+ # else:
827
+ return doc[doc_to_text]
828
+ else:
829
+ text_string = utils.apply_template(doc_to_text, doc)
830
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
831
+ return ast.literal_eval(text_string)
832
+ else:
833
+ return text_string
834
+ elif callable(doc_to_text):
835
+ return (
836
+ doc_to_text(doc, self.model_specific_prompt_kwargs)
837
+ if self.model_specific_prompt_kwargs is not None
838
+ else doc_to_text(
839
+ doc,
840
+ )
841
+ )
842
+ # Used when applying a Promptsource template
843
+ elif hasattr(doc_to_text, "apply"):
844
+ applied_prompt = doc_to_text.apply(doc)
845
+ if len(applied_prompt) == 2:
846
+ return applied_prompt[0]
847
+ else:
848
+ eval_logger.warning("Applied prompt returns empty string")
849
+ return self.config.fewshot_delimiter
850
+ else:
851
+ print(type(doc_to_text))
852
+ raise TypeError
853
+
854
+ def doc_to_target(self, doc: dict) -> Union[int, str, list]:
855
+ doc_to_target = self.config.doc_to_target
856
+
857
+ if type(doc_to_target) == int:
858
+ return doc_to_target
859
+ elif type(doc_to_target) == str:
860
+ if doc_to_target in self.features:
861
+ # if self.config.doc_to_choice is not None:
862
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
863
+ # else:
864
+ return doc[doc_to_target]
865
+ else:
866
+ target_string = utils.apply_template(doc_to_target, doc)
867
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
868
+ return ast.literal_eval(target_string)
869
+ elif len(target_string) >= 2 and (target_string[0] == "[") and (target_string[-1] == "]"):
870
+ try:
871
+ return ast.literal_eval(target_string)
872
+ except (SyntaxError, ValueError):
873
+ return target_string
874
+ else:
875
+ return target_string
876
+ elif type(doc_to_target) == list:
877
+ return doc_to_target
878
+ elif callable(doc_to_target):
879
+ return doc_to_target(doc, self.model_specific_target_kwargs) if self.model_specific_target_kwargs is not None else doc_to_target(doc)
880
+ # Used when applying a Promptsource template
881
+ elif hasattr(doc_to_target, "apply"):
882
+ applied_prompt = doc_to_target.apply(doc)
883
+ if len(applied_prompt) == 2:
884
+ return applied_prompt[1]
885
+ else:
886
+ eval_logger.warning("Applied prompt returns empty string")
887
+ return self.config.fewshot_delimiter
888
+ else:
889
+ raise TypeError
890
+
891
+ def doc_to_visual(self, doc: dict) -> Union[int, str, list]:
892
+ self.config.doc_to_visual
893
+ if type(self.config.doc_to_visual) == str:
894
+ assert self.config.doc_to_visual in self.features
895
+ # Single image. Still return a list for consistency.
896
+ return [doc[self.config.doc_to_visual]]
897
+ else:
898
+ assert callable(self.config.doc_to_visual)
899
+ return self.config.doc_to_visual(doc)
900
+
901
+ def doc_to_choice(self, doc: Any) -> List[str]:
902
+ if self.config.doc_to_choice is None:
903
+ eval_logger.error("doc_to_choice was called but not set in config")
904
+ else:
905
+ doc_to_choice = self.config.doc_to_choice
906
+
907
+ if type(doc_to_choice) == str:
908
+ if doc_to_choice in self.features:
909
+ return doc[doc_to_choice]
910
+ else:
911
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
912
+ elif type(doc_to_choice) == list:
913
+ return doc_to_choice
914
+ elif type(doc_to_choice) == dict:
915
+ return list(doc_to_choice.values())
916
+ elif callable(doc_to_choice):
917
+ return doc_to_choice(doc)
918
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
919
+ return doc_to_choice.get_answer_choices_list(doc)
920
+ else:
921
+ raise TypeError
922
+
923
+ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]:
924
+ split = kwargs.get("split")
925
+ kwargs.pop("split")
926
+ if self.OUTPUT_TYPE == "loglikelihood":
927
+ arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split)
928
+ elif self.OUTPUT_TYPE == "multiple_choice":
929
+ doc = self.dataset[split][doc_id]
930
+ choices = self.doc_to_choice(doc)
931
+ target_delimiter = self.config.target_delimiter
932
+ if self.multiple_input:
933
+ # If there are multiple inputs, choices are placed in the ctx
934
+ cont = self.doc_to_target(doc)
935
+ arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for ctx in choices]
936
+ else:
937
+ # Otherwise they are placed in the continuation
938
+ arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for cont in choices]
939
+ request_list = [
940
+ Instance(
941
+ request_type="loglikelihood",
942
+ # doc=doc,
943
+ arguments=arg,
944
+ idx=i,
945
+ **kwargs,
946
+ )
947
+ for i, arg in enumerate(arguments)
948
+ ]
949
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
950
+ if "acc_mutual_info" in self._metric_fn_list.keys():
951
+ # if we are calculating multiple choice accuracy
952
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
953
+
954
+ # here mutual info refers to calculating
955
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
956
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
957
+ request_list.extend(
958
+ [
959
+ Instance(
960
+ request_type="loglikelihood",
961
+ # doc=doc,
962
+ arguments=("", "{}".format(choice)),
963
+ idx=i,
964
+ **kwargs,
965
+ )
966
+ for i, choice in enumerate(choices)
967
+ ]
968
+ )
969
+ return request_list
970
+
971
+ elif self.OUTPUT_TYPE == "generate_until":
972
+ arguments = (ctx, self.config.generation_kwargs, self.doc_to_visual, doc_id, self.config.task, split)
973
+ return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)
974
+
975
+ def process_results(self, doc, results):
976
+ if callable(self.config.process_results):
977
+ return self.config.process_results(doc, results)
978
+
979
+ result_dict = {}
980
+ use_metric = list(self._metric_fn_list.keys())
981
+ if self.OUTPUT_TYPE == "loglikelihood":
982
+ results = results[0]
983
+ ll, is_greedy = results
984
+ return {
985
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
986
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
987
+ }
988
+ elif self.OUTPUT_TYPE == "multiple_choice":
989
+ lls, is_greedy = zip(*results)
990
+
991
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
992
+ choices = self.doc_to_choice(doc)
993
+ completion_len = np.array([float(len(i)) for i in choices])
994
+
995
+ if 2 * len(choices) == len(lls) and "acc_mutual_info" in self._metric_fn_list.keys():
996
+ # then we are doing mutual info.
997
+ # this stores the "dryrun" / unconditional answer loglikelihoods
998
+ lls_unconditional = lls[1::2]
999
+ assert len(lls_unconditional) == len(choices)
1000
+ # and this stores our "regular" conditional loglikelihoods
1001
+ lls = lls[::2]
1002
+
1003
+ pred = np.argmax(lls)
1004
+ pred_norm = np.argmax(lls / completion_len)
1005
+
1006
+ if self.multiple_input:
1007
+ gold = self.doc_to_text(doc)
1008
+ else:
1009
+ gold = self.doc_to_target(doc)
1010
+
1011
+ gold_index_error = False
1012
+ if type(gold) is list:
1013
+ gold = [i if i < len(choices) else -100 for i in gold]
1014
+ if -100 in gold:
1015
+ gold_index_error = True
1016
+ else:
1017
+ if type(gold) is int:
1018
+ gold = gold if gold < len(choices) else -100
1019
+ elif type(gold) is str:
1020
+ gold = choices.index(gold) if gold in choices else -100
1021
+
1022
+ if gold == -100:
1023
+ gold_index_error = True
1024
+
1025
+ if gold_index_error:
1026
+ eval_logger.warning(f"Label index was not in within range of available choices," f"Sample:\n\n{doc}\n\n")
1027
+
1028
+ if self.multiple_target:
1029
+ acc = 1.0 if pred in gold else 0.0
1030
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1031
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1032
+ else:
1033
+ acc = 1.0 if pred == gold else 0.0
1034
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1035
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1036
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1037
+
1038
+ result_dict = {
1039
+ **({"acc": acc} if "acc" in use_metric else {}),
1040
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1041
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1042
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1043
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1044
+ }
1045
+
1046
+ if "acc_mutual_info" in use_metric:
1047
+ lls_mutual_info = [ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)]
1048
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1049
+ result_dict["acc_mutual_info"] = acc_mutual_info
1050
+
1051
+ elif self.OUTPUT_TYPE == "generate_until":
1052
+ gold = self.doc_to_target(doc)
1053
+ result = results[0]
1054
+ if self.config.doc_to_choice is not None:
1055
+ # If you set doc_to_choice,
1056
+ # it assumes that doc_to_target returns a number.
1057
+ choices = self.doc_to_choice(doc)
1058
+ gold = choices[gold]
1059
+ # we expect multiple_targets to be a list.
1060
+ elif self.multiple_target:
1061
+ gold = list(gold)
1062
+ elif type(gold) != type(result):
1063
+ # cast gold to the same type as result
1064
+ gold = type(result)(gold)
1065
+
1066
+ for metric in self._metric_fn_list.keys():
1067
+ if self.multiple_target:
1068
+ # in the case where we have multiple targets,
1069
+ # return true if any are true
1070
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1071
+ scores = []
1072
+ if not isinstance(gold, list):
1073
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1074
+ # print(gold)
1075
+ gold = [gold]
1076
+ for gold_option in gold:
1077
+ try:
1078
+ result_score = self._metric_fn_list[metric](
1079
+ references=[gold_option],
1080
+ predictions=[result],
1081
+ **self._metric_fn_kwargs[metric],
1082
+ )
1083
+ except TypeError: # TODO: this is hacky and I don't want to do it
1084
+ result_score = self._metric_fn_list[metric]([gold_option, result])
1085
+ if isinstance(result_score, dict):
1086
+ # TODO: this handles the case where HF evaluate returns a dict.
1087
+ result_score = result_score[metric]
1088
+ scores.append(result_score)
1089
+ if any(scores):
1090
+ result_score = 1.0
1091
+ else:
1092
+ result_score = 0.0
1093
+ else:
1094
+ try:
1095
+ result_score = self._metric_fn_list[metric](
1096
+ references=[gold],
1097
+ predictions=[result],
1098
+ **self._metric_fn_kwargs[metric],
1099
+ )
1100
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1101
+ result_score = self._metric_fn_list[metric]([gold, result])
1102
+ if isinstance(result_score, dict):
1103
+ # TODO: this handles the case where HF evaluate returns a dict.
1104
+ result_score = result_score[metric]
1105
+ result_dict[metric] = result_score
1106
+ else:
1107
+ raise ValueError(
1108
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1109
+ "'loglikelihood','generate_until' or 'multiple_choice'",
1110
+ )
1111
+
1112
+ return result_dict
1113
+
1114
+ def aggregation(self):
1115
+ return self._aggregation_list
1116
+
1117
+ def higher_is_better(self):
1118
+ return self._higher_is_better
EAGLE/lmms_eval/filters/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lmms_eval.api.filter import FilterEnsemble
2
+ from . import selection
3
+ from . import extraction
4
+ from . import transformation
5
+
6
+
7
+ FILTER_REGISTRY = {
8
+ "take_first": selection.TakeFirstFilter,
9
+ "regex": extraction.RegexFilter,
10
+ "majority_vote": selection.MajorityVoteFilter,
11
+ "take_first_k": selection.TakeKFilter,
12
+ "remove_whitespace": extraction.WhitespaceFilter,
13
+ "lowercase": transformation.LowercaseFilter,
14
+ "uppercase": transformation.UppercaseFilter,
15
+ "map": transformation.MapFilter,
16
+ # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
17
+ # that takes an input and returns a scalar and then should select the max reward,
18
+ # or should implement different filters for different ways of handling a reward model's inference.
19
+ # "arg_max": selection.ArgMaxFilter,
20
+ }
21
+
22
+
23
+ def get_filter(filter_name):
24
+ if filter_name in FILTER_REGISTRY:
25
+ return FILTER_REGISTRY[filter_name]
26
+ else:
27
+ return filter_name
28
+
29
+
30
+ def build_filter_ensemble(filter_name, components):
31
+ """
32
+ Create a filtering pipeline.
33
+ """
34
+ filters = []
35
+ for function, kwargs in components:
36
+ if kwargs is None:
37
+ f = get_filter(function)()
38
+ else:
39
+ # create a filter given its name in the registry
40
+ f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
41
+ # add the filter as a pipeline step
42
+ filters.append(f)
43
+
44
+ return FilterEnsemble(name=filter_name, filters=filters)
EAGLE/lmms_eval/filters/decontamination.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lmms_eval.api.filter import Filter
2
+
3
+
4
+ class DecontaminationFilter(Filter):
5
+ """
6
+ A filter which evaluates
7
+ """
8
+
9
+ name = "track_decontamination"
10
+
11
+ def __init__(self, path) -> None:
12
+ """
13
+
14
+ TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
15
+ should further cache result on a given (task_name, doc_id)
16
+ """
17
+ self._decontam_results = None
18
+
19
+ def apply(self, resps, docs) -> None:
20
+ """
21
+ Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
22
+ """
23
+ pass
EAGLE/lmms_eval/filters/extraction.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from lmms_eval.api.filter import Filter
4
+
5
+
6
+ class RegexFilter(Filter):
7
+ """ """
8
+
9
+ def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
10
+ """
11
+ pass a string `regex` to run `re.compile(r"regex")` on.
12
+ `fallback` defines the output returned if no matches for the regex are located.
13
+ """
14
+ self.regex_pattern = regex_pattern
15
+ self.regex = re.compile(regex_pattern)
16
+ self.fallback = fallback
17
+
18
+ def apply(self, resps, docs):
19
+ # here, we assume we have a list, in which each element is
20
+ # a list of model responses for some particular input/target pair.
21
+ # so we process each of these (same input/target response sets)
22
+ # independently (and keep them a list.)
23
+ def filter_set(inst):
24
+ filtered = []
25
+ for resp in inst:
26
+ match = self.regex.search(resp)
27
+ if match:
28
+ match = match.group(1).strip()
29
+ else:
30
+ match = self.fallback
31
+ filtered.append(match)
32
+ return filtered
33
+
34
+ # print(resps)
35
+ filtered_resps = list(map(lambda x: filter_set(x), resps))
36
+ # print(filtered_resps)
37
+
38
+ return filtered_resps
39
+
40
+
41
+ class WhitespaceFilter(Filter):
42
+ """ """
43
+
44
+ def __init__(self) -> None:
45
+ pass
46
+
47
+ def apply(self, resps, docs):
48
+ def filter_set(inst):
49
+ filtered_resp = []
50
+ for resp in inst:
51
+ if resp.startswith(" "):
52
+ resp = resp[1:]
53
+
54
+ filtered_resp.append(resp)
55
+
56
+ return filtered_resp
57
+
58
+ filtered_resps = [filter_set(resp) for resp in resps]
59
+
60
+ return filtered_resps
EAGLE/lmms_eval/filters/selection.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from lmms_eval.api.filter import Filter
4
+
5
+
6
+ class TakeFirstFilter(Filter):
7
+ def __init__(self) -> None:
8
+ """
9
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
10
+ """
11
+
12
+ def apply(self, resps, docs):
13
+ """
14
+ Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
15
+ """
16
+ return map(lambda r: r[0], resps)
17
+
18
+
19
+ class TakeKFilter(Filter):
20
+ def __init__(self, *args, **kwargs) -> None:
21
+ self.k = kwargs.pop("k")
22
+
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def apply(self, resps, docs):
26
+ # check we have at least k responses per doc, else we can't take the first k
27
+ assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
28
+ return map(lambda r: r[: self.k], resps)
29
+
30
+
31
+ class MajorityVoteFilter(Filter):
32
+ def __init__(self) -> None:
33
+ """
34
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
35
+ """
36
+
37
+ def apply(self, resps, docs):
38
+ """
39
+ Each entry of `resps` is a list of model responses.
40
+ We select the response that occurs most frequently in each entry of `resps`.
41
+ """
42
+
43
+ def select_majority(resp):
44
+ counts = Counter(resp)
45
+ vote = counts.most_common(1)[0][0]
46
+ return vote
47
+
48
+ return map(lambda r: [select_majority(r)], resps)
EAGLE/lmms_eval/filters/transformation.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lmms_eval.api.filter import Filter
2
+
3
+
4
+ class LowercaseFilter(Filter):
5
+ def __init__(self) -> None:
6
+ pass
7
+
8
+ def apply(self, resps, docs):
9
+ def filter_set(inst):
10
+ return [resp.lower() for resp in inst]
11
+
12
+ return [filter_set(resp) for resp in resps]
13
+
14
+
15
+ class UppercaseFilter(Filter):
16
+ def __init__(self) -> None:
17
+ pass
18
+
19
+ def apply(self, resps, docs):
20
+ def filter_set(inst):
21
+ return [resp.upper() for resp in inst]
22
+
23
+ return [filter_set(resp) for resp in resps]
24
+
25
+
26
+ class MapFilter(Filter):
27
+ def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
28
+ """
29
+ Initializes the MapFilter with a given mapping dictionary and default value.
30
+
31
+ Args:
32
+ - mapping_dict (dict): A dictionary containing the key-value mappings.
33
+ Default is an empty dictionary.
34
+ - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
35
+ Default is None.
36
+
37
+ Example:
38
+ mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
39
+ """
40
+ assert isinstance(mapping_dict, dict), "Provided mapping_dict is not a dictionary"
41
+ self.mapping_dict = mapping_dict
42
+ self.default_value = default_value
43
+
44
+ def apply(self, resps, docs):
45
+ def filter_set(inst):
46
+ return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
47
+
48
+ return [filter_set(resp) for resp in resps]
EAGLE/lmms_eval/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ AVAILABLE_MODELS = {
4
+ "eagle": "Eagle",
5
+ }
6
+
7
+ for model_name, model_class in AVAILABLE_MODELS.items():
8
+ try:
9
+ exec(f"from .{model_name} import {model_class}")
10
+ except ImportError:
11
+ pass
12
+
13
+
14
+ import hf_transfer
15
+
16
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
EAGLE/lmms_eval/models/eagle.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+
4
+ torch.backends.cuda.matmul.allow_tf32 = True
5
+
6
+ import logging
7
+ import copy
8
+ from tqdm import tqdm
9
+ from datetime import timedelta
10
+
11
+ from lmms_eval import utils
12
+ from lmms_eval.api.instance import Instance
13
+ from lmms_eval.api.model import lmms
14
+ from lmms_eval.api.registry import register_model
15
+ from lmms_eval.utils import stop_sequences_criteria
16
+
17
+ from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
18
+ from accelerate.state import AcceleratorState
19
+ from typing import List, Optional, Union, Tuple
20
+ import warnings
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ eval_logger = logging.getLogger("lmms-eval")
25
+
26
+ try:
27
+ from eagle.model.builder import load_pretrained_model
28
+ from eagle.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
29
+ from eagle.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
30
+ from eagle.conversation import conv_templates, SeparatorStyle
31
+ except ImportError:
32
+ eval_logger.error("Please add a symbolic link pointing to the eagle folder of repo ")
33
+
34
+ from transformers.integrations.deepspeed import (
35
+ is_deepspeed_zero3_enabled,
36
+ set_hf_deepspeed_config,
37
+ unset_hf_deepspeed_config,
38
+ )
39
+
40
+ def resize_image_with_aspect_ratio(img, min_size):
41
+ """
42
+ Resize an image while maintaining its aspect ratio.
43
+
44
+ Parameters:
45
+ - image_path: str, path to the input image.
46
+ - min_size: int, the minimum size for the shortest side of the image.
47
+
48
+ Returns:
49
+ - resized_image: PIL.Image object, the resized image.
50
+ """
51
+ # Get the original dimensions of the image
52
+ original_width, original_height = img.size
53
+
54
+ # Determine the aspect ratio
55
+ aspect_ratio = original_width / original_height
56
+
57
+ # Calculate new dimensions based on the shortest side
58
+ if original_width < original_height:
59
+ new_width = min_size
60
+ new_height = int(min_size / aspect_ratio)
61
+ else:
62
+ new_height = min_size
63
+ new_width = int(min_size * aspect_ratio)
64
+
65
+ # Resize the image while maintaining aspect ratio
66
+ resized_image = img.resize((new_width, new_height), Image.LANCZOS)# Image.ANTIALIAS)
67
+
68
+ return resized_image
69
+
70
+
71
+ @register_model("eagle")
72
+ class Eagle(lmms):
73
+ """
74
+ EAGLE Model
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ pretrained: str = "NVEagle/Eagle-X5-7B",
80
+ truncation: Optional[bool] = True,
81
+ device: Optional[str] = "cuda",
82
+ dtype: Optional[Union[str, torch.dtype]] = "",
83
+ batch_size: Optional[Union[int, str]] = 1,
84
+ trust_remote_code: Optional[bool] = False,
85
+ revision=None,
86
+ use_flash_attention_2=True,
87
+ device_map="",
88
+ conv_template="vicuna_v1",
89
+ use_cache=True,
90
+ truncate_context=False,
91
+ **kwargs,
92
+ ) -> None:
93
+ super().__init__()
94
+ # Do not use kwargs for now
95
+ assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
96
+
97
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
98
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
99
+ if accelerator.num_processes > 1 and device_map == "":
100
+ self._device = torch.device(f"cuda:{accelerator.local_process_index}")
101
+ self.device_map = f"cuda:{accelerator.local_process_index}"
102
+ else:
103
+ self._device = torch.device(device)
104
+ self.device_map = device_map
105
+
106
+ self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2)
107
+ self._config = self._model.config
108
+ self.model.eval()
109
+ self.model.tie_weights()
110
+ self.truncation = truncation
111
+ self.batch_size_per_gpu = int(batch_size)
112
+ self.conv_template = conv_template
113
+ self.use_cache = use_cache
114
+ self.truncate_context = truncate_context
115
+
116
+ if accelerator.num_processes > 1 and device_map == "":
117
+ assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
118
+ # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
119
+ # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
120
+ # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
121
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
122
+ kwargs = {
123
+ "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
124
+ "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
125
+ }
126
+ AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
127
+ eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
128
+
129
+ if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
130
+ self._model = accelerator.prepare(self.model)
131
+ else:
132
+ self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
133
+ self.accelerator = accelerator
134
+ if self.accelerator.is_local_main_process:
135
+ eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
136
+ self._rank = self.accelerator.local_process_index
137
+ self._world_size = self.accelerator.num_processes
138
+ elif accelerator.num_processes == 1 and device_map == "auto":
139
+ eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
140
+ self._rank = 0
141
+ self._word_size = 1
142
+ else:
143
+ eval_logger.info(f"Using single device: {self._device}")
144
+ self.model.to(self._device)
145
+ self._rank = 0
146
+ self._world_size = 1
147
+
148
+ @property
149
+ def config(self):
150
+ # return the associated transformers.AutoConfig for the given pretrained model.
151
+ return self._config
152
+
153
+ @property
154
+ def tokenizer(self):
155
+ return self._tokenizer
156
+
157
+ @property
158
+ def model(self):
159
+ # returns the model, unwrapping it if using Accelerate
160
+ if hasattr(self, "accelerator"):
161
+ return self.accelerator.unwrap_model(self._model)
162
+ else:
163
+ return self._model
164
+
165
+ @property
166
+ def eot_token_id(self):
167
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
168
+ return self.tokenizer.eos_token_id
169
+
170
+ @property
171
+ def max_length(self):
172
+ return self._max_length
173
+
174
+ def pad_sequence(self, input_ids, batch_first, padding_value):
175
+ if self.tokenizer.padding_side == "left":
176
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
177
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
178
+ if self.tokenizer.padding_side == "left":
179
+ input_ids = torch.flip(input_ids, [1])
180
+ return input_ids
181
+
182
+ @property
183
+ def batch_size(self):
184
+ return self.batch_size_per_gpu
185
+
186
+ @property
187
+ def device(self):
188
+ return self._device
189
+
190
+ @property
191
+ def rank(self):
192
+ return self._rank
193
+
194
+ @property
195
+ def world_size(self):
196
+ return self._world_size
197
+
198
+ def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
199
+ """ """
200
+ add_special_tokens = False if add_special_tokens is None else add_special_tokens
201
+ encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
202
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
203
+ if left_truncate_len:
204
+ encoding = encoding[-left_truncate_len:]
205
+ return encoding
206
+
207
+ def tok_decode(self, tokens):
208
+ return self.tokenizer.decode(tokens)
209
+
210
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
211
+ # TODO
212
+ res = []
213
+ pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
214
+
215
+ for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
216
+ # encode, pad, and truncate contexts for this batch
217
+ if type(doc_to_target) == str:
218
+ continuation = doc_to_target
219
+ else:
220
+ continuation = doc_to_target(self.task_dict[task][split][doc_id])
221
+ visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
222
+ visuals = self.flatten(visuals)
223
+ if visuals:
224
+ image = process_images(visuals, self._image_processor, self._config)
225
+ if type(image) is list:
226
+ image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
227
+ else:
228
+ image = image.to(dtype=torch.float16, device=self.device)
229
+ else:
230
+ image = None
231
+
232
+ prompts_input = contexts[0]
233
+
234
+ if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
235
+ """
236
+ Three senarios:
237
+ 1. No image, and there for, no image token should be added.
238
+ 2. image token is already specified in the context, so we don't need to add it.
239
+ 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
240
+ """
241
+ image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
242
+ image_tokens = " ".join(image_tokens)
243
+ prompts_input = image_tokens + "\n" + contexts[0]
244
+
245
+ conv = conv_templates[self.conv_template].copy()
246
+ conv.append_message(conv.roles[0], prompts_input)
247
+ conv.append_message(conv.roles[1], None)
248
+ prompt = conv.get_prompt()
249
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
250
+ contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
251
+ # Add the answer of the second role
252
+ conv.messages[1][1] = continuation
253
+
254
+ prompt = conv.get_prompt()
255
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
256
+ labels = input_ids.clone()
257
+ # Context part no need to calculate for loss
258
+ labels[0, : contxt_id.shape[1]] = -100
259
+ with torch.inference_mode():
260
+ outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True)
261
+ loss = outputs["loss"]
262
+ # loss = torch.exp(loss)
263
+ logits = outputs["logits"]
264
+ greedy_tokens = logits.argmax(dim=-1)
265
+ cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq]
266
+ greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq]
267
+ max_equal = (greedy_tokens == cont_toks).all()
268
+ res.append((float(loss.item()), bool(max_equal)))
269
+ pbar.update(1)
270
+ pbar.close()
271
+ return res
272
+
273
+ def flatten(self, input):
274
+ new_list = []
275
+ for i in input:
276
+ for j in i:
277
+ new_list.append(j)
278
+ return new_list
279
+
280
+ def generate_until(self, requests: List[Instance]) -> List[str]:
281
+ res = []
282
+
283
+ def _collate(x):
284
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
285
+ # - time estimates will always be over not underestimates, which is more useful for planning
286
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
287
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
288
+ # automatic adaptive batches much much easier to implement
289
+ # - any OOMs will happen right away rather than near the end
290
+ toks = self.tok_encode(x[0])
291
+ return -len(toks), x[0]
292
+
293
+ # we group requests by their generation_kwargs,
294
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
295
+ # in the same batch.
296
+ re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
297
+ chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
298
+ num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
299
+ pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
300
+ for chunk in chunks:
301
+ contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
302
+ task = task[0]
303
+ split = split[0]
304
+ visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
305
+ visuals = self.flatten(visuals)
306
+ # we assume all gen kwargs in the batch are the same
307
+ # this is safe to assume because the `grouper` object ensures it.
308
+ gen_kwargs = all_gen_kwargs[0]
309
+
310
+ # Set default values for until and max_new_tokens
311
+ until = [self.tok_decode(self.eot_token_id)]
312
+
313
+ # Update values from gen_kwargs if present
314
+ if "until" in gen_kwargs:
315
+ until = gen_kwargs.pop("until")
316
+ if isinstance(until, str):
317
+ until = [until]
318
+ elif not isinstance(until, list):
319
+ raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
320
+
321
+ if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__:
322
+ # here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation
323
+ self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
324
+ eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")
325
+
326
+ if visuals:
327
+ image_tensor = process_images(visuals, self._image_processor, self._config)
328
+ if type(image_tensor) is list:
329
+ image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
330
+ else:
331
+ image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
332
+ else:
333
+ image_tensor = None
334
+
335
+ # prompts_input = contexts[0]
336
+
337
+ question_input = []
338
+
339
+ for visual, context in zip(visuals, contexts):
340
+ if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
341
+ """
342
+ Three senarios:
343
+ 1. No image, and there for, no image token should be added.
344
+ 2. image token is already specified in the context, so we don't need to add it.
345
+ 3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
346
+ """
347
+ image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
348
+ image_tokens = " ".join(image_tokens)
349
+ question = image_tokens + "\n" + context
350
+ else:
351
+ question = context
352
+
353
+ conv = conv_templates[self.conv_template].copy()
354
+ conv.append_message(conv.roles[0], question)
355
+ conv.append_message(conv.roles[1], None)
356
+ prompt_question = conv.get_prompt()
357
+ question_input.append(prompt_question)
358
+
359
+ # The above for loop has bugs. When there is no visuals, e.g. pure text,
360
+ # there will be no for loop execute resulting in an empty question_input (because no visuals)
361
+ # Scenario 1 won't even be execute
362
+ if len(visuals) == 0:
363
+ for context in contexts:
364
+ question = context
365
+ conv = conv_templates[self.conv_template].copy()
366
+ conv.append_message(conv.roles[0], question)
367
+ conv.append_message(conv.roles[1], None)
368
+ prompt_question = conv.get_prompt()
369
+ question_input.append(prompt_question)
370
+
371
+ # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
372
+ # preconfigure gen_kwargs with defaults
373
+ gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
374
+ if "max_new_tokens" not in gen_kwargs:
375
+ gen_kwargs["max_new_tokens"] = 1024
376
+ if "temperature" not in gen_kwargs:
377
+ gen_kwargs["temperature"] = 0
378
+ if "top_p" not in gen_kwargs:
379
+ gen_kwargs["top_p"] = None
380
+ if "num_beams" not in gen_kwargs:
381
+ gen_kwargs["num_beams"] = 1
382
+
383
+ input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input]
384
+ pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
385
+ input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
386
+ attention_masks = input_ids.ne(pad_token_ids).to(self.device)
387
+
388
+ try:
389
+ cont = self.model.generate(
390
+ input_ids,
391
+ attention_mask=attention_masks,
392
+ pad_token_id=pad_token_ids,
393
+ images=image_tensor,
394
+ image_sizes=gen_kwargs["image_sizes"],
395
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
396
+ temperature=gen_kwargs["temperature"],
397
+ top_p=gen_kwargs["top_p"],
398
+ num_beams=gen_kwargs["num_beams"],
399
+ max_new_tokens=gen_kwargs["max_new_tokens"],
400
+ use_cache=self.use_cache,
401
+ )
402
+ text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
403
+ except Exception as e:
404
+ eval_logger.error(f"Error {e} in generating")
405
+ cont = ""
406
+ text_outputs = [""]
407
+
408
+ res.extend(text_outputs)
409
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
410
+ pbar.update(1)
411
+
412
+ res = re_ords.get_original(res)
413
+
414
+ pbar.close()
415
+ return res
EAGLE/lmms_eval/models/gpt4v.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from copy import deepcopy
3
+ import os
4
+ import base64
5
+ from typing import List, Tuple
6
+ from tqdm import tqdm
7
+ import requests as url_requests
8
+ import time
9
+ import logging
10
+
11
+ from lmms_eval.api.instance import Instance
12
+ from lmms_eval.api.model import lmms
13
+ from lmms_eval.api.registry import register_model
14
+ from lmms_eval import utils
15
+
16
+ from PIL import Image
17
+
18
+ API_TYPE = os.getenv("API_TYPE", "openai")
19
+ NUM_SECONDS_TO_SLEEP = 5
20
+ eval_logger = logging.getLogger("lmms-eval")
21
+
22
+ if API_TYPE == "openai":
23
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
24
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
25
+ headers = {
26
+ "Authorization": f"Bearer {API_KEY}",
27
+ "Content-Type": "application/json",
28
+ }
29
+ elif API_TYPE == "azure":
30
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
31
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
32
+ headers = {
33
+ "api-key": API_KEY,
34
+ "Content-Type": "application/json",
35
+ }
36
+
37
+
38
+ @register_model("gpt4V")
39
+ class GPT4V(lmms):
40
+ def __init__(self, **kwargs) -> None:
41
+ super().__init__()
42
+ # Manually set a image token for GPT4V so that we can search for it
43
+ # and split the text and image
44
+ # Here we just use the same token as llava for convenient
45
+ self.image_token = "<image>"
46
+
47
+ # Function to encode the image
48
+ def encode_image(self, image: Image):
49
+ output_buffer = BytesIO()
50
+ image.save(output_buffer, format="JPEG")
51
+ byte_data = output_buffer.getvalue()
52
+ base64_str = base64.b64encode(byte_data).decode("utf-8")
53
+ return base64_str
54
+
55
+ def flatten(self, input):
56
+ new_list = []
57
+ for i in input:
58
+ for j in i:
59
+ new_list.append(j)
60
+ return new_list
61
+
62
+ def generate_until(self, requests) -> List[str]:
63
+ res = []
64
+ pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
65
+
66
+ for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
67
+ # encode, pad, and truncate contexts for this batch
68
+ visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
69
+ visuals = self.flatten(visuals)
70
+ imgs = []
71
+ for visual in visuals:
72
+ img = self.encode_image(visual)
73
+ imgs.append(img)
74
+
75
+ payload = {"model": "gpt-4-vision-preview", "messages": []}
76
+ response_json = {"role": "user", "content": []}
77
+ # When there is no image token in the context, append the image to the text
78
+ if self.image_token not in contexts:
79
+ payload["messages"].append(deepcopy(response_json))
80
+ payload["messages"][0]["content"].append({"type": "text", "text": contexts})
81
+ for img in imgs:
82
+ payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
83
+ else:
84
+ contexts = contexts.split(self.image_token)
85
+ for idx, img in enumerate(imgs):
86
+ payload["messages"].append(deepcopy(response_json))
87
+ payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]})
88
+ payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
89
+
90
+ # If n image tokens are in the contexts
91
+ # contexts will be splitted into n+1 chunks
92
+ # Manually add it into the payload
93
+ payload["messages"].append(deepcopy(response_json))
94
+ payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]})
95
+
96
+ if "max_new_tokens" not in gen_kwargs:
97
+ gen_kwargs["max_new_tokens"] = 1024
98
+ if "temperature" not in gen_kwargs:
99
+ gen_kwargs["temperature"] = 0
100
+ if "top_p" not in gen_kwargs:
101
+ gen_kwargs["top_p"] = None
102
+ if "num_beams" not in gen_kwargs:
103
+ gen_kwargs["num_beams"] = 1
104
+
105
+ # payload["max_tokens"] = gen_kwargs["max_new_tokens"]
106
+ # payload["temperature"] = gen_kwargs["temperature"]
107
+
108
+ for attempt in range(5):
109
+ try:
110
+ response = url_requests.post(API_URL, headers=headers, json=payload, timeout=20)
111
+ response_data = response.json()
112
+
113
+ content = response_data["choices"][0]["message"]["content"].strip()
114
+ break # If successful, break out of the loop
115
+
116
+ except Exception as e:
117
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}")
118
+ if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt
119
+ time.sleep(NUM_SECONDS_TO_SLEEP)
120
+ else: # If this was the last attempt, log and return empty
121
+ eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}")
122
+ content = ""
123
+ res.append(content)
124
+ pbar.update(1)
125
+ return res
126
+
127
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
128
+ # TODO
129
+ assert False, "GPT4V not support"
EAGLE/lmms_eval/tasks/__init__.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, Dict
3
+
4
+ from lmms_eval import utils
5
+
6
+ # from lmms_eval import prompts
7
+ from lmms_eval.api.task import TaskConfig, Task, ConfigurableTask
8
+ from lmms_eval.api.registry import (
9
+ register_task,
10
+ register_group,
11
+ TASK_REGISTRY,
12
+ GROUP_REGISTRY,
13
+ ALL_TASKS,
14
+ )
15
+
16
+ import logging
17
+
18
+ eval_logger = logging.getLogger("lmms-eval")
19
+
20
+
21
+ def register_configurable_task(config: Dict[str, str]) -> int:
22
+ SubClass = type(
23
+ config["task"] + "ConfigurableTask",
24
+ (ConfigurableTask,),
25
+ {"CONFIG": TaskConfig(**config)},
26
+ )
27
+
28
+ if "task" in config:
29
+ task_name = "{}".format(config["task"])
30
+ register_task(task_name)(SubClass)
31
+
32
+ if "group" in config:
33
+ if config["group"] == config["task"]:
34
+ raise ValueError("task and group name cannot be the same")
35
+ elif type(config["group"]) == str:
36
+ group_name = [config["group"]]
37
+ else:
38
+ group_name = config["group"]
39
+
40
+ for group in group_name:
41
+ register_group(group)(SubClass)
42
+
43
+ return 0
44
+
45
+
46
+ def register_configurable_group(config: Dict[str, str]) -> int:
47
+ group = config["group"]
48
+ task_list = config["task"]
49
+ task_names = utils.pattern_match(task_list, ALL_TASKS)
50
+ for task in task_names:
51
+ if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
52
+ if group in GROUP_REGISTRY:
53
+ GROUP_REGISTRY[group].append(task)
54
+ else:
55
+ GROUP_REGISTRY[group] = [task]
56
+ ALL_TASKS.add(group)
57
+ return 0
58
+
59
+
60
+ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
61
+ if "dataset_name" in task_config:
62
+ return "{dataset_path}_{dataset_name}".format(**task_config)
63
+ else:
64
+ return "{dataset_path}".format(**task_config)
65
+
66
+
67
+ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
68
+ """
69
+ Calling this function
70
+ """
71
+ for root, subdirs, file_list in os.walk(task_dir):
72
+ # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
73
+ for f in file_list:
74
+ if f.endswith(".yaml"):
75
+ yaml_path = os.path.join(root, f)
76
+ try:
77
+ config = utils.load_yaml_config(yaml_path)
78
+
79
+ if "task" not in config:
80
+ continue
81
+
82
+ if register_task:
83
+ if type(config["task"]) == str:
84
+ register_configurable_task(config)
85
+ else:
86
+ if type(config["task"]) == list:
87
+ register_configurable_group(config)
88
+
89
+ # Log this silently and show it only when
90
+ # the user defines the appropriate verbosity.
91
+ except ModuleNotFoundError as e:
92
+ eval_logger.debug(f"{yaml_path}: {e}. Config will not be added to registry.")
93
+ except Exception as error:
94
+ import traceback
95
+
96
+ eval_logger.debug(f"Failed to load config in {yaml_path}. Config will not be added to registry\n" f"Error: {error}\n" f"Traceback: {traceback.format_exc()}")
97
+ return 0
98
+
99
+
100
+ def include_path(task_dir):
101
+ include_task_folder(task_dir)
102
+ # Register Benchmarks after all tasks have been added
103
+ include_task_folder(task_dir, register_task=False)
104
+ return 0
105
+
106
+
107
+ def initialize_tasks(verbosity="INFO"):
108
+ eval_logger.setLevel(getattr(logging, f"{verbosity}"))
109
+ task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
110
+ include_path(task_dir)
111
+
112
+
113
+ def get_task(task_name, model_name):
114
+ try:
115
+ return TASK_REGISTRY[task_name](model_name=model_name)
116
+ except KeyError:
117
+ eval_logger.info("Available tasks:")
118
+ eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
119
+ raise KeyError(f"Missing task {task_name}")
120
+
121
+
122
+ def get_task_name_from_object(task_object):
123
+ for name, class_ in TASK_REGISTRY.items():
124
+ if class_ is task_object:
125
+ return name
126
+
127
+ # TODO: scrap this
128
+ # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
129
+ return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
130
+
131
+
132
+ # TODO: pass num_fewshot and other cmdline overrides in a better way
133
+ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], model_name: str):
134
+ all_task_dict = {}
135
+
136
+ # Ensure task_name_list is a list to simplify processing
137
+ if not isinstance(task_name_list, list):
138
+ task_name_list = [task_name_list]
139
+
140
+ for task_element in task_name_list:
141
+ if isinstance(task_element, str) and task_element in GROUP_REGISTRY:
142
+ group_name = task_element
143
+ for task_name in GROUP_REGISTRY[task_element]:
144
+ if task_name not in all_task_dict:
145
+ # Recursively get the task dictionary for nested groups
146
+ task_obj = get_task_dict([task_name], model_name)
147
+ # Merge the dictionaries
148
+ all_task_dict.update({task_name: (group_name, task_obj.get(task_name, None))})
149
+ else:
150
+ task_name = task_element if isinstance(task_element, str) else task_element.EVAL_HARNESS_NAME
151
+ if task_name not in all_task_dict:
152
+ task_obj = get_task(task_name=task_name, model_name=model_name)
153
+ all_task_dict[task_name] = task_obj
154
+
155
+ return all_task_dict
EAGLE/lmms_eval/tasks/_task_utils/file_utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def generate_submission_file(file_name, args, subpath="submissions"):
5
+ path = os.path.join(args.output_path, subpath)
6
+ os.makedirs(path, exist_ok=True)
7
+ path = os.path.join(path, file_name)
8
+ return os.path.abspath(path)
EAGLE/lmms_eval/tasks/_task_utils/gpt_eval_utils.py ADDED
File without changes
EAGLE/lmms_eval/tasks/_task_utils/vqa_eval_metric.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ class EvalAIAnswerProcessor:
5
+ """
6
+ Processes an answer similar to Eval AI
7
+ copied from
8
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
9
+ """
10
+
11
+ CONTRACTIONS = {
12
+ "aint": "ain't",
13
+ "arent": "aren't",
14
+ "cant": "can't",
15
+ "couldve": "could've",
16
+ "couldnt": "couldn't",
17
+ "couldn'tve": "couldn't've",
18
+ "couldnt've": "couldn't've",
19
+ "didnt": "didn't",
20
+ "doesnt": "doesn't",
21
+ "dont": "don't",
22
+ "hadnt": "hadn't",
23
+ "hadnt've": "hadn't've",
24
+ "hadn'tve": "hadn't've",
25
+ "hasnt": "hasn't",
26
+ "havent": "haven't",
27
+ "hed": "he'd",
28
+ "hed've": "he'd've",
29
+ "he'dve": "he'd've",
30
+ "hes": "he's",
31
+ "howd": "how'd",
32
+ "howll": "how'll",
33
+ "hows": "how's",
34
+ "Id've": "I'd've",
35
+ "I'dve": "I'd've",
36
+ "Im": "I'm",
37
+ "Ive": "I've",
38
+ "isnt": "isn't",
39
+ "itd": "it'd",
40
+ "itd've": "it'd've",
41
+ "it'dve": "it'd've",
42
+ "itll": "it'll",
43
+ "let's": "let's",
44
+ "maam": "ma'am",
45
+ "mightnt": "mightn't",
46
+ "mightnt've": "mightn't've",
47
+ "mightn'tve": "mightn't've",
48
+ "mightve": "might've",
49
+ "mustnt": "mustn't",
50
+ "mustve": "must've",
51
+ "neednt": "needn't",
52
+ "notve": "not've",
53
+ "oclock": "o'clock",
54
+ "oughtnt": "oughtn't",
55
+ "ow's'at": "'ow's'at",
56
+ "'ows'at": "'ow's'at",
57
+ "'ow'sat": "'ow's'at",
58
+ "shant": "shan't",
59
+ "shed've": "she'd've",
60
+ "she'dve": "she'd've",
61
+ "she's": "she's",
62
+ "shouldve": "should've",
63
+ "shouldnt": "shouldn't",
64
+ "shouldnt've": "shouldn't've",
65
+ "shouldn'tve": "shouldn't've",
66
+ "somebody'd": "somebodyd",
67
+ "somebodyd've": "somebody'd've",
68
+ "somebody'dve": "somebody'd've",
69
+ "somebodyll": "somebody'll",
70
+ "somebodys": "somebody's",
71
+ "someoned": "someone'd",
72
+ "someoned've": "someone'd've",
73
+ "someone'dve": "someone'd've",
74
+ "someonell": "someone'll",
75
+ "someones": "someone's",
76
+ "somethingd": "something'd",
77
+ "somethingd've": "something'd've",
78
+ "something'dve": "something'd've",
79
+ "somethingll": "something'll",
80
+ "thats": "that's",
81
+ "thered": "there'd",
82
+ "thered've": "there'd've",
83
+ "there'dve": "there'd've",
84
+ "therere": "there're",
85
+ "theres": "there's",
86
+ "theyd": "they'd",
87
+ "theyd've": "they'd've",
88
+ "they'dve": "they'd've",
89
+ "theyll": "they'll",
90
+ "theyre": "they're",
91
+ "theyve": "they've",
92
+ "twas": "'twas",
93
+ "wasnt": "wasn't",
94
+ "wed've": "we'd've",
95
+ "we'dve": "we'd've",
96
+ "weve": "we've",
97
+ "werent": "weren't",
98
+ "whatll": "what'll",
99
+ "whatre": "what're",
100
+ "whats": "what's",
101
+ "whatve": "what've",
102
+ "whens": "when's",
103
+ "whered": "where'd",
104
+ "wheres": "where's",
105
+ "whereve": "where've",
106
+ "whod": "who'd",
107
+ "whod've": "who'd've",
108
+ "who'dve": "who'd've",
109
+ "wholl": "who'll",
110
+ "whos": "who's",
111
+ "whove": "who've",
112
+ "whyll": "why'll",
113
+ "whyre": "why're",
114
+ "whys": "why's",
115
+ "wont": "won't",
116
+ "wouldve": "would've",
117
+ "wouldnt": "wouldn't",
118
+ "wouldnt've": "wouldn't've",
119
+ "wouldn'tve": "wouldn't've",
120
+ "yall": "y'all",
121
+ "yall'll": "y'all'll",
122
+ "y'allll": "y'all'll",
123
+ "yall'd've": "y'all'd've",
124
+ "y'alld've": "y'all'd've",
125
+ "y'all'dve": "y'all'd've",
126
+ "youd": "you'd",
127
+ "youd've": "you'd've",
128
+ "you'dve": "you'd've",
129
+ "youll": "you'll",
130
+ "youre": "you're",
131
+ "youve": "you've",
132
+ }
133
+
134
+ NUMBER_MAP = {
135
+ "none": "0",
136
+ "zero": "0",
137
+ "one": "1",
138
+ "two": "2",
139
+ "three": "3",
140
+ "four": "4",
141
+ "five": "5",
142
+ "six": "6",
143
+ "seven": "7",
144
+ "eight": "8",
145
+ "nine": "9",
146
+ "ten": "10",
147
+ }
148
+ ARTICLES = ["a", "an", "the"]
149
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
150
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
151
+ PUNCTUATIONS = [
152
+ ";",
153
+ r"/",
154
+ "[",
155
+ "]",
156
+ '"',
157
+ "{",
158
+ "}",
159
+ "(",
160
+ ")",
161
+ "=",
162
+ "+",
163
+ "\\",
164
+ "_",
165
+ "-",
166
+ ">",
167
+ "<",
168
+ "@",
169
+ "`",
170
+ ",",
171
+ "?",
172
+ "!",
173
+ ]
174
+
175
+ def __init__(self, *args, **kwargs):
176
+ pass
177
+
178
+ def word_tokenize(self, word):
179
+ word = word.lower()
180
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
181
+ return word.strip()
182
+
183
+ def process_punctuation(self, in_text):
184
+ out_text = in_text
185
+ for p in self.PUNCTUATIONS:
186
+ if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None):
187
+ out_text = out_text.replace(p, "")
188
+ else:
189
+ out_text = out_text.replace(p, " ")
190
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
191
+ return out_text
192
+
193
+ def process_digit_article(self, in_text):
194
+ out_text = []
195
+ temp_text = in_text.lower().split()
196
+ for word in temp_text:
197
+ word = self.NUMBER_MAP.setdefault(word, word)
198
+ if word not in self.ARTICLES:
199
+ out_text.append(word)
200
+ else:
201
+ pass
202
+ for word_id, word in enumerate(out_text):
203
+ if word in self.CONTRACTIONS:
204
+ out_text[word_id] = self.CONTRACTIONS[word]
205
+ out_text = " ".join(out_text)
206
+ return out_text
207
+
208
+ def __call__(self, item):
209
+ item = self.word_tokenize(item)
210
+ item = item.replace("\n", " ").replace("\t", " ").strip()
211
+ item = self.process_punctuation(item)
212
+ item = self.process_digit_article(item)
213
+ return item
EAGLE/lmms_eval/tasks/cmmmu/_cmmmu.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ group: cmmmu
2
+ task:
3
+ - cmmmu_val
4
+ - cmmmu_test
EAGLE/lmms_eval/tasks/cmmmu/_default_template_cmmmu_yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/CMMMU
2
+ output_type: generate_until
3
+ doc_to_visual: !function utils.cmmmu_doc_to_visual
4
+ doc_to_text: !function utils.cmmmu_doc_to_text
5
+ doc_to_target: "answer"
6
+ generation_kwargs:
7
+ max_new_tokens: 16
8
+ image_aspect_ratio: original
EAGLE/lmms_eval/tasks/cmmmu/cmmmu_test.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: "cmmmu_test"
2
+ test_split: test
3
+ # The return value of process_results will be used by metrics
4
+ process_results: !function utils.cmmmu_process_test_results_for_submission
5
+ # Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
6
+ metric_list:
7
+ - metric: submission
8
+ aggregation: !function utils.cmmmu_test_aggregate_results_for_submission
9
+ higher_is_better: false
10
+ metadata:
11
+ - version: 0.0
12
+ include: _default_template_cmmmu_yaml
EAGLE/lmms_eval/tasks/cmmmu/cmmmu_val.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: "cmmmu_val"
2
+ test_split: val
3
+ # The return value of process_results will be used by metrics
4
+ process_results: !function utils.cmmmu_process_results
5
+ # Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
6
+ generation_kwargs:
7
+ max_new_tokens: 16
8
+ image_aspect_ratio: original
9
+ metric_list:
10
+ - metric: cmmmu_acc
11
+ aggregation: !function utils.cmmmu_aggregate_results
12
+ higher_is_better: true
13
+ metadata:
14
+ - version: 0.0
15
+ include: _default_template_cmmmu_yaml
EAGLE/lmms_eval/tasks/cmmmu/utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import re
3
+ import random
4
+ import os
5
+ import json
6
+ import logging
7
+ from collections import Counter
8
+ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
9
+
10
+ eval_logger = logging.getLogger("lmms-eval")
11
+
12
+ PROMPT = {
13
+ "task_instructions": [
14
+ "请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。",
15
+ "请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。",
16
+ "请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。",
17
+ ],
18
+ "multi_choice_example_format": ["问题:{}\n选项:\n{}\n正确答案:\n"],
19
+ "T/F_example_format": ["问题:{}\n正确答案:\n"],
20
+ "short_ans_example_format": ["问题:{}\n正确答案:\n"],
21
+ }
22
+
23
+
24
+ def construct_prompt(sample):
25
+ question = sample["question"]
26
+ task_instructions = PROMPT["task_instructions"]
27
+
28
+ if sample["type"] == "选择":
29
+ formatted_options = ""
30
+ start_chr = "A"
31
+ for i in range(1, 5):
32
+ formatted_options += f"({start_chr}) {sample[f'option{i}']}\n"
33
+ start_chr = chr(ord(start_chr) + 1)
34
+
35
+ current_example_template = PROMPT["multi_choice_example_format"][0]
36
+ current_example = current_example_template.format(question, formatted_options)
37
+ final_input_prompt = task_instructions[0] + "\n\n" + current_example
38
+
39
+ elif sample["type"] == "判断":
40
+ current_example_template = PROMPT["T/F_example_format"][0]
41
+ current_example = current_example_template.format(question)
42
+ final_input_prompt = task_instructions[1] + "\n\n" + current_example
43
+
44
+ else: # For fill in the blanks questions.
45
+ current_example_template = PROMPT["short_ans_example_format"][0]
46
+ current_example = current_example_template.format(question)
47
+ final_input_prompt = task_instructions[2] + "\n\n" + current_example
48
+
49
+ for i in range(1, 6):
50
+ final_input_prompt = final_input_prompt.replace(f'<img="{sample[f"image_{i}_filename"]}">', f"<图片 {i}>")
51
+
52
+ return final_input_prompt
53
+
54
+
55
+ def cmmmu_doc_to_text(doc):
56
+ return construct_prompt(doc)
57
+
58
+
59
+ def cmmmu_doc_to_visual(doc):
60
+ prompt = construct_prompt(doc)
61
+ image_tokens = re.findall(r"<图片 \d+>", prompt)
62
+ # Remove <> and swap space as _
63
+ image_tokens = [image_token.strip("<>").replace(" ", "_").replace("图片", "image") for image_token in image_tokens]
64
+ visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
65
+ return visual
66
+
67
+
68
+ def cmmmu_process_results(doc, results):
69
+ pred = results[0]
70
+ if doc["type"] == "选择":
71
+ index2ans, all_choices = get_multi_choice_info([doc[f"option{i}"] for i in range(1, 5)])
72
+ parsed_pred = get_multi_choice_prediction(pred, all_choices, index2ans)
73
+ elif doc["type"] == "判断":
74
+ parsed_pred = get_TF_prediction(pred)
75
+ else:
76
+ parsed_pred = get_fill_blank_prediction(pred, doc["answer"])
77
+ return {"cmmmu_acc": {"id": doc["id"], "subdomain": doc["subcategory"], "question_type": doc["type"], "answer": doc["answer"], "parsed_pred": parsed_pred}}
78
+
79
+
80
+ def cmmmu_aggregate_results(results):
81
+ evaluation_result = {}
82
+ subset_to_eval_samples = defaultdict(list)
83
+ for result in results:
84
+ subset_to_eval_samples[result["subdomain"]].append(result)
85
+ for subset, sub_eval_samples in subset_to_eval_samples.items():
86
+ metric_dict = eval_cmmmu(sub_eval_samples)
87
+ evaluation_result[subset] = metric_dict
88
+
89
+ printable_results = {}
90
+ for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
91
+ in_domain_cat_results = {}
92
+ for cat_name in in_domain_cats:
93
+ if cat_name in evaluation_result.keys():
94
+ in_domain_cat_results[cat_name] = evaluation_result[cat_name]
95
+ else:
96
+ pass
97
+ in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
98
+ in_domain_data_num = sum([cat_results["entries_num"] for cat_results in in_domain_cat_results.values()])
99
+ printable_results["Overall-" + domain] = {
100
+ "num": int(in_domain_data_num),
101
+ "acc": round(in_domain_ins_acc, 3),
102
+ }
103
+ # add sub category
104
+ for cat_name, cat_results in in_domain_cat_results.items():
105
+ printable_results[cat_name] = {
106
+ "num": int(cat_results["entries_num"]),
107
+ "acc": round(cat_results["acc"], 3),
108
+ }
109
+ all_ins_acc = calculate_ins_level_acc(evaluation_result)
110
+ printable_results["Overall"] = {
111
+ "num": sum([cat_results["entries_num"] for cat_results in evaluation_result.values()]),
112
+ "acc": round(all_ins_acc, 3),
113
+ }
114
+ print(printable_results)
115
+ return printable_results["Overall"]["acc"]
116
+
117
+
118
+ def cmmmu_process_test_results_for_submission(doc, results):
119
+ response = results[0]
120
+ return {"submission": {"id": doc["id"], "type": doc["type"], "response": response}}
121
+
122
+
123
+ def cmmmu_test_aggregate_results_for_submission(results, args):
124
+ file = generate_submission_file("cmmmu_test_for_submission.jsonl", args)
125
+ with open(file, "w", encoding="utf8") as f:
126
+ for result in results:
127
+ json.dump(result, f, ensure_ascii=False)
128
+ f.write("\n")
129
+ eval_logger.info(f"Submission file saved to {file}")
130
+
131
+
132
+ ##################
133
+ # Helper functions
134
+ ##################
135
+
136
+ DOMAIN_CAT2SUB_CAT = {
137
+ "艺术与设计": ["艺术", "艺术理论", "设计", "音乐"],
138
+ "商业": ["会计", "经济", "金融", "管理", "营销"],
139
+ "科学": ["生物", "化学", "地理", "数学", "物理"],
140
+ "健康与医学": ["基础医学", "临床医学", "诊断学与实验室医学", "制药", "公共卫生"],
141
+ "人文社会科学": ["历史", "文献学", "社会学", "心理学"],
142
+ "技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"],
143
+ }
144
+
145
+
146
+ def eval_cmmmu(entries):
147
+ correct_cnt = 0
148
+ for entry in entries:
149
+ parsed_pred = entry.get("parsed_pred", "")
150
+ correct = False
151
+ if entry.get("question_type") == "选择":
152
+ if parsed_pred == entry["answer"]:
153
+ correct_cnt += 1
154
+ correct = True
155
+
156
+ elif entry.get("question_type") == "填空":
157
+ norm_answers = normalize_str(entry["answer"], entry["answer"])
158
+
159
+ for pred in parsed_pred:
160
+ # already normalized
161
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
162
+ for norm_ans in norm_answers:
163
+ # only see if the string answer in the string pred
164
+ # print(norm_ans, pred)
165
+ if isinstance(norm_ans, str) and norm_ans in pred:
166
+ if not correct:
167
+ correct_cnt += 1
168
+ correct = True
169
+ break
170
+ else: # it's a number
171
+ if pred in norm_answers:
172
+ if not correct:
173
+ correct_cnt += 1
174
+ correct = True
175
+ break
176
+
177
+ else:
178
+ positive_keywords = ["正确", "对", "准确", "肯定", "对的"]
179
+ negative_keywords = ["不对", "错误", "不正确", "不准确", "不合适", "否定", "错的", "错"]
180
+ ambiguous_keywords = ["对错", "是否正确", "否正确", "或者", "是否", "正确性", "对不"]
181
+
182
+ def judge_similarity(pred_list, positive_keywords, negative_keywords):
183
+ positive_count = 0
184
+ negative_count = 0
185
+
186
+ for pred in pred_list:
187
+ if any(pos_word in pred for pos_word in positive_keywords):
188
+ positive_count += 1
189
+ elif any(neg_word in pred for neg_word in negative_keywords):
190
+ negative_count += 1
191
+
192
+ if positive_count > negative_count:
193
+ return "对"
194
+ elif negative_count > positive_count:
195
+ return "错"
196
+ else:
197
+ return random.choice(["对", "错"])
198
+
199
+ answer = entry["answer"]
200
+ parsed_pred = [word for word in parsed_pred if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
201
+ result = judge_similarity(parsed_pred, positive_keywords, negative_keywords)
202
+ if result == answer:
203
+ correct_cnt += 1
204
+ correct = True
205
+ if correct:
206
+ entry["judge"] = "正确"
207
+ else:
208
+ entry["judge"] = "错误"
209
+
210
+ if len(entries) == 0:
211
+ print("entries_num == 0, please check your file")
212
+ results_count = {"correct_num": 0, "entries_num": 0, "acc": 0}
213
+ else:
214
+ results_count = {"correct_num": correct_cnt, "entries_num": len(entries), "acc": correct_cnt / len(entries)}
215
+
216
+ return results_count
217
+
218
+
219
+ def get_multi_choice_prediction(response, all_choices, index2ans):
220
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
221
+ response = response.strip(char)
222
+ response = " " + response + " " # add space to avoid partial match
223
+
224
+ candidates = []
225
+
226
+ for choice in all_choices: # (A) (B) (C) (D)
227
+ # Add the choice to candidates each time it appears in the response
228
+ candidates.extend([choice for _ in range(response.count(f"({choice})"))])
229
+
230
+ if len(candidates) == 0:
231
+ for choice in all_choices: # A B C D
232
+ # Similarly, add the choice for each occurrence
233
+ candidates.extend([choice for _ in range(response.count(f"{choice}"))])
234
+
235
+ if len(candidates) == 0 and len(response.split()) >= 1:
236
+ for index, ans in index2ans.items():
237
+ # Add index for each occurrence of ans in response
238
+ candidates.extend([index for _ in range(response.count(ans))])
239
+
240
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
241
+ if len(candidates) == 0 and len(response.split()) >= 1:
242
+ for index, ans in index2ans.items():
243
+ if ans in response:
244
+ candidates.append(index)
245
+ index_ans = False # it's content ans.
246
+
247
+ if len(candidates) == 0: # still not get answer, randomly choose one.
248
+ return random.choice(all_choices)
249
+ # return ''
250
+ else:
251
+ # Count the occurrence of each candidate
252
+ candidate_counts = Counter(candidates)
253
+
254
+ # Select the most frequent candidates
255
+ max_count = max(candidate_counts.values())
256
+ most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
257
+
258
+ # Combine the most frequent candidates in ABCD order
259
+ return "".join(most_frequent_candidates)
260
+
261
+
262
+ def extract_numbers(string):
263
+ # Pattern for numbers with Chinese commas
264
+ pattern_commas = r"-?\d{1,3}(?:,\d{3})+"
265
+ # Pattern for scientific notation
266
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
267
+ # Pattern for simple numbers without Chinese commas
268
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)"
269
+
270
+ # Extract numbers with Chinese commas
271
+ numbers_with_commas = re.findall(pattern_commas, string)
272
+ # Extract numbers in scientific notation
273
+ numbers_scientific = re.findall(pattern_scientific, string)
274
+ # Extract simple numbers without Chinese commas
275
+ numbers_simple = re.findall(pattern_simple, string)
276
+
277
+ # Combine all extracted numbers
278
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
279
+ return all_numbers
280
+
281
+
282
+ def check_is_number(string):
283
+ try:
284
+ float(string.replace(",", ""))
285
+ return True
286
+ except ValueError:
287
+ # check if there's comma inside
288
+ return False
289
+
290
+
291
+ def count_letters(string):
292
+ return sum(c.isalpha() and "a" <= c <= "z" or "A" <= c <= "Z" for c in string)
293
+
294
+
295
+ def normalize_str(string, answer):
296
+ # check if characters in the string
297
+
298
+ # if number, numerize it.
299
+ if string == None:
300
+ return [string]
301
+ string = string.strip()
302
+
303
+ is_number = check_is_number(string)
304
+
305
+ if is_number:
306
+ string = string.replace(",", "")
307
+ string = float(string)
308
+ # leave 2 decimal
309
+ string = round(string, 2)
310
+ return [string]
311
+ else: # it's likely to be a string
312
+ if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
313
+ return []
314
+ return [string]
315
+
316
+
317
+ def get_fill_blank_prediction(response, answer):
318
+ """get the prediction from the generated response,
319
+ return a list of predicted strings or numbers"""
320
+
321
+ def get_key_subresponses(response):
322
+ key_responses = []
323
+ response = response.strip("。").strip()
324
+ sub_responses = re.split(r"。|\n", response)
325
+ indicators_of_keys = ["是", "为", "所以", "等于", "方案", "选择", "正确答案", "因此", "最后", "答案", "结果"]
326
+ key_responses = []
327
+ for index, resp in enumerate(sub_responses):
328
+ # if last one, accept it's an equation (the entire response can be just one sentence with equation)
329
+ if index == len(sub_responses) - 1:
330
+ indicators_of_keys.extend(["="])
331
+ shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
332
+ for indicator in indicators_of_keys:
333
+ if indicator in resp:
334
+ if not shortest_key_response:
335
+ shortest_key_response = resp.split(indicator)[-1].strip()
336
+ else:
337
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
338
+ shortest_key_response = resp.split(indicator)[-1].strip()
339
+
340
+ if shortest_key_response:
341
+ # and it's not trivial
342
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
343
+ key_responses.append(shortest_key_response)
344
+ if len(key_responses) == 0: # did not found any
345
+ return [response]
346
+ return key_responses
347
+
348
+ key_responses = get_key_subresponses(response)
349
+
350
+ pred_list = key_responses.copy() # keep the original string response
351
+ for resp in key_responses:
352
+ pred_list.extend(extract_numbers(resp))
353
+
354
+ tmp_pred_list = []
355
+ for i in range(len(pred_list)):
356
+ tmp_pred_list.extend(normalize_str(pred_list[i], answer))
357
+ pred_list = tmp_pred_list
358
+
359
+ # remove duplicates
360
+ pred_list = list(set(pred_list))
361
+
362
+ return pred_list
363
+
364
+
365
+ def get_TF_prediction(response):
366
+ """get the prediction from the generated response,
367
+ return a list of predicted strings or numbers"""
368
+
369
+ def get_key_subresponses(response):
370
+ key_responses = []
371
+ response = response.strip("。").strip()
372
+ sub_responses = re.split(r"。|\n", response)
373
+ indicators_of_keys = ["是", "为", "所以", "判断", "陈述", "说法", "表达", "答案", "结果"]
374
+ key_responses = []
375
+ for index, resp in enumerate(sub_responses):
376
+ shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
377
+ for indicator in indicators_of_keys:
378
+ if indicator in resp:
379
+ if not shortest_key_response:
380
+ shortest_key_response = resp.split(indicator)[-1].strip()
381
+ else:
382
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
383
+ shortest_key_response = resp.split(indicator)[-1].strip()
384
+
385
+ if shortest_key_response:
386
+ # and it's not trivial
387
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
388
+ key_responses.append(shortest_key_response)
389
+ if len(key_responses) == 0: # did not found any
390
+ return [response]
391
+ return key_responses
392
+
393
+ key_responses = get_key_subresponses(response)
394
+
395
+ pred_list = key_responses.copy() # keep the original string response
396
+ # remove duplicates
397
+ pred_list = list(set(pred_list))
398
+
399
+ return pred_list
400
+
401
+
402
+ def get_multi_choice_info(options):
403
+ start_chr = "A"
404
+ all_choices = []
405
+ index2ans = {}
406
+ for i, option in enumerate(options):
407
+ index2ans[chr(ord(start_chr) + i)] = option
408
+ all_choices.append(chr(ord(start_chr) + i))
409
+
410
+ return index2ans, all_choices
411
+
412
+
413
+ def calculate_ins_level_acc(results):
414
+ correct_sum = 0
415
+ entries_sum = 0
416
+ for cat_results in results.values():
417
+ correct_sum += cat_results["correct_num"]
418
+ entries_sum += cat_results["entries_num"]
419
+ if entries_sum == 0:
420
+ return 0
421
+ return correct_sum / entries_sum
EAGLE/lmms_eval/tasks/gqa/gqa.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/GQA
2
+ dataset_name: testdev_balanced_instructions
3
+ dataset_kwargs:
4
+ token: True
5
+ task: "gqa"
6
+ test_split: testdev
7
+ output_type: generate_until
8
+ doc_to_visual: !function utils.gqa_doc_to_visual
9
+ doc_to_text: !function utils.gqa_doc_to_text
10
+ doc_to_target: "answer"
11
+ generation_kwargs:
12
+ max_new_tokens: 16
13
+ temperature: 0
14
+ top_p: 0
15
+ num_beams: 1
16
+ do_sample: false
17
+ metric_list:
18
+ - metric: exact_match
19
+ aggregation: mean
20
+ higher_is_better: true
21
+ ignore_case: true
22
+ ignore_punctuation: true
23
+ metadata:
24
+ - version: 0.0
25
+
26
+ model_specific_prompt_kwargs:
27
+ default:
28
+ pre_prompt: ""
29
+ post_prompt: "\nAnswer the question using a single word or phrase."
30
+ qwen_vl:
31
+ pre_prompt: ""
32
+ post_prompt: " Answer:"
EAGLE/lmms_eval/tasks/gqa/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ GQA_RAW_IMAGE_DATASET = None
4
+ GQA_ID2IMAGE = None
5
+
6
+
7
+ def gqa_doc_to_visual(doc):
8
+ global GQA_RAW_IMAGE_DATASET
9
+ global GQA_ID2IMAGE
10
+ if GQA_RAW_IMAGE_DATASET is None:
11
+ GQA_RAW_IMAGE_DATASET = load_dataset("lmms-lab/GQA", "testdev_balanced_images", split="testdev", token=True)
12
+ GQA_ID2IMAGE = {}
13
+ for row in GQA_RAW_IMAGE_DATASET:
14
+ GQA_ID2IMAGE[row["id"]] = row["image"].convert("RGB")
15
+ image = GQA_ID2IMAGE[doc["imageId"]]
16
+ return [image]
17
+
18
+
19
+ def gqa_doc_to_text(doc, model_specific_prompt_kwargs):
20
+ question = doc["question"]
21
+ pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
22
+ post_prompt = model_specific_prompt_kwargs["post_prompt"]
23
+ return f"{pre_prompt}{question}{post_prompt}"
EAGLE/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/llava-bench-in-the-wild
2
+ dataset_kwargs:
3
+ token: True
4
+ task: "llava_in_the_wild"
5
+ test_split: train
6
+ output_type: generate_until
7
+ doc_to_visual: !function utils.llava_doc_to_visual
8
+ doc_to_text: !function utils.llava_doc_to_text
9
+ doc_to_target: "gpt_answer"
10
+ generation_kwargs:
11
+ until:
12
+ - "ASSISTANT:"
13
+ image_aspect_ratio: original
14
+ max_new_tokens: 1024
15
+ temperature: 0
16
+ top_p: 0
17
+ num_beams: 1
18
+ do_sample: false
19
+ process_results: !function utils.llava_process_results
20
+ metric_list:
21
+ - metric: gpt_eval_llava_all
22
+ aggregation: !function utils.llava_all_aggregation
23
+ higher_is_better: true
24
+ - metric: gpt_eval_llava_conv
25
+ aggregation: !function utils.llava_conv_aggregation
26
+ higher_is_better: true
27
+ - metric: gpt_eval_llava_detail
28
+ aggregation: !function utils.llava_detail_aggregation
29
+ higher_is_better: true
30
+ - metric: gpt_eval_llava_complex
31
+ aggregation: !function utils.llava_complex_aggregation
32
+ higher_is_better: true
33
+ metadata:
34
+ version: 0.0
35
+ gpt_eval_model_name: "gpt-4-0613"
36
+ model_specific_prompt_kwargs:
37
+ default:
38
+ pre_prompt: ""
39
+ post_prompt: ""
EAGLE/lmms_eval/tasks/llava-in-the-wild/rule.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."},
3
+ "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."},
4
+ "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
5
+ "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
6
+ "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
7
+ "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
8
+ "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
9
+ "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
10
+ "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
11
+ }
EAGLE/lmms_eval/tasks/llava-in-the-wild/utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import requests
5
+ import numpy as np
6
+ import openai
7
+ from openai import OpenAI
8
+ import time
9
+ import yaml
10
+ from pathlib import Path
11
+ from copy import deepcopy
12
+
13
+ eval_logger = logging.getLogger("lmms-eval")
14
+ NUM_SECONDS_TO_SLEEP = 5
15
+
16
+ LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_complex"]
17
+
18
+ rule_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rule.json"), "r"))
19
+
20
+ with open(Path(__file__).parent / "llava-in-the-wild.yaml", "r") as f:
21
+ raw_data = f.readlines()
22
+ safe_data = []
23
+ for i, line in enumerate(raw_data):
24
+ # remove function definition since yaml load cannot handle it
25
+ if "!function" not in line:
26
+ safe_data.append(line)
27
+
28
+ config = yaml.safe_load("".join(safe_data))
29
+
30
+ GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
31
+
32
+ API_TYPE = os.getenv("API_TYPE", "openai")
33
+
34
+ if API_TYPE == "openai":
35
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
36
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
37
+ headers = {
38
+ "Authorization": f"Bearer {API_KEY}",
39
+ "Content-Type": "application/json",
40
+ }
41
+ elif API_TYPE == "azure":
42
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
43
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
44
+ headers = {
45
+ "api-key": API_KEY,
46
+ "Content-Type": "application/json",
47
+ }
48
+
49
+
50
+ def get_eval(content: str, max_tokens: int, retries: int = 5):
51
+ global headers
52
+
53
+ messages = [
54
+ {
55
+ "role": "system",
56
+ "content": "You are a helpful and precise assistant for checking the quality of the answer.",
57
+ },
58
+ {"role": "user", "content": content},
59
+ ]
60
+
61
+ payload = {
62
+ "model": GPT_EVAL_MODEL_NAME,
63
+ "messages": messages,
64
+ "temperature": 0.2,
65
+ "max_tokens": max_tokens,
66
+ }
67
+
68
+ for attempt in range(retries):
69
+ try:
70
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
71
+ response.raise_for_status()
72
+ response_data = response.json()
73
+
74
+ content = response_data["choices"][0]["message"]["content"].strip()
75
+ if content != "":
76
+ return content, response_data["model"]
77
+ break # If successful, break out of the loop
78
+
79
+ except Exception as e:
80
+ eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
81
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
82
+ time.sleep(NUM_SECONDS_TO_SLEEP)
83
+ else: # If this was the last attempt, log and return empty
84
+ eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
85
+ return "", ""
86
+ return "", ""
87
+
88
+
89
+ def parse_score(review):
90
+ try:
91
+ score_pair = review.split("\n")[0]
92
+ score_pair = score_pair.replace(",", " ")
93
+ sp = score_pair.split(" ")
94
+ if len(sp) == 2:
95
+ return [float(sp[0]), float(sp[1])]
96
+ else:
97
+ eval_logger.debug(f"Can not split: {review}. Returning [-1, -1]")
98
+ return [-1, -1]
99
+ except Exception as e:
100
+ eval_logger.debug(f"Error: {e}. Returning [-1, -1]")
101
+ return [-1, -1]
102
+
103
+
104
+ def llava_doc_to_visual(doc):
105
+ return [doc["image"].convert("RGB")]
106
+
107
+
108
+ def llava_doc_to_text(doc, model_specific_prompt_kwargs=None):
109
+ if model_specific_prompt_kwargs is None:
110
+ model_specific_prompt_kwargs = {}
111
+ pre_prompt = model_specific_prompt_kwargs.get("pre_prompt", "")
112
+ post_prompt = model_specific_prompt_kwargs.get("post_prompt", "")
113
+ return f"{pre_prompt}{doc['question']}{post_prompt}"
114
+
115
+
116
+ def llava_process_results(doc, result):
117
+ """
118
+ Args:
119
+ doc: a instance of the eval dataset
120
+ results: [pred]
121
+ Returns:
122
+ a dictionary with key: metric name (in this case coco_bleu), value: metric value
123
+ """
124
+ try:
125
+ question = doc.get("question", "")
126
+ ans1 = doc.get("gpt_answer", "")
127
+ ans2 = result[0] if result else ""
128
+ captions = doc.get("caption", [])
129
+ context = "\n".join(captions) if isinstance(captions, list) else captions
130
+ category = "llava_bench_" + doc.get("category", "")
131
+ rule = rule_dict.get(category, {})
132
+ prompt = rule.get("prompt", "")
133
+ role = rule.get("role", "user")
134
+ content = f"[Context]\n{context}\n\n" f"[Question]\n{question}\n\n" f"[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n" f"[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n" f"[System]\n{prompt}\n\n"
135
+
136
+ review, model_name = get_eval(content, 1024)
137
+ scores = parse_score(review)
138
+ except Exception as e:
139
+ eval_logger.error(f"Error for Question ID: {doc.get('question_id', 'Unknown')}: {e}")
140
+ review = "Failed to Get a Proper Review."
141
+ model_name = "Failed Request"
142
+ scores = [-1, -1]
143
+
144
+ metric = f"gpt_eval_llava_{doc.get('category', 'all')}"
145
+ category_review_dict = {"question": question, "ans1": ans1, "ans2": ans2, "context": context, "category": category, "review": review, "scores": scores, "eval_model": model_name, "content": content}
146
+
147
+ non_category_review_dict = deepcopy(category_review_dict)
148
+ non_category_review_dict["scores"] = [-999, -999]
149
+
150
+ data_dict = {}
151
+ for m in LLAVA_W_METRICS:
152
+ if m == metric:
153
+ data_dict[m] = category_review_dict
154
+ else:
155
+ data_dict[m] = non_category_review_dict
156
+ data_dict["gpt_eval_llava_all"] = category_review_dict
157
+
158
+ # return {"gpt_eval_llava_all": review_dict}
159
+ return data_dict
160
+
161
+
162
+ def llava_conv_aggregation(results):
163
+ return llava_aggregation(results, "conv")
164
+
165
+
166
+ def llava_complex_aggregation(results):
167
+ return llava_aggregation(results, "complex")
168
+
169
+
170
+ def llava_detail_aggregation(results):
171
+ return llava_aggregation(results, "detail")
172
+
173
+
174
+ def llava_all_aggregation(results):
175
+ return llava_aggregation(results, "all")
176
+
177
+
178
+ def llava_aggregation(results, category):
179
+ try:
180
+ scores = []
181
+ for result in results:
182
+ if -999 in result["scores"]:
183
+ continue
184
+ scores.append(result["scores"])
185
+
186
+ stats = np.asarray(scores).mean(0).tolist()
187
+ stats = [round(x, 3) for x in stats]
188
+ # gpt4_score_percentage = stats[0] * 10
189
+ # model_score_percentage = stats[1] * 10
190
+ # eval_logger.info(f"Category: {category}")
191
+ # eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%")
192
+ # eval_logger.info(f"Model Score: {model_score_percentage:.1f}%")
193
+ # eval_logger.info("=========================")
194
+ return round(stats[1] / stats[0] * 100, 1)
195
+ except Exception as e:
196
+ eval_logger.info(f"Error in llava_aggregation: {e}, and in category: {category}")
197
+ return None
EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_cn_yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/MMBench
2
+ dataset_kwargs:
3
+ token: True
4
+ doc_to_target: "answer"
5
+ dataset_name: "cn"
6
+ output_type: generate_until
7
+ doc_to_visual: !function cn_utils.mmbench_doc_to_visual
8
+ doc_to_text: !function cn_utils.mmbench_doc_to_text
9
+ generation_kwargs:
10
+ max_new_tokens: 256
11
+ temperature: 0
12
+ top_p: 0
13
+ num_beams: 1
14
+ do_sample: false
15
+ process_results: !function cn_utils.mmbench_process_results
16
+ model_specific_prompt_kwargs:
17
+ default:
18
+ pre_prompt: ""
19
+ post_prompt: "\n请直接使用所提供的选项字母作为答案回答。"
20
+ model_specific_generation_kwargs:
21
+ llava:
22
+ image_aspect_ratio: original
EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_en_yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/MMBench
2
+ dataset_kwargs:
3
+ token: True
4
+ doc_to_target: "answer"
5
+ model_specific_prompt_kwargs:
6
+ default:
7
+ pre_prompt: ""
8
+ post_prompt: "\nAnswer with the option's letter from the given choices directly."
9
+ doc_to_visual: !function en_utils.mmbench_doc_to_visual
10
+ doc_to_text: !function en_utils.mmbench_doc_to_text
11
+ doc_to_target: "answer"
12
+ process_results: !function en_utils.mmbench_process_results
13
+ model_specific_generation_kwargs:
14
+ llava:
15
+ image_aspect_ratio: original
16
+ output_type: generate_until
17
+ dataset_name: "en"
18
+ generation_kwargs:
19
+ until:
20
+ - "ASSISTANT:"
21
+ max_new_tokens: 1024
22
+ temperature: 0
23
+ top_p: 0
24
+ num_beams: 1
25
+ do_sample: false
EAGLE/lmms_eval/tasks/mmbench/cc_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import yaml
3
+ import os
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import json
7
+
8
+ eval_logger = logging.getLogger("lmms-eval")
9
+ from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
10
+ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
11
+
12
+ with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
13
+ raw_data = f.readlines()
14
+ safe_data = []
15
+ for i, line in enumerate(raw_data):
16
+ # remove function definition since yaml load cannot handle it
17
+ if "!function" not in line:
18
+ safe_data.append(line)
19
+
20
+ config = yaml.safe_load("".join(safe_data))
21
+
22
+ GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
23
+ API_TYPE = os.getenv("API_TYPE", "openai")
24
+
25
+ if API_TYPE == "openai":
26
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
27
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
28
+ elif API_TYPE == "azure":
29
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
30
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
31
+
32
+
33
+ mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
34
+
35
+
36
+ def mmbench_doc_to_visual(doc):
37
+ return [doc["image"].convert("RGB")]
38
+
39
+
40
+ def mmbench_cn_cc_doc_to_text(doc, model_specific_prompt_kwargs=None):
41
+ option_candidate = ["A", "B", "C", "D", "E"]
42
+ options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
43
+
44
+ data = {
45
+ # "img": doc["image"],
46
+ "question": doc["question"],
47
+ "answer": doc.get("answer", None),
48
+ "options": options_prompt,
49
+ "category": doc["category"],
50
+ "options_dict": options_dict,
51
+ "index": doc["index"],
52
+ "source": doc["source"],
53
+ }
54
+
55
+ query_prompt = f"{data['question']} {data['options']}"
56
+
57
+ if model_specific_prompt_kwargs:
58
+ query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
59
+
60
+ return query_prompt
61
+
62
+
63
+ def mmbench_cn_cc_process_results(doc, results):
64
+ model_response = results[0].strip()
65
+ data = {
66
+ "gpt_eval_score": {
67
+ "index": doc["index"],
68
+ "question": doc["question"],
69
+ "answer": doc["answer"],
70
+ "prediction": model_response,
71
+ "source": doc["source"],
72
+ "category": doc["category"],
73
+ },
74
+ "submission": {
75
+ "index": doc["index"],
76
+ "question": doc["question"],
77
+ "answer": doc["answer"],
78
+ "prediction": model_response,
79
+ "source": doc["source"],
80
+ "category": doc["category"],
81
+ },
82
+ }
83
+ option_candidate = ["A", "B", "C", "D", "E"]
84
+ for c in option_candidate:
85
+ data["submission"][c] = doc.get(c, "nan")
86
+ data["gpt_eval_score"][c] = doc.get(c, "nan")
87
+ return data
88
+
89
+
90
+ def mmbench_cn_cc_aggregate_dev_results_eval(results, args):
91
+ print(f"============= MMBench-CN(CC) Detailed Results =============")
92
+ overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
93
+ file = generate_submission_file("mmbench_cn_cc_results.json", args)
94
+ details_info = {
95
+ "overall_acc": overall_acc,
96
+ "category_acc": category_acc,
97
+ "l2_category_acc": l2_category_acc,
98
+ }
99
+ with open(file, "w") as f:
100
+ json.dump(details_info, f)
101
+ return overall_acc * 100
102
+
103
+
104
+ def mmbench_cn_cc_aggregate_results(results, args):
105
+ df = pd.DataFrame(results)
106
+ file = generate_submission_file("mmbench_cn_cc_results.xlsx", args)
107
+ with pd.ExcelWriter(file) as writer:
108
+ df.to_excel(writer, index=False)
109
+ eval_logger.info(f"Saved results to {file}")
EAGLE/lmms_eval/tasks/mmbench/cn_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import yaml
3
+ import os
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import json
7
+ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
8
+
9
+ eval_logger = logging.getLogger("lmms-eval")
10
+ from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
11
+ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
12
+
13
+ with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
14
+ raw_data = f.readlines()
15
+ safe_data = []
16
+ for i, line in enumerate(raw_data):
17
+ # remove function definition since yaml load cannot handle it
18
+ if "!function" not in line:
19
+ safe_data.append(line)
20
+
21
+ config = yaml.safe_load("".join(safe_data))
22
+
23
+ GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
24
+ API_TYPE = os.getenv("API_TYPE", "openai")
25
+
26
+ if API_TYPE == "openai":
27
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
28
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
29
+ elif API_TYPE == "azure":
30
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
31
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
32
+
33
+
34
+ mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
35
+
36
+
37
+ def mmbench_doc_to_visual(doc):
38
+ return [doc["image"].convert("RGB")]
39
+
40
+
41
+ def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None):
42
+ option_candidate = ["A", "B", "C", "D", "E"]
43
+ options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
44
+
45
+ data = {
46
+ # "img": doc["image"],
47
+ "question": doc["question"],
48
+ "answer": doc.get("answer", None),
49
+ "options": options_prompt,
50
+ "category": doc["category"],
51
+ "L2-category": doc["L2-category"],
52
+ "options_dict": options_dict,
53
+ "index": doc["index"],
54
+ "hint": doc["hint"],
55
+ "source": doc["source"],
56
+ "split": doc["split"],
57
+ }
58
+
59
+ query_prompt = f"{data['hint']} {data['question']} {data['options']}" if pd.notna(data["hint"]) else f"{data['question']} {data['options']}"
60
+
61
+ if model_specific_prompt_kwargs:
62
+ query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
63
+
64
+ return query_prompt
65
+
66
+
67
+ def mmbench_process_results(doc, results):
68
+ model_response = results[0].strip()
69
+ data = {
70
+ "gpt_eval_score": {
71
+ "index": doc["index"],
72
+ "question": doc["question"],
73
+ "answer": doc["answer"],
74
+ "prediction": model_response,
75
+ "hint": doc["hint"],
76
+ "source": doc["source"],
77
+ "split": doc["split"],
78
+ "category": doc["category"],
79
+ "L2-category": doc["L2-category"],
80
+ },
81
+ "submission": {
82
+ "index": doc["index"],
83
+ "question": doc["question"],
84
+ "answer": doc["answer"],
85
+ "prediction": model_response,
86
+ "hint": doc["hint"],
87
+ "source": doc["source"],
88
+ "split": doc["split"],
89
+ "category": doc["category"],
90
+ "L2-category": doc["L2-category"],
91
+ },
92
+ }
93
+ option_candidate = ["A", "B", "C", "D", "E"]
94
+ for c in option_candidate:
95
+ data["submission"][c] = doc.get(c, "nan")
96
+ data["gpt_eval_score"][c] = doc.get(c, "nan")
97
+ return data
98
+
99
+
100
+ def mmbench_aggregate_dev_results_eval(results, args):
101
+ print(f"============= MMBench-CN(Dev) Detailed Results =============")
102
+ overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
103
+ file = generate_submission_file("mmbench_cn_dev_results.json", args)
104
+ details_info = {
105
+ "overall_acc": overall_acc,
106
+ "category_acc": category_acc,
107
+ "l2_category_acc": l2_category_acc,
108
+ }
109
+ with open(file, "w") as f:
110
+ json.dump(details_info, f)
111
+ return overall_acc * 100
112
+
113
+
114
+ def mmbench_aggregate_dev_results(results, args):
115
+ df = pd.DataFrame(results)
116
+ excel_write_path = generate_submission_file("mmbench_cn_dev_results.xlsx", args)
117
+ with pd.ExcelWriter(excel_write_path) as writer:
118
+ df.to_excel(writer, index=False)
119
+ eval_logger.info(f"Saved results to {excel_write_path}")
120
+
121
+
122
+ def mmbench_aggregate_test_results(results, args):
123
+ df = pd.DataFrame(results)
124
+ excel_write_path = generate_submission_file("mmbench_cn_test_results.xlsx", args)
125
+ with pd.ExcelWriter(excel_write_path) as writer:
126
+ df.to_excel(writer, index=False)
127
+ eval_logger.info(f"Saved results to {excel_write_path}")
EAGLE/lmms_eval/tasks/mmbench/en_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import yaml
3
+ import os
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import json
7
+
8
+ eval_logger = logging.getLogger("lmms-eval")
9
+ from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
10
+ from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
11
+
12
+ with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
13
+ raw_data = f.readlines()
14
+ safe_data = []
15
+ for i, line in enumerate(raw_data):
16
+ # remove function definition since yaml load cannot handle it
17
+ if "!function" not in line:
18
+ safe_data.append(line)
19
+
20
+ config = yaml.safe_load("".join(safe_data))
21
+
22
+ GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
23
+ API_TYPE = os.getenv("API_TYPE", "openai")
24
+
25
+ if API_TYPE == "openai":
26
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
27
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
28
+ elif API_TYPE == "azure":
29
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
30
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
31
+
32
+
33
+ mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
34
+
35
+
36
+ def mmbench_doc_to_visual(doc):
37
+ return [doc["image"].convert("RGB")]
38
+
39
+
40
+ def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None):
41
+ option_candidate = ["A", "B", "C", "D", "E"]
42
+ options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
43
+
44
+ data = {
45
+ # "img": doc["image"],
46
+ "question": doc["question"],
47
+ "answer": doc.get("answer", None),
48
+ "options": options_prompt,
49
+ "category": doc["category"],
50
+ "L2-category": doc["L2-category"],
51
+ "options_dict": options_dict,
52
+ "index": doc["index"],
53
+ "hint": doc["hint"],
54
+ "source": doc["source"],
55
+ "split": doc["split"],
56
+ }
57
+
58
+ query_prompt = f"{data['hint']} {data['question']} {data['options']}" if pd.notna(data["hint"]) and data["hint"] != "nan" else f"{data['question']} {data['options']}"
59
+
60
+ if model_specific_prompt_kwargs:
61
+ query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
62
+
63
+ return query_prompt
64
+
65
+
66
+ def mmbench_process_results(doc, results):
67
+ model_response = results[0].strip()
68
+ data = {
69
+ "gpt_eval_score": {
70
+ "index": doc["index"],
71
+ "question": doc["question"],
72
+ "answer": doc["answer"],
73
+ "prediction": model_response,
74
+ "hint": doc["hint"],
75
+ "source": doc["source"],
76
+ "split": doc["split"],
77
+ "category": doc["category"],
78
+ "L2-category": doc["L2-category"],
79
+ },
80
+ "submission": {
81
+ "index": doc["index"],
82
+ "question": doc["question"],
83
+ "answer": doc["answer"],
84
+ "prediction": model_response,
85
+ "hint": doc["hint"],
86
+ "source": doc["source"],
87
+ "split": doc["split"],
88
+ "category": doc["category"],
89
+ "L2-category": doc["L2-category"],
90
+ },
91
+ }
92
+ option_candidate = ["A", "B", "C", "D", "E"]
93
+ for c in option_candidate:
94
+ data["submission"][c] = doc.get(c, "nan")
95
+ data["gpt_eval_score"][c] = doc.get(c, "nan")
96
+ return data
97
+
98
+
99
+ def mmbench_aggregate_dev_results_eval(results, args):
100
+ print(f"============= MMBench-EN(Dev) Detailed Results =============")
101
+ overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
102
+ file = generate_submission_file("mmbench_en_dev_results.json", args)
103
+ details_info = {
104
+ "overall_acc": overall_acc,
105
+ "category_acc": category_acc,
106
+ "l2_category_acc": l2_category_acc,
107
+ }
108
+ with open(file, "w") as f:
109
+ json.dump(details_info, f)
110
+ return overall_acc * 100
111
+
112
+
113
+ def mmbench_aggregate_dev_results_submission(results, args):
114
+ df = pd.DataFrame(results)
115
+ excel_write_path = generate_submission_file("mmbench_en_dev_results.xlsx", args)
116
+ with pd.ExcelWriter(excel_write_path) as writer:
117
+ df.to_excel(writer, index=False)
118
+ eval_logger.info(f"Saved results to {excel_write_path}")
119
+
120
+
121
+ def mmbench_aggregate_test_results(results, args):
122
+ df = pd.DataFrame(results)
123
+ excel_write_path = generate_submission_file("mmbench_en_test_results.xlsx", args)
124
+ with pd.ExcelWriter(excel_write_path) as writer:
125
+ df.to_excel(writer, index=False)
126
+ eval_logger.info(f"Saved results to {excel_write_path}")
EAGLE/lmms_eval/tasks/mmbench/mmbench.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group: mmbench
2
+ task:
3
+ - mmbench_en_dev
4
+ - mmbench_en_test
5
+ - mmbench_cn_dev
6
+ - mmbench_cn_test
7
+ - mmbench_cn_cc
8
+ metadata:
9
+ version: 0.0
10
+ sys_prompt: "There are several options:"
11
+ gpt_eval_model_name: "gpt-3.5-turbo-0613"
EAGLE/lmms_eval/tasks/mmbench/mmbench_cc.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_path: lmms-lab/MMBench
2
+ dataset_name: cc
3
+ dataset_kwargs:
4
+ token: True
5
+ task: "mmbench_cn_cc"
6
+ test_split: test
7
+ output_type: generate_until
8
+ doc_to_visual: !function cc_utils.mmbench_doc_to_visual
9
+ doc_to_text: !function cc_utils.mmbench_cn_cc_doc_to_text
10
+ doc_to_target: "answer"
11
+ generation_kwargs:
12
+ max_new_tokens: 256
13
+ temperature: 0
14
+ top_p: 0
15
+ num_beams: 1
16
+ do_sample: false
17
+ process_results: !function cc_utils.mmbench_cn_cc_process_results
18
+ metric_list:
19
+ - metric: gpt_eval_score
20
+ aggregation: !function cc_utils.mmbench_cn_cc_aggregate_dev_results_eval
21
+ higher_is_better: true
22
+ - metric: submission
23
+ aggregation: !function cc_utils.mmbench_cn_cc_aggregate_results
24
+ metadata:
25
+ version: 0.0
26
+ gpt_eval_model_name: "gpt-3.5-turbo-0613"
27
+
28
+ model_specific_prompt_kwargs:
29
+ default:
30
+ pre_prompt: ""
31
+ post_prompt: "\n请直接使用所提供的选项字母作为答案回答。"
32
+ model_specific_generation_kwargs:
33
+ llava:
34
+ image_aspect_ratio: original
EAGLE/lmms_eval/tasks/mmbench/mmbench_cn.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ group: mmbench_cn
2
+ task:
3
+ - mmbench_cn_dev
4
+ - mmbench_cn_test
5
+ - mmbench_cn_cc
6
+ metadata:
7
+ version: 0.0
8
+ gpt_eval_model_name: "gpt-3.5-turbo-0613"
9
+ sys_prompt: "有如下几个选项:"