Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- EAGLE/eagle/model/language_model/__init__.py +0 -0
- EAGLE/eagle/model/language_model/eagle_llama.py +173 -0
- EAGLE/eagle/model/multimodal_encoder/__init__.py +0 -0
- EAGLE/eagle/model/multimodal_encoder/clip_encoder.py +89 -0
- EAGLE/eagle/model/multimodal_encoder/convnext_encoder.py +141 -0
- EAGLE/eagle/model/multimodal_encoder/hr_clip_encoder.py +175 -0
- EAGLE/eagle/model/multimodal_encoder/pix2struct_encoder.py +146 -0
- EAGLE/eagle/model/multimodal_encoder/vision_models/__init__.py +0 -0
- EAGLE/eagle/model/multimodal_encoder/vision_models/convnext.py +1108 -0
- EAGLE/eagle/model/multimodal_encoder/vision_models/eva_vit.py +1235 -0
- EAGLE/eagle/model/multimodal_projector/__init__.py +0 -0
- EAGLE/eagle/model/multimodal_projector/builder.py +50 -0
- EAGLE/lmms_eval/api/__init__.py +0 -0
- EAGLE/lmms_eval/api/filter.py +53 -0
- EAGLE/lmms_eval/api/instance.py +29 -0
- EAGLE/lmms_eval/api/metrics.py +431 -0
- EAGLE/lmms_eval/api/model.py +203 -0
- EAGLE/lmms_eval/api/registry.py +139 -0
- EAGLE/lmms_eval/api/samplers.py +94 -0
- EAGLE/lmms_eval/api/task.py +1118 -0
- EAGLE/lmms_eval/filters/__init__.py +44 -0
- EAGLE/lmms_eval/filters/decontamination.py +23 -0
- EAGLE/lmms_eval/filters/extraction.py +60 -0
- EAGLE/lmms_eval/filters/selection.py +48 -0
- EAGLE/lmms_eval/filters/transformation.py +48 -0
- EAGLE/lmms_eval/models/__init__.py +16 -0
- EAGLE/lmms_eval/models/eagle.py +415 -0
- EAGLE/lmms_eval/models/gpt4v.py +129 -0
- EAGLE/lmms_eval/tasks/__init__.py +155 -0
- EAGLE/lmms_eval/tasks/_task_utils/file_utils.py +8 -0
- EAGLE/lmms_eval/tasks/_task_utils/gpt_eval_utils.py +0 -0
- EAGLE/lmms_eval/tasks/_task_utils/vqa_eval_metric.py +213 -0
- EAGLE/lmms_eval/tasks/cmmmu/_cmmmu.yaml +4 -0
- EAGLE/lmms_eval/tasks/cmmmu/_default_template_cmmmu_yaml +8 -0
- EAGLE/lmms_eval/tasks/cmmmu/cmmmu_test.yaml +12 -0
- EAGLE/lmms_eval/tasks/cmmmu/cmmmu_val.yaml +15 -0
- EAGLE/lmms_eval/tasks/cmmmu/utils.py +421 -0
- EAGLE/lmms_eval/tasks/gqa/gqa.yaml +32 -0
- EAGLE/lmms_eval/tasks/gqa/utils.py +23 -0
- EAGLE/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml +39 -0
- EAGLE/lmms_eval/tasks/llava-in-the-wild/rule.json +11 -0
- EAGLE/lmms_eval/tasks/llava-in-the-wild/utils.py +197 -0
- EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_cn_yaml +22 -0
- EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_en_yaml +25 -0
- EAGLE/lmms_eval/tasks/mmbench/cc_utils.py +109 -0
- EAGLE/lmms_eval/tasks/mmbench/cn_utils.py +127 -0
- EAGLE/lmms_eval/tasks/mmbench/en_utils.py +126 -0
- EAGLE/lmms_eval/tasks/mmbench/mmbench.yaml +11 -0
- EAGLE/lmms_eval/tasks/mmbench/mmbench_cc.yaml +34 -0
- EAGLE/lmms_eval/tasks/mmbench/mmbench_cn.yaml +9 -0
EAGLE/eagle/model/language_model/__init__.py
ADDED
|
File without changes
|
EAGLE/eagle/model/language_model/eagle_llama.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright 2023 Haotian Liu
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from typing import List, Optional, Tuple, Union
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
|
| 35 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
| 36 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
| 37 |
+
|
| 38 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 39 |
+
from transformers.generation.utils import GenerateOutput
|
| 40 |
+
|
| 41 |
+
from ..eagle_arch import EagleMetaModel, EagleMetaForCausalLM
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class EagleConfig(LlamaConfig):
|
| 45 |
+
model_type = "eagle_llama"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EagleLlamaModel(EagleMetaModel, LlamaModel):
|
| 49 |
+
config_class = EagleConfig
|
| 50 |
+
|
| 51 |
+
def __init__(self, config: LlamaConfig):
|
| 52 |
+
super(EagleLlamaModel, self).__init__(config)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EagleLlamaForCausalLM(LlamaForCausalLM, EagleMetaForCausalLM):
|
| 56 |
+
config_class = EagleConfig
|
| 57 |
+
|
| 58 |
+
def __init__(self, config):
|
| 59 |
+
super(LlamaForCausalLM, self).__init__(config)
|
| 60 |
+
self.model = EagleLlamaModel(config)
|
| 61 |
+
self.pretraining_tp = config.pretraining_tp
|
| 62 |
+
self.vocab_size = config.vocab_size
|
| 63 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 64 |
+
|
| 65 |
+
# Initialize weights and apply final processing
|
| 66 |
+
self.post_init()
|
| 67 |
+
|
| 68 |
+
def get_model(self):
|
| 69 |
+
return self.model
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
input_ids: torch.LongTensor = None,
|
| 74 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 75 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 76 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 77 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 78 |
+
labels: Optional[torch.LongTensor] = None,
|
| 79 |
+
use_cache: Optional[bool] = None,
|
| 80 |
+
output_attentions: Optional[bool] = None,
|
| 81 |
+
output_hidden_states: Optional[bool] = None,
|
| 82 |
+
images: Optional[torch.FloatTensor] = None,
|
| 83 |
+
image_sizes: Optional[List[List[int]]] = None,
|
| 84 |
+
return_dict: Optional[bool] = None,
|
| 85 |
+
**kwargs # for llama3, upgrade the transformers and will receive an additional argument cache_position
|
| 86 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 87 |
+
|
| 88 |
+
if inputs_embeds is None:
|
| 89 |
+
(
|
| 90 |
+
input_ids,
|
| 91 |
+
position_ids,
|
| 92 |
+
attention_mask,
|
| 93 |
+
past_key_values,
|
| 94 |
+
inputs_embeds,
|
| 95 |
+
labels
|
| 96 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 97 |
+
input_ids,
|
| 98 |
+
position_ids,
|
| 99 |
+
attention_mask,
|
| 100 |
+
past_key_values,
|
| 101 |
+
labels,
|
| 102 |
+
images,
|
| 103 |
+
image_sizes
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return super().forward(
|
| 107 |
+
input_ids=input_ids,
|
| 108 |
+
attention_mask=attention_mask,
|
| 109 |
+
position_ids=position_ids,
|
| 110 |
+
past_key_values=past_key_values,
|
| 111 |
+
inputs_embeds=inputs_embeds,
|
| 112 |
+
labels=labels,
|
| 113 |
+
use_cache=use_cache,
|
| 114 |
+
output_attentions=output_attentions,
|
| 115 |
+
output_hidden_states=output_hidden_states,
|
| 116 |
+
return_dict=return_dict
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def generate(
|
| 121 |
+
self,
|
| 122 |
+
inputs: Optional[torch.Tensor] = None,
|
| 123 |
+
images: Optional[torch.Tensor] = None,
|
| 124 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 125 |
+
**kwargs,
|
| 126 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 127 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 128 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 129 |
+
if "inputs_embeds" in kwargs:
|
| 130 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 131 |
+
|
| 132 |
+
if images is not None:
|
| 133 |
+
(
|
| 134 |
+
inputs,
|
| 135 |
+
position_ids,
|
| 136 |
+
attention_mask,
|
| 137 |
+
_,
|
| 138 |
+
inputs_embeds,
|
| 139 |
+
_
|
| 140 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 141 |
+
inputs,
|
| 142 |
+
position_ids,
|
| 143 |
+
attention_mask,
|
| 144 |
+
None,
|
| 145 |
+
None,
|
| 146 |
+
images,
|
| 147 |
+
image_sizes=image_sizes
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 151 |
+
|
| 152 |
+
return super().generate(
|
| 153 |
+
position_ids=position_ids,
|
| 154 |
+
attention_mask=attention_mask,
|
| 155 |
+
inputs_embeds=inputs_embeds,
|
| 156 |
+
**kwargs
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
| 160 |
+
inputs_embeds=None, **kwargs):
|
| 161 |
+
images = kwargs.pop("images", None)
|
| 162 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 163 |
+
inputs = super().prepare_inputs_for_generation(
|
| 164 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 165 |
+
)
|
| 166 |
+
if images is not None:
|
| 167 |
+
inputs['images'] = images
|
| 168 |
+
if image_sizes is not None:
|
| 169 |
+
inputs['image_sizes'] = image_sizes
|
| 170 |
+
return inputs
|
| 171 |
+
|
| 172 |
+
AutoConfig.register("eagle_llama", EagleConfig)
|
| 173 |
+
AutoModelForCausalLM.register(EagleConfig, EagleLlamaForCausalLM)
|
EAGLE/eagle/model/multimodal_encoder/__init__.py
ADDED
|
File without changes
|
EAGLE/eagle/model/multimodal_encoder/clip_encoder.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CLIPVisionTower(nn.Module):
|
| 9 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.is_loaded = False
|
| 13 |
+
|
| 14 |
+
self.vision_tower_name = vision_tower
|
| 15 |
+
self.select_layer = args.mm_vision_select_layer
|
| 16 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 17 |
+
|
| 18 |
+
if not delay_load:
|
| 19 |
+
self.load_model()
|
| 20 |
+
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
| 21 |
+
self.load_model()
|
| 22 |
+
else:
|
| 23 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
| 24 |
+
|
| 25 |
+
def load_model(self, device_map=None):
|
| 26 |
+
if self.is_loaded:
|
| 27 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 31 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
| 32 |
+
self.vision_tower.requires_grad_(False)
|
| 33 |
+
|
| 34 |
+
self.is_loaded = True
|
| 35 |
+
|
| 36 |
+
def feature_select(self, image_forward_outs):
|
| 37 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 38 |
+
if self.select_feature == 'patch':
|
| 39 |
+
image_features = image_features[:, 1:]
|
| 40 |
+
elif self.select_feature == 'cls_patch':
|
| 41 |
+
image_features = image_features
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 44 |
+
return image_features
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def forward(self, images):
|
| 48 |
+
if type(images) is list:
|
| 49 |
+
image_features = []
|
| 50 |
+
for image in images:
|
| 51 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
| 52 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 53 |
+
image_features.append(image_feature)
|
| 54 |
+
else:
|
| 55 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 56 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 57 |
+
|
| 58 |
+
return image_features
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def dummy_feature(self):
|
| 62 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def dtype(self):
|
| 66 |
+
return self.vision_tower.dtype
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def device(self):
|
| 70 |
+
return self.vision_tower.device
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def config(self):
|
| 74 |
+
if self.is_loaded:
|
| 75 |
+
return self.vision_tower.config
|
| 76 |
+
else:
|
| 77 |
+
return self.cfg_only
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def hidden_size(self):
|
| 81 |
+
return self.config.hidden_size
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def num_patches_per_side(self):
|
| 85 |
+
return self.config.image_size // self.config.patch_size
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def num_patches(self):
|
| 89 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
EAGLE/eagle/model/multimodal_encoder/convnext_encoder.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is modified from https://github.com/luogen1996/LLaVA-HR
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from transformers import CLIPImageProcessor
|
| 21 |
+
from .vision_models.convnext import convnext_xxlarge
|
| 22 |
+
from torch.utils.checkpoint import checkpoint
|
| 23 |
+
|
| 24 |
+
cfg={
|
| 25 |
+
"crop_size": 256,
|
| 26 |
+
"do_center_crop": True,
|
| 27 |
+
"do_normalize": True,
|
| 28 |
+
"do_resize": True,
|
| 29 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 30 |
+
"image_mean": [
|
| 31 |
+
0.48145466,
|
| 32 |
+
0.4578275,
|
| 33 |
+
0.40821073
|
| 34 |
+
],
|
| 35 |
+
"image_std": [
|
| 36 |
+
0.26862954,
|
| 37 |
+
0.26130258,
|
| 38 |
+
0.27577711
|
| 39 |
+
],
|
| 40 |
+
"resample": 3,
|
| 41 |
+
"size": 256
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
class ConvNextVisionTower(nn.Module):
|
| 45 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
self.is_loaded = False
|
| 49 |
+
self.freeze_vision=args.freeze_vision
|
| 50 |
+
self.input_image_size=args.input_image_size
|
| 51 |
+
self.vision_tower_name = vision_tower
|
| 52 |
+
self.select_layer = -1 # hardcode
|
| 53 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 54 |
+
|
| 55 |
+
self.load_model()
|
| 56 |
+
|
| 57 |
+
def load_model(self):
|
| 58 |
+
self.image_processor = CLIPImageProcessor(**cfg)
|
| 59 |
+
if 'xxlarge' in self.vision_tower_name:
|
| 60 |
+
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
|
| 61 |
+
setattr(self.vision_tower, 'hidden_size', 3072)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
if self.freeze_vision:
|
| 66 |
+
self.vision_tower.requires_grad_(False)
|
| 67 |
+
|
| 68 |
+
# Hardcode
|
| 69 |
+
for s in self.vision_tower.stages:
|
| 70 |
+
s.grad_checkpointing = True
|
| 71 |
+
|
| 72 |
+
if self.input_image_size is not None:
|
| 73 |
+
self.image_processor.size=self.input_image_size
|
| 74 |
+
self.image_processor.crop_size={
|
| 75 |
+
'height':self.input_image_size,
|
| 76 |
+
'width': self.input_image_size
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
self.is_loaded = True
|
| 80 |
+
|
| 81 |
+
def feature_select(self, image_forward_outs):
|
| 82 |
+
image_features = image_forward_outs[self.select_layer]
|
| 83 |
+
return image_features
|
| 84 |
+
|
| 85 |
+
def forward_features(self, x):
|
| 86 |
+
x = self.vision_tower.stem(x)
|
| 87 |
+
image_forward_out=[]
|
| 88 |
+
for blk in self.vision_tower.stages:
|
| 89 |
+
x = blk(x)
|
| 90 |
+
b,c,h,w=x.shape
|
| 91 |
+
image_forward_out.append(x.view(b,c,-1).transpose(1,2))
|
| 92 |
+
return image_forward_out
|
| 93 |
+
|
| 94 |
+
def forward(self, images):
|
| 95 |
+
if self.freeze_vision:
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
image_features = self._forward_images(images)
|
| 98 |
+
else:
|
| 99 |
+
image_features = self._forward_images(images)
|
| 100 |
+
|
| 101 |
+
return image_features
|
| 102 |
+
|
| 103 |
+
def _forward_images(self, images):
|
| 104 |
+
|
| 105 |
+
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
|
| 106 |
+
image_features = self.feature_select(image_forward_outs)
|
| 107 |
+
|
| 108 |
+
return image_features
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def dummy_feature(self):
|
| 112 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def dtype(self):
|
| 116 |
+
return next(self.vision_tower.parameters()).dtype
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def device(self):
|
| 120 |
+
return next(self.vision_tower.parameters()).device
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def config(self):
|
| 124 |
+
assert NotImplementedError
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def num_attention_heads(self):
|
| 129 |
+
# as constant
|
| 130 |
+
return 16
|
| 131 |
+
@property
|
| 132 |
+
def num_layers(self):
|
| 133 |
+
# as constant
|
| 134 |
+
return 4
|
| 135 |
+
@property
|
| 136 |
+
def hidden_size(self):
|
| 137 |
+
return self.vision_tower.hidden_size
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def num_patches(self):
|
| 141 |
+
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
|
EAGLE/eagle/model/multimodal_encoder/hr_clip_encoder.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# Mostly copy-paste from LLaVA-HR
|
| 17 |
+
# https://github.com/luogen1996/LLaVA-HR
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.utils.checkpoint import checkpoint
|
| 22 |
+
|
| 23 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from typing import List, Optional
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 32 |
+
batch_size = pixel_values.shape[0]
|
| 33 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 34 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
| 35 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 36 |
+
|
| 37 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
| 38 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 39 |
+
position_embeddings = self.position_embedding(self.position_ids)
|
| 40 |
+
|
| 41 |
+
if position_embeddings.shape[1]!=embeddings.shape[1]:
|
| 42 |
+
position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1])
|
| 43 |
+
|
| 44 |
+
embeddings = embeddings + position_embeddings
|
| 45 |
+
return embeddings
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resample_pos_embed(
|
| 49 |
+
posemb,
|
| 50 |
+
new_size: int,
|
| 51 |
+
num_prefix_tokens: int = 1,
|
| 52 |
+
interpolation: str = 'bicubic',
|
| 53 |
+
antialias: bool = True,
|
| 54 |
+
verbose: bool = False,
|
| 55 |
+
):
|
| 56 |
+
new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))]
|
| 57 |
+
num_pos_tokens = posemb.shape[1] - num_prefix_tokens
|
| 58 |
+
old_size = int(math.sqrt(num_pos_tokens))
|
| 59 |
+
bs=posemb.shape[0]
|
| 60 |
+
|
| 61 |
+
if num_prefix_tokens:
|
| 62 |
+
posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:]
|
| 63 |
+
else:
|
| 64 |
+
posemb_prefix, posemb = None, posemb
|
| 65 |
+
|
| 66 |
+
# do the interpolation
|
| 67 |
+
embed_dim = posemb.shape[-1]
|
| 68 |
+
orig_dtype = posemb.dtype
|
| 69 |
+
posemb = posemb.float() # interpolate needs float32
|
| 70 |
+
posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2)
|
| 71 |
+
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
| 72 |
+
posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim)
|
| 73 |
+
posemb = posemb.to(dtype=orig_dtype)
|
| 74 |
+
|
| 75 |
+
# add back extra (class, etc) prefix tokens
|
| 76 |
+
if posemb_prefix is not None:
|
| 77 |
+
posemb = torch.cat([posemb_prefix, posemb],1)
|
| 78 |
+
|
| 79 |
+
if not torch.jit.is_scripting() and verbose:
|
| 80 |
+
print(f'Resized position embedding: {old_size} to {new_size}.')
|
| 81 |
+
|
| 82 |
+
return posemb
|
| 83 |
+
|
| 84 |
+
class HRCLIPVisionTower(nn.Module):
|
| 85 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
self.is_loaded = False
|
| 89 |
+
self.freeze_vision=args.freeze_vision
|
| 90 |
+
self.input_image_size=args.input_image_size
|
| 91 |
+
self.vision_tower_name = vision_tower
|
| 92 |
+
self.select_layer = args.mm_vision_select_layer
|
| 93 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 94 |
+
|
| 95 |
+
if not delay_load:
|
| 96 |
+
self.load_model()
|
| 97 |
+
else:
|
| 98 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_model(self):
|
| 102 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 103 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
| 104 |
+
# checkpointing for clip
|
| 105 |
+
self.vision_tower.vision_model.encoder.gradient_checkpointing =True
|
| 106 |
+
|
| 107 |
+
if self.freeze_vision:
|
| 108 |
+
self.vision_tower.requires_grad_(False)
|
| 109 |
+
|
| 110 |
+
cls_=self.vision_tower.vision_model.embeddings
|
| 111 |
+
bound_method = forward_embeddings.__get__(cls_, cls_.__class__)
|
| 112 |
+
setattr(cls_, 'forward', bound_method)
|
| 113 |
+
|
| 114 |
+
if self.input_image_size is not None:
|
| 115 |
+
self.image_processor.size=self.input_image_size
|
| 116 |
+
self.image_processor.crop_size={
|
| 117 |
+
'height':self.input_image_size,
|
| 118 |
+
'width': self.input_image_size
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
self.is_loaded = True
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
# 448 image input
|
| 125 |
+
blks = self.vision_tower.vision_model.encoder.layers
|
| 126 |
+
x = self.vision_tower.vision_model.embeddings(x)
|
| 127 |
+
x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:])
|
| 128 |
+
|
| 129 |
+
# inference of fast branch
|
| 130 |
+
for blk in blks:
|
| 131 |
+
if self.training:
|
| 132 |
+
x=checkpoint(
|
| 133 |
+
blk.__call__,
|
| 134 |
+
x,
|
| 135 |
+
None,
|
| 136 |
+
None
|
| 137 |
+
)[0]
|
| 138 |
+
else:
|
| 139 |
+
x = blk(x, None, None)[0]
|
| 140 |
+
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def dummy_feature(self):
|
| 145 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def dtype(self):
|
| 149 |
+
return self.vision_tower.dtype
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def device(self):
|
| 153 |
+
return self.vision_tower.device
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def num_attention_heads(self):
|
| 158 |
+
return self.config.num_attention_heads
|
| 159 |
+
@property
|
| 160 |
+
def num_layers(self):
|
| 161 |
+
return self.config.num_hidden_layers
|
| 162 |
+
@property
|
| 163 |
+
def config(self):
|
| 164 |
+
if self.is_loaded:
|
| 165 |
+
return self.vision_tower.config
|
| 166 |
+
else:
|
| 167 |
+
return self.cfg_only
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def hidden_size(self):
|
| 171 |
+
return self.config.hidden_size
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def num_patches(self):
|
| 175 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
EAGLE/eagle/model/multimodal_encoder/pix2struct_encoder.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from transformers import AutoModel, CLIPImageProcessor
|
| 22 |
+
from PIL import Image
|
| 23 |
+
import requests
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration
|
| 26 |
+
|
| 27 |
+
cfg={
|
| 28 |
+
"crop_size": 256,
|
| 29 |
+
"do_center_crop": True,
|
| 30 |
+
"do_normalize": True,
|
| 31 |
+
"do_resize": True,
|
| 32 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 33 |
+
"image_mean": [
|
| 34 |
+
0.48145466,
|
| 35 |
+
0.4578275,
|
| 36 |
+
0.40821073
|
| 37 |
+
],
|
| 38 |
+
"image_std": [
|
| 39 |
+
0.26862954,
|
| 40 |
+
0.26130258,
|
| 41 |
+
0.27577711
|
| 42 |
+
],
|
| 43 |
+
"resample": 3,
|
| 44 |
+
"size": 256
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
'''
|
| 48 |
+
Pixel2Struct-Large Model (pretrained version)
|
| 49 |
+
'''
|
| 50 |
+
class Pix2StructLargeVisionTower(nn.Module):
|
| 51 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.is_loaded = False
|
| 55 |
+
self.vision_tower_name = vision_tower
|
| 56 |
+
self.do_resize = args.do_resize
|
| 57 |
+
self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor
|
| 58 |
+
self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
|
| 59 |
+
self.input_image_size = args.input_image_size
|
| 60 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 61 |
+
self.freeze_vision = args.freeze_vision
|
| 62 |
+
|
| 63 |
+
self.args = args
|
| 64 |
+
if not self.is_loaded:
|
| 65 |
+
self.load_model()
|
| 66 |
+
|
| 67 |
+
def load_model(self):
|
| 68 |
+
if self.is_loaded:
|
| 69 |
+
return
|
| 70 |
+
whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large")
|
| 71 |
+
self.vision_tower = whole_model.encoder
|
| 72 |
+
self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large")
|
| 73 |
+
self.pix2struct_processor.image_processor.is_vqa = False
|
| 74 |
+
|
| 75 |
+
self.image_processor = CLIPImageProcessor(**cfg)
|
| 76 |
+
if self.input_image_size is not None:
|
| 77 |
+
self.image_processor.size=self.input_image_size
|
| 78 |
+
self.image_processor.crop_size={
|
| 79 |
+
'height':self.input_image_size,
|
| 80 |
+
'width': self.input_image_size
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if self.freeze_vision:
|
| 84 |
+
self.vision_tower.requires_grad_(False)
|
| 85 |
+
|
| 86 |
+
self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1)
|
| 87 |
+
self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1)
|
| 88 |
+
|
| 89 |
+
self.is_loaded = True
|
| 90 |
+
|
| 91 |
+
def feature_select(self, image_forward_outs):
|
| 92 |
+
image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0
|
| 93 |
+
if self.select_feature == 'patch':
|
| 94 |
+
image_features = image_features[:, 1:]
|
| 95 |
+
elif self.select_feature == 'cls_patch':
|
| 96 |
+
image_features = image_features
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 99 |
+
return image_features
|
| 100 |
+
|
| 101 |
+
# @torch.no_grad()
|
| 102 |
+
def forward(self, images):
|
| 103 |
+
|
| 104 |
+
if self.de_normalize:
|
| 105 |
+
mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
|
| 106 |
+
std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
|
| 107 |
+
x = (images * std + mean) * 255.0
|
| 108 |
+
x = self.pix2struct_processor(images=x.float(), return_tensors="pt")
|
| 109 |
+
|
| 110 |
+
image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state
|
| 111 |
+
bs, n, c = image_features.shape
|
| 112 |
+
image_features = image_features[:, :2025, :] # HARD CODE
|
| 113 |
+
|
| 114 |
+
if self.do_resize:
|
| 115 |
+
image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE
|
| 116 |
+
image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE
|
| 117 |
+
return image_features
|
| 118 |
+
else:
|
| 119 |
+
return image_features
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def dummy_feature(self):
|
| 124 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def dtype(self):
|
| 128 |
+
return next(self.vision_tower.parameters()).dtype
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def device(self):
|
| 132 |
+
return next(self.vision_tower.parameters()).device
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def config(self):
|
| 136 |
+
return self.vision_tower.config
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def hidden_size(self):
|
| 140 |
+
# Hard code
|
| 141 |
+
hidden_dim = 1536
|
| 142 |
+
return hidden_dim
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def num_patches(self):
|
| 146 |
+
return self.config['num_patches']
|
EAGLE/eagle/model/multimodal_encoder/vision_models/__init__.py
ADDED
|
File without changes
|
EAGLE/eagle/model/multimodal_encoder/vision_models/convnext.py
ADDED
|
@@ -0,0 +1,1108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" ConvNeXt
|
| 2 |
+
|
| 3 |
+
Papers:
|
| 4 |
+
* `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
| 5 |
+
@Article{liu2022convnet,
|
| 6 |
+
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
| 7 |
+
title = {A ConvNet for the 2020s},
|
| 8 |
+
journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 9 |
+
year = {2022},
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
* `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
|
| 13 |
+
@article{Woo2023ConvNeXtV2,
|
| 14 |
+
title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
|
| 15 |
+
author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
|
| 16 |
+
year={2023},
|
| 17 |
+
journal={arXiv preprint arXiv:2301.00808},
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
Original code and weights from:
|
| 21 |
+
* https://github.com/facebookresearch/ConvNeXt, original copyright below
|
| 22 |
+
* https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
|
| 23 |
+
|
| 24 |
+
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
|
| 25 |
+
|
| 26 |
+
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
|
| 27 |
+
"""
|
| 28 |
+
# ConvNeXt
|
| 29 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 30 |
+
# All rights reserved.
|
| 31 |
+
# This source code is licensed under the MIT license
|
| 32 |
+
|
| 33 |
+
# ConvNeXt-V2
|
| 34 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 35 |
+
# All rights reserved.
|
| 36 |
+
# This source code is licensed under the license found in the
|
| 37 |
+
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
|
| 38 |
+
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
|
| 39 |
+
|
| 40 |
+
from collections import OrderedDict
|
| 41 |
+
from functools import partial
|
| 42 |
+
from typing import Callable, Optional, Tuple, Union
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
|
| 47 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
| 48 |
+
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
| 49 |
+
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
| 50 |
+
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
| 51 |
+
from timm.models._builder import build_model_with_cfg
|
| 52 |
+
from timm.models._manipulate import named_apply, checkpoint_seq
|
| 53 |
+
from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
|
| 54 |
+
|
| 55 |
+
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Downsample(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
| 61 |
+
super().__init__()
|
| 62 |
+
avg_stride = stride if dilation == 1 else 1
|
| 63 |
+
if stride > 1 or dilation > 1:
|
| 64 |
+
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
| 65 |
+
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
| 66 |
+
else:
|
| 67 |
+
self.pool = nn.Identity()
|
| 68 |
+
|
| 69 |
+
if in_chs != out_chs:
|
| 70 |
+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
| 71 |
+
else:
|
| 72 |
+
self.conv = nn.Identity()
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
x = self.pool(x)
|
| 76 |
+
x = self.conv(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ConvNeXtBlock(nn.Module):
|
| 81 |
+
""" ConvNeXt Block
|
| 82 |
+
There are two equivalent implementations:
|
| 83 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 84 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 85 |
+
|
| 86 |
+
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
| 87 |
+
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
| 88 |
+
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_chs: int,
|
| 94 |
+
out_chs: Optional[int] = None,
|
| 95 |
+
kernel_size: int = 7,
|
| 96 |
+
stride: int = 1,
|
| 97 |
+
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
| 98 |
+
mlp_ratio: float = 4,
|
| 99 |
+
conv_mlp: bool = False,
|
| 100 |
+
conv_bias: bool = True,
|
| 101 |
+
use_grn: bool = False,
|
| 102 |
+
ls_init_value: Optional[float] = 1e-6,
|
| 103 |
+
act_layer: Union[str, Callable] = 'gelu',
|
| 104 |
+
norm_layer: Optional[Callable] = None,
|
| 105 |
+
drop_path: float = 0.,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
in_chs: Block input channels.
|
| 111 |
+
out_chs: Block output channels (same as in_chs if None).
|
| 112 |
+
kernel_size: Depthwise convolution kernel size.
|
| 113 |
+
stride: Stride of depthwise convolution.
|
| 114 |
+
dilation: Tuple specifying input and output dilation of block.
|
| 115 |
+
mlp_ratio: MLP expansion ratio.
|
| 116 |
+
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
| 117 |
+
conv_bias: Apply bias for all convolution (linear) layers.
|
| 118 |
+
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
| 119 |
+
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
| 120 |
+
act_layer: Activation layer.
|
| 121 |
+
norm_layer: Normalization layer (defaults to LN if not specified).
|
| 122 |
+
drop_path: Stochastic depth probability.
|
| 123 |
+
"""
|
| 124 |
+
super().__init__()
|
| 125 |
+
out_chs = out_chs or in_chs
|
| 126 |
+
dilation = to_ntuple(2)(dilation)
|
| 127 |
+
act_layer = get_act_layer(act_layer)
|
| 128 |
+
if not norm_layer:
|
| 129 |
+
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
| 130 |
+
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
| 131 |
+
self.use_conv_mlp = conv_mlp
|
| 132 |
+
self.conv_dw = create_conv2d(
|
| 133 |
+
in_chs,
|
| 134 |
+
out_chs,
|
| 135 |
+
kernel_size=kernel_size,
|
| 136 |
+
stride=stride,
|
| 137 |
+
dilation=dilation[0],
|
| 138 |
+
depthwise=True,
|
| 139 |
+
bias=conv_bias,
|
| 140 |
+
)
|
| 141 |
+
self.norm = norm_layer(out_chs)
|
| 142 |
+
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
| 143 |
+
self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
| 144 |
+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
| 145 |
+
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
| 146 |
+
else:
|
| 147 |
+
self.shortcut = nn.Identity()
|
| 148 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
shortcut = x
|
| 152 |
+
x = self.conv_dw(x)
|
| 153 |
+
if self.use_conv_mlp:
|
| 154 |
+
x = self.norm(x)
|
| 155 |
+
x = self.mlp(x)
|
| 156 |
+
else:
|
| 157 |
+
x = x.permute(0, 2, 3, 1)
|
| 158 |
+
x = self.norm(x)
|
| 159 |
+
x = self.mlp(x)
|
| 160 |
+
x = x.permute(0, 3, 1, 2)
|
| 161 |
+
if self.weight is not None:
|
| 162 |
+
x = x.mul(self.weight.reshape(1, -1, 1, 1))
|
| 163 |
+
|
| 164 |
+
x = self.drop_path(x) + self.shortcut(shortcut)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConvNeXtStage(nn.Module):
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
in_chs,
|
| 173 |
+
out_chs,
|
| 174 |
+
kernel_size=7,
|
| 175 |
+
stride=2,
|
| 176 |
+
depth=2,
|
| 177 |
+
dilation=(1, 1),
|
| 178 |
+
drop_path_rates=None,
|
| 179 |
+
ls_init_value=1.0,
|
| 180 |
+
conv_mlp=False,
|
| 181 |
+
conv_bias=True,
|
| 182 |
+
use_grn=False,
|
| 183 |
+
act_layer='gelu',
|
| 184 |
+
norm_layer=None,
|
| 185 |
+
norm_layer_cl=None
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.grad_checkpointing = False
|
| 189 |
+
|
| 190 |
+
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
|
| 191 |
+
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
|
| 192 |
+
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
|
| 193 |
+
self.downsample = nn.Sequential(
|
| 194 |
+
norm_layer(in_chs),
|
| 195 |
+
create_conv2d(
|
| 196 |
+
in_chs,
|
| 197 |
+
out_chs,
|
| 198 |
+
kernel_size=ds_ks,
|
| 199 |
+
stride=stride,
|
| 200 |
+
dilation=dilation[0],
|
| 201 |
+
padding=pad,
|
| 202 |
+
bias=conv_bias,
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
in_chs = out_chs
|
| 206 |
+
else:
|
| 207 |
+
self.downsample = nn.Identity()
|
| 208 |
+
|
| 209 |
+
drop_path_rates = drop_path_rates or [0.] * depth
|
| 210 |
+
stage_blocks = []
|
| 211 |
+
for i in range(depth):
|
| 212 |
+
stage_blocks.append(ConvNeXtBlock(
|
| 213 |
+
in_chs=in_chs,
|
| 214 |
+
out_chs=out_chs,
|
| 215 |
+
kernel_size=kernel_size,
|
| 216 |
+
dilation=dilation[1],
|
| 217 |
+
drop_path=drop_path_rates[i],
|
| 218 |
+
ls_init_value=ls_init_value,
|
| 219 |
+
conv_mlp=conv_mlp,
|
| 220 |
+
conv_bias=conv_bias,
|
| 221 |
+
use_grn=use_grn,
|
| 222 |
+
act_layer=act_layer,
|
| 223 |
+
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
|
| 224 |
+
))
|
| 225 |
+
in_chs = out_chs
|
| 226 |
+
self.blocks = nn.Sequential(*stage_blocks)
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
x = self.downsample(x)
|
| 230 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 231 |
+
x = checkpoint_seq(self.blocks, x)
|
| 232 |
+
else:
|
| 233 |
+
x = self.blocks(x)
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class ConvNeXt(nn.Module):
|
| 238 |
+
r""" ConvNeXt
|
| 239 |
+
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
in_chans: int = 3,
|
| 245 |
+
num_classes: int = 1000,
|
| 246 |
+
global_pool: str = 'avg',
|
| 247 |
+
output_stride: int = 32,
|
| 248 |
+
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
| 249 |
+
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
| 250 |
+
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
| 251 |
+
ls_init_value: Optional[float] = 1e-6,
|
| 252 |
+
stem_type: str = 'patch',
|
| 253 |
+
patch_size: int = 4,
|
| 254 |
+
head_init_scale: float = 1.,
|
| 255 |
+
head_norm_first: bool = False,
|
| 256 |
+
head_hidden_size: Optional[int] = None,
|
| 257 |
+
conv_mlp: bool = False,
|
| 258 |
+
conv_bias: bool = True,
|
| 259 |
+
use_grn: bool = False,
|
| 260 |
+
act_layer: Union[str, Callable] = 'gelu',
|
| 261 |
+
norm_layer: Optional[Union[str, Callable]] = None,
|
| 262 |
+
norm_eps: Optional[float] = None,
|
| 263 |
+
drop_rate: float = 0.,
|
| 264 |
+
drop_path_rate: float = 0.,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Args:
|
| 268 |
+
in_chans: Number of input image channels.
|
| 269 |
+
num_classes: Number of classes for classification head.
|
| 270 |
+
global_pool: Global pooling type.
|
| 271 |
+
output_stride: Output stride of network, one of (8, 16, 32).
|
| 272 |
+
depths: Number of blocks at each stage.
|
| 273 |
+
dims: Feature dimension at each stage.
|
| 274 |
+
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
| 275 |
+
ls_init_value: Init value for Layer Scale, disabled if None.
|
| 276 |
+
stem_type: Type of stem.
|
| 277 |
+
patch_size: Stem patch size for patch stem.
|
| 278 |
+
head_init_scale: Init scaling value for classifier weights and biases.
|
| 279 |
+
head_norm_first: Apply normalization before global pool + head.
|
| 280 |
+
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
| 281 |
+
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
| 282 |
+
conv_bias: Use bias layers w/ all convolutions.
|
| 283 |
+
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
| 284 |
+
act_layer: Activation layer type.
|
| 285 |
+
norm_layer: Normalization layer type.
|
| 286 |
+
drop_rate: Head pre-classifier dropout rate.
|
| 287 |
+
drop_path_rate: Stochastic depth drop rate.
|
| 288 |
+
"""
|
| 289 |
+
super().__init__()
|
| 290 |
+
assert output_stride in (8, 16, 32)
|
| 291 |
+
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
| 292 |
+
if norm_layer is None:
|
| 293 |
+
norm_layer = LayerNorm2d
|
| 294 |
+
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
| 295 |
+
if norm_eps is not None:
|
| 296 |
+
norm_layer = partial(norm_layer, eps=norm_eps)
|
| 297 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
| 298 |
+
else:
|
| 299 |
+
assert conv_mlp,\
|
| 300 |
+
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
| 301 |
+
norm_layer_cl = norm_layer
|
| 302 |
+
if norm_eps is not None:
|
| 303 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
| 304 |
+
|
| 305 |
+
self.num_classes = num_classes
|
| 306 |
+
self.drop_rate = drop_rate
|
| 307 |
+
self.feature_info = []
|
| 308 |
+
|
| 309 |
+
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
|
| 310 |
+
if stem_type == 'patch':
|
| 311 |
+
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
| 312 |
+
self.stem = nn.Sequential(
|
| 313 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
| 314 |
+
norm_layer(dims[0]),
|
| 315 |
+
)
|
| 316 |
+
stem_stride = patch_size
|
| 317 |
+
else:
|
| 318 |
+
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
|
| 319 |
+
self.stem = nn.Sequential(
|
| 320 |
+
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
| 321 |
+
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
| 322 |
+
norm_layer(dims[0]),
|
| 323 |
+
)
|
| 324 |
+
stem_stride = 4
|
| 325 |
+
|
| 326 |
+
self.stages = nn.Sequential()
|
| 327 |
+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
| 328 |
+
stages = []
|
| 329 |
+
prev_chs = dims[0]
|
| 330 |
+
curr_stride = stem_stride
|
| 331 |
+
dilation = 1
|
| 332 |
+
# 4 feature resolution stages, each consisting of multiple residual blocks
|
| 333 |
+
for i in range(4):
|
| 334 |
+
stride = 2 if curr_stride == 2 or i > 0 else 1
|
| 335 |
+
if curr_stride >= output_stride and stride > 1:
|
| 336 |
+
dilation *= stride
|
| 337 |
+
stride = 1
|
| 338 |
+
curr_stride *= stride
|
| 339 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
| 340 |
+
out_chs = dims[i]
|
| 341 |
+
stages.append(ConvNeXtStage(
|
| 342 |
+
prev_chs,
|
| 343 |
+
out_chs,
|
| 344 |
+
kernel_size=kernel_sizes[i],
|
| 345 |
+
stride=stride,
|
| 346 |
+
dilation=(first_dilation, dilation),
|
| 347 |
+
depth=depths[i],
|
| 348 |
+
drop_path_rates=dp_rates[i],
|
| 349 |
+
ls_init_value=ls_init_value,
|
| 350 |
+
conv_mlp=conv_mlp,
|
| 351 |
+
conv_bias=conv_bias,
|
| 352 |
+
use_grn=use_grn,
|
| 353 |
+
act_layer=act_layer,
|
| 354 |
+
norm_layer=norm_layer,
|
| 355 |
+
norm_layer_cl=norm_layer_cl,
|
| 356 |
+
))
|
| 357 |
+
prev_chs = out_chs
|
| 358 |
+
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
| 359 |
+
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
| 360 |
+
self.stages = nn.Sequential(*stages)
|
| 361 |
+
self.num_features = prev_chs
|
| 362 |
+
|
| 363 |
+
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
| 364 |
+
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
| 365 |
+
if head_norm_first:
|
| 366 |
+
assert not head_hidden_size
|
| 367 |
+
self.norm_pre = norm_layer(self.num_features)
|
| 368 |
+
self.head = ClassifierHead(
|
| 369 |
+
self.num_features,
|
| 370 |
+
num_classes,
|
| 371 |
+
pool_type=global_pool,
|
| 372 |
+
drop_rate=self.drop_rate,
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
self.norm_pre = nn.Identity()
|
| 376 |
+
self.head = NormMlpClassifierHead(
|
| 377 |
+
self.num_features,
|
| 378 |
+
num_classes,
|
| 379 |
+
hidden_size=head_hidden_size,
|
| 380 |
+
pool_type=global_pool,
|
| 381 |
+
drop_rate=self.drop_rate,
|
| 382 |
+
norm_layer=norm_layer,
|
| 383 |
+
act_layer='gelu',
|
| 384 |
+
)
|
| 385 |
+
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
| 386 |
+
|
| 387 |
+
@torch.jit.ignore
|
| 388 |
+
def group_matcher(self, coarse=False):
|
| 389 |
+
return dict(
|
| 390 |
+
stem=r'^stem',
|
| 391 |
+
blocks=r'^stages\.(\d+)' if coarse else [
|
| 392 |
+
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
|
| 393 |
+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
| 394 |
+
(r'^norm_pre', (99999,))
|
| 395 |
+
]
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
@torch.jit.ignore
|
| 399 |
+
def set_grad_checkpointing(self, enable=True):
|
| 400 |
+
for s in self.stages:
|
| 401 |
+
s.grad_checkpointing = enable
|
| 402 |
+
|
| 403 |
+
@torch.jit.ignore
|
| 404 |
+
def get_classifier(self):
|
| 405 |
+
return self.head.fc
|
| 406 |
+
|
| 407 |
+
def reset_classifier(self, num_classes=0, global_pool=None):
|
| 408 |
+
self.head.reset(num_classes, global_pool)
|
| 409 |
+
|
| 410 |
+
def forward_features(self, x):
|
| 411 |
+
x = self.stem(x)
|
| 412 |
+
x = self.stages(x)
|
| 413 |
+
x = self.norm_pre(x)
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 417 |
+
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
x = self.forward_features(x)
|
| 421 |
+
x = self.forward_head(x)
|
| 422 |
+
return x
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _init_weights(module, name=None, head_init_scale=1.0):
|
| 426 |
+
if isinstance(module, nn.Conv2d):
|
| 427 |
+
trunc_normal_(module.weight, std=.02)
|
| 428 |
+
if module.bias is not None:
|
| 429 |
+
nn.init.zeros_(module.bias)
|
| 430 |
+
elif isinstance(module, nn.Linear):
|
| 431 |
+
trunc_normal_(module.weight, std=.02)
|
| 432 |
+
nn.init.zeros_(module.bias)
|
| 433 |
+
if name and 'head.' in name:
|
| 434 |
+
module.weight.data.mul_(head_init_scale)
|
| 435 |
+
module.bias.data.mul_(head_init_scale)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 439 |
+
""" Remap FB checkpoints -> timm """
|
| 440 |
+
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
|
| 441 |
+
out_dict={}
|
| 442 |
+
out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
|
| 443 |
+
return out_dict # non-FB checkpoint
|
| 444 |
+
if 'model' in state_dict:
|
| 445 |
+
state_dict = state_dict['model']
|
| 446 |
+
|
| 447 |
+
out_dict = {}
|
| 448 |
+
if 'visual.trunk.stem.0.weight' in state_dict:
|
| 449 |
+
out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
|
| 450 |
+
k.startswith('visual.trunk.')}
|
| 451 |
+
|
| 452 |
+
if 'visual.head.proj.weight' in state_dict:
|
| 453 |
+
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
|
| 454 |
+
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
| 455 |
+
elif 'visual.head.mlp.fc1.weight' in state_dict:
|
| 456 |
+
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
|
| 457 |
+
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
|
| 458 |
+
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
|
| 459 |
+
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
|
| 460 |
+
return out_dict
|
| 461 |
+
|
| 462 |
+
import re
|
| 463 |
+
for k, v in state_dict.items():
|
| 464 |
+
k = k.replace('downsample_layers.0.', 'stem.')
|
| 465 |
+
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
| 466 |
+
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
|
| 467 |
+
k = k.replace('dwconv', 'conv_dw')
|
| 468 |
+
k = k.replace('pwconv', 'mlp.fc')
|
| 469 |
+
if 'grn' in k:
|
| 470 |
+
k = k.replace('grn.beta', 'mlp.grn.bias')
|
| 471 |
+
k = k.replace('grn.gamma', 'mlp.grn.weight')
|
| 472 |
+
v = v.reshape(v.shape[-1])
|
| 473 |
+
k = k.replace('head.', 'head.fc.')
|
| 474 |
+
if k.startswith('norm.'):
|
| 475 |
+
k = k.replace('norm', 'head.norm')
|
| 476 |
+
if v.ndim == 2 and 'head' not in k:
|
| 477 |
+
model_shape = model.state_dict()[k].shape
|
| 478 |
+
v = v.reshape(model_shape)
|
| 479 |
+
k=k.replace('gamma','weight')
|
| 480 |
+
out_dict[k] = v
|
| 481 |
+
|
| 482 |
+
return out_dict
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _create_convnext(variant, pretrained=False, **kwargs):
|
| 486 |
+
if kwargs.get('pretrained_cfg', '') == 'fcmae':
|
| 487 |
+
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
|
| 488 |
+
# This is workaround loading with num_classes=0 w/o removing norm-layer.
|
| 489 |
+
kwargs.setdefault('pretrained_strict', False)
|
| 490 |
+
|
| 491 |
+
model = build_model_with_cfg(
|
| 492 |
+
ConvNeXt, variant, pretrained,
|
| 493 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 494 |
+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
| 495 |
+
**kwargs)
|
| 496 |
+
return model
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _cfg(url='', **kwargs):
|
| 500 |
+
return {
|
| 501 |
+
'url': url,
|
| 502 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 503 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 504 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 505 |
+
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
| 506 |
+
**kwargs
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _cfgv2(url='', **kwargs):
|
| 511 |
+
return {
|
| 512 |
+
'url': url,
|
| 513 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 514 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 515 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 516 |
+
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
| 517 |
+
'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
|
| 518 |
+
'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
|
| 519 |
+
'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
|
| 520 |
+
**kwargs
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
default_cfgs = generate_default_cfgs({
|
| 525 |
+
# timm specific variants
|
| 526 |
+
'convnext_tiny.in12k_ft_in1k': _cfg(
|
| 527 |
+
hf_hub_id='timm/',
|
| 528 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 529 |
+
'convnext_small.in12k_ft_in1k': _cfg(
|
| 530 |
+
hf_hub_id='timm/',
|
| 531 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 532 |
+
|
| 533 |
+
'convnext_atto.d2_in1k': _cfg(
|
| 534 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
| 535 |
+
hf_hub_id='timm/',
|
| 536 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 537 |
+
'convnext_atto_ols.a2_in1k': _cfg(
|
| 538 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
| 539 |
+
hf_hub_id='timm/',
|
| 540 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 541 |
+
'convnext_femto.d1_in1k': _cfg(
|
| 542 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
| 543 |
+
hf_hub_id='timm/',
|
| 544 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 545 |
+
'convnext_femto_ols.d1_in1k': _cfg(
|
| 546 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
|
| 547 |
+
hf_hub_id='timm/',
|
| 548 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 549 |
+
'convnext_pico.d1_in1k': _cfg(
|
| 550 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
|
| 551 |
+
hf_hub_id='timm/',
|
| 552 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 553 |
+
'convnext_pico_ols.d1_in1k': _cfg(
|
| 554 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
|
| 555 |
+
hf_hub_id='timm/',
|
| 556 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 557 |
+
'convnext_nano.in12k_ft_in1k': _cfg(
|
| 558 |
+
hf_hub_id='timm/',
|
| 559 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 560 |
+
'convnext_nano.d1h_in1k': _cfg(
|
| 561 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
|
| 562 |
+
hf_hub_id='timm/',
|
| 563 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 564 |
+
'convnext_nano_ols.d1h_in1k': _cfg(
|
| 565 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
|
| 566 |
+
hf_hub_id='timm/',
|
| 567 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 568 |
+
'convnext_tiny_hnf.a2h_in1k': _cfg(
|
| 569 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
| 570 |
+
hf_hub_id='timm/',
|
| 571 |
+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 572 |
+
|
| 573 |
+
'convnext_tiny.in12k_ft_in1k_384': _cfg(
|
| 574 |
+
hf_hub_id='timm/',
|
| 575 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 576 |
+
'convnext_small.in12k_ft_in1k_384': _cfg(
|
| 577 |
+
hf_hub_id='timm/',
|
| 578 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 579 |
+
|
| 580 |
+
'convnext_nano.in12k': _cfg(
|
| 581 |
+
hf_hub_id='timm/',
|
| 582 |
+
crop_pct=0.95, num_classes=11821),
|
| 583 |
+
'convnext_tiny.in12k': _cfg(
|
| 584 |
+
hf_hub_id='timm/',
|
| 585 |
+
crop_pct=0.95, num_classes=11821),
|
| 586 |
+
'convnext_small.in12k': _cfg(
|
| 587 |
+
hf_hub_id='timm/',
|
| 588 |
+
crop_pct=0.95, num_classes=11821),
|
| 589 |
+
|
| 590 |
+
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
| 591 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
| 592 |
+
hf_hub_id='timm/',
|
| 593 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 594 |
+
'convnext_small.fb_in22k_ft_in1k': _cfg(
|
| 595 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
|
| 596 |
+
hf_hub_id='timm/',
|
| 597 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 598 |
+
'convnext_base.fb_in22k_ft_in1k': _cfg(
|
| 599 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
|
| 600 |
+
hf_hub_id='timm/',
|
| 601 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 602 |
+
'convnext_large.fb_in22k_ft_in1k': _cfg(
|
| 603 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
|
| 604 |
+
hf_hub_id='timm/',
|
| 605 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 606 |
+
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
|
| 607 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
|
| 608 |
+
hf_hub_id='timm/',
|
| 609 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 610 |
+
|
| 611 |
+
'convnext_tiny.fb_in1k': _cfg(
|
| 612 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
| 613 |
+
hf_hub_id='timm/',
|
| 614 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 615 |
+
'convnext_small.fb_in1k': _cfg(
|
| 616 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
| 617 |
+
hf_hub_id='timm/',
|
| 618 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 619 |
+
'convnext_base.fb_in1k': _cfg(
|
| 620 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
| 621 |
+
hf_hub_id='timm/',
|
| 622 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 623 |
+
'convnext_large.fb_in1k': _cfg(
|
| 624 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
| 625 |
+
hf_hub_id='timm/',
|
| 626 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 627 |
+
|
| 628 |
+
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
| 629 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
| 630 |
+
hf_hub_id='timm/',
|
| 631 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 632 |
+
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
|
| 633 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
|
| 634 |
+
hf_hub_id='timm/',
|
| 635 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 636 |
+
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
|
| 637 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
|
| 638 |
+
hf_hub_id='timm/',
|
| 639 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 640 |
+
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
|
| 641 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
|
| 642 |
+
hf_hub_id='timm/',
|
| 643 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 644 |
+
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
|
| 645 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
|
| 646 |
+
hf_hub_id='timm/',
|
| 647 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 648 |
+
|
| 649 |
+
'convnext_tiny.fb_in22k': _cfg(
|
| 650 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
| 651 |
+
hf_hub_id='timm/',
|
| 652 |
+
num_classes=21841),
|
| 653 |
+
'convnext_small.fb_in22k': _cfg(
|
| 654 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
| 655 |
+
hf_hub_id='timm/',
|
| 656 |
+
num_classes=21841),
|
| 657 |
+
'convnext_base.fb_in22k': _cfg(
|
| 658 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
| 659 |
+
hf_hub_id='timm/',
|
| 660 |
+
num_classes=21841),
|
| 661 |
+
'convnext_large.fb_in22k': _cfg(
|
| 662 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
| 663 |
+
hf_hub_id='timm/',
|
| 664 |
+
num_classes=21841),
|
| 665 |
+
'convnext_xlarge.fb_in22k': _cfg(
|
| 666 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
| 667 |
+
hf_hub_id='timm/',
|
| 668 |
+
num_classes=21841),
|
| 669 |
+
|
| 670 |
+
'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
|
| 671 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
|
| 672 |
+
hf_hub_id='timm/',
|
| 673 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 674 |
+
'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
|
| 675 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
|
| 676 |
+
hf_hub_id='timm/',
|
| 677 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 678 |
+
'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
|
| 679 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
|
| 680 |
+
hf_hub_id='timm/',
|
| 681 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 682 |
+
'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
|
| 683 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
|
| 684 |
+
hf_hub_id='timm/',
|
| 685 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 686 |
+
'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
|
| 687 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
|
| 688 |
+
hf_hub_id='timm/',
|
| 689 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 690 |
+
'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
|
| 691 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
|
| 692 |
+
hf_hub_id='timm/',
|
| 693 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 694 |
+
'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
|
| 695 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
|
| 696 |
+
hf_hub_id='timm/',
|
| 697 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 698 |
+
'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
|
| 699 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
|
| 700 |
+
hf_hub_id='timm/',
|
| 701 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 702 |
+
'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
|
| 703 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
|
| 704 |
+
hf_hub_id='timm/',
|
| 705 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 706 |
+
'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
|
| 707 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
|
| 708 |
+
hf_hub_id='timm/',
|
| 709 |
+
input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
|
| 710 |
+
|
| 711 |
+
'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
|
| 712 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
|
| 713 |
+
hf_hub_id='timm/',
|
| 714 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 715 |
+
'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
|
| 716 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
|
| 717 |
+
hf_hub_id='timm/',
|
| 718 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 719 |
+
'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
|
| 720 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
|
| 721 |
+
hf_hub_id='timm/',
|
| 722 |
+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 723 |
+
'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
|
| 724 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
|
| 725 |
+
hf_hub_id='timm/',
|
| 726 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 727 |
+
'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
|
| 728 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
|
| 729 |
+
hf_hub_id='timm/',
|
| 730 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 731 |
+
'convnextv2_base.fcmae_ft_in1k': _cfgv2(
|
| 732 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
|
| 733 |
+
hf_hub_id='timm/',
|
| 734 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 735 |
+
'convnextv2_large.fcmae_ft_in1k': _cfgv2(
|
| 736 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
|
| 737 |
+
hf_hub_id='timm/',
|
| 738 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 739 |
+
'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
|
| 740 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
|
| 741 |
+
hf_hub_id='timm/',
|
| 742 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 743 |
+
|
| 744 |
+
'convnextv2_atto.fcmae': _cfgv2(
|
| 745 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
|
| 746 |
+
hf_hub_id='timm/',
|
| 747 |
+
num_classes=0),
|
| 748 |
+
'convnextv2_femto.fcmae': _cfgv2(
|
| 749 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
|
| 750 |
+
hf_hub_id='timm/',
|
| 751 |
+
num_classes=0),
|
| 752 |
+
'convnextv2_pico.fcmae': _cfgv2(
|
| 753 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
|
| 754 |
+
hf_hub_id='timm/',
|
| 755 |
+
num_classes=0),
|
| 756 |
+
'convnextv2_nano.fcmae': _cfgv2(
|
| 757 |
+
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
|
| 758 |
+
hf_hub_id='timm/',
|
| 759 |
+
num_classes=0),
|
| 760 |
+
'convnextv2_tiny.fcmae': _cfgv2(
|
| 761 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
|
| 762 |
+
hf_hub_id='timm/',
|
| 763 |
+
num_classes=0),
|
| 764 |
+
'convnextv2_base.fcmae': _cfgv2(
|
| 765 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
|
| 766 |
+
hf_hub_id='timm/',
|
| 767 |
+
num_classes=0),
|
| 768 |
+
'convnextv2_large.fcmae': _cfgv2(
|
| 769 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
|
| 770 |
+
hf_hub_id='timm/',
|
| 771 |
+
num_classes=0),
|
| 772 |
+
'convnextv2_huge.fcmae': _cfgv2(
|
| 773 |
+
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
|
| 774 |
+
hf_hub_id='timm/',
|
| 775 |
+
num_classes=0),
|
| 776 |
+
|
| 777 |
+
'convnextv2_small.untrained': _cfg(),
|
| 778 |
+
|
| 779 |
+
# CLIP weights, fine-tuned on in1k or in12k + in1k
|
| 780 |
+
'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
|
| 781 |
+
hf_hub_id='timm/',
|
| 782 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 783 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
| 784 |
+
'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
|
| 785 |
+
hf_hub_id='timm/',
|
| 786 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 787 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 788 |
+
'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
|
| 789 |
+
hf_hub_id='timm/',
|
| 790 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 791 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
|
| 792 |
+
'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
|
| 793 |
+
hf_hub_id='timm/',
|
| 794 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 795 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 796 |
+
|
| 797 |
+
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
|
| 798 |
+
hf_hub_id='timm/',
|
| 799 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 800 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
| 801 |
+
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
|
| 802 |
+
hf_hub_id='timm/',
|
| 803 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 804 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
| 805 |
+
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
|
| 806 |
+
hf_hub_id='timm/',
|
| 807 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 808 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
|
| 809 |
+
),
|
| 810 |
+
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
|
| 811 |
+
hf_hub_id='timm/',
|
| 812 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 813 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
|
| 814 |
+
),
|
| 815 |
+
'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
|
| 816 |
+
hf_hub_id='timm/',
|
| 817 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 818 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
| 819 |
+
|
| 820 |
+
'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
|
| 821 |
+
hf_hub_id='timm/',
|
| 822 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
|
| 823 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
| 824 |
+
'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
|
| 825 |
+
hf_hub_id='timm/',
|
| 826 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
|
| 827 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
|
| 828 |
+
'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
|
| 829 |
+
hf_hub_id='timm/',
|
| 830 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
|
| 831 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 832 |
+
'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
|
| 833 |
+
hf_hub_id='timm/',
|
| 834 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
|
| 835 |
+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
| 836 |
+
'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
|
| 837 |
+
hf_hub_id='timm/',
|
| 838 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
|
| 839 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
| 840 |
+
|
| 841 |
+
# CLIP original image tower weights
|
| 842 |
+
'convnext_base.clip_laion2b': _cfg(
|
| 843 |
+
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
| 844 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 845 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 846 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
| 847 |
+
'convnext_base.clip_laion2b_augreg': _cfg(
|
| 848 |
+
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
|
| 849 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 850 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 851 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
| 852 |
+
'convnext_base.clip_laiona': _cfg(
|
| 853 |
+
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
|
| 854 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 855 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 856 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
| 857 |
+
'convnext_base.clip_laiona_320': _cfg(
|
| 858 |
+
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
|
| 859 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 860 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 861 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
| 862 |
+
'convnext_base.clip_laiona_augreg_320': _cfg(
|
| 863 |
+
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
|
| 864 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 865 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 866 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
| 867 |
+
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
|
| 868 |
+
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
|
| 869 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 870 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 871 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
| 872 |
+
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
|
| 873 |
+
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
|
| 874 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 875 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 876 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
| 877 |
+
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
|
| 878 |
+
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
|
| 879 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 880 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 881 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
| 882 |
+
'convnext_xxlarge.clip_laion2b_soup': _cfg(
|
| 883 |
+
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
|
| 884 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 885 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 886 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
| 887 |
+
'convnext_xxlarge.clip_laion2b_rewind': _cfg(
|
| 888 |
+
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
|
| 889 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 890 |
+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
| 891 |
+
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
| 892 |
+
})
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
# @register_model
|
| 896 |
+
# def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
|
| 897 |
+
# # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
| 898 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
|
| 899 |
+
# model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 900 |
+
# return model
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
# @register_model
|
| 904 |
+
# def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
| 905 |
+
# # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
|
| 906 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
|
| 907 |
+
# model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 908 |
+
# return model
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
# @register_model
|
| 912 |
+
# def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
|
| 913 |
+
# # timm femto variant
|
| 914 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
|
| 915 |
+
# model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 916 |
+
# return model
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
# @register_model
|
| 920 |
+
# def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
| 921 |
+
# # timm femto variant
|
| 922 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
|
| 923 |
+
# model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 924 |
+
# return model
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
# @register_model
|
| 928 |
+
# def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
|
| 929 |
+
# # timm pico variant
|
| 930 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
|
| 931 |
+
# model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 932 |
+
# return model
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
# @register_model
|
| 936 |
+
# def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
| 937 |
+
# # timm nano variant with overlapping 3x3 conv stem
|
| 938 |
+
# model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
|
| 939 |
+
# model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 940 |
+
# return model
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
# @register_model
|
| 944 |
+
# def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
|
| 945 |
+
# # timm nano variant with standard stem and head
|
| 946 |
+
# model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
|
| 947 |
+
# model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 948 |
+
# return model
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
# @register_model
|
| 952 |
+
# def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
| 953 |
+
# # experimental nano variant with overlapping conv stem
|
| 954 |
+
# model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
|
| 955 |
+
# model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 956 |
+
# return model
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
# @register_model
|
| 960 |
+
# def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
|
| 961 |
+
# # experimental tiny variant with norm before pooling in head (head norm first)
|
| 962 |
+
# model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
|
| 963 |
+
# model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 964 |
+
# return model
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# @register_model
|
| 968 |
+
# def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
|
| 969 |
+
# model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
|
| 970 |
+
# model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 971 |
+
# return model
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
# @register_model
|
| 975 |
+
# def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
|
| 976 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
|
| 977 |
+
# model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 978 |
+
# return model
|
| 979 |
+
|
| 980 |
+
# @register_model
|
| 981 |
+
# def convnext_base_clip(pretrained='', **kwargs) -> ConvNeXt:
|
| 982 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
| 983 |
+
# model = _create_convnext(pretrained, pretrained=True, **dict(model_args, **kwargs))
|
| 984 |
+
# return model
|
| 985 |
+
|
| 986 |
+
# @register_model
|
| 987 |
+
# def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
|
| 988 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
| 989 |
+
# model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 990 |
+
# return model
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
# @register_model
|
| 994 |
+
# def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
|
| 995 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
|
| 996 |
+
# model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 997 |
+
# return model
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
# @register_model
|
| 1001 |
+
# def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1002 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
|
| 1003 |
+
# model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1004 |
+
# return model
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
# @register_model
|
| 1008 |
+
# def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1009 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
|
| 1010 |
+
# model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1011 |
+
# return model
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
# @register_model
|
| 1015 |
+
def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1016 |
+
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
|
| 1017 |
+
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1018 |
+
return model
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# @register_model
|
| 1022 |
+
# def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1023 |
+
# # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
| 1024 |
+
# model_args = dict(
|
| 1025 |
+
# depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
|
| 1026 |
+
# model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1027 |
+
# return model
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
# @register_model
|
| 1031 |
+
# def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1032 |
+
# # timm femto variant
|
| 1033 |
+
# model_args = dict(
|
| 1034 |
+
# depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
|
| 1035 |
+
# model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1036 |
+
# return model
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
# @register_model
|
| 1040 |
+
# def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1041 |
+
# # timm pico variant
|
| 1042 |
+
# model_args = dict(
|
| 1043 |
+
# depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
|
| 1044 |
+
# model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1045 |
+
# return model
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
# @register_model
|
| 1049 |
+
# def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1050 |
+
# # timm nano variant with standard stem and head
|
| 1051 |
+
# model_args = dict(
|
| 1052 |
+
# depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
|
| 1053 |
+
# model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1054 |
+
# return model
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
# @register_model
|
| 1058 |
+
# def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1059 |
+
# model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
|
| 1060 |
+
# model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1061 |
+
# return model
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
# @register_model
|
| 1065 |
+
# def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1066 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
|
| 1067 |
+
# model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1068 |
+
# return model
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
# @register_model
|
| 1072 |
+
# def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1073 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
|
| 1074 |
+
# model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1075 |
+
# return model
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
# @register_model
|
| 1079 |
+
# def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1080 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
|
| 1081 |
+
# model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1082 |
+
# return model
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
# @register_model
|
| 1086 |
+
# def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
|
| 1087 |
+
# model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
|
| 1088 |
+
# model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1089 |
+
# return model
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
# register_model_deprecations(__name__, {
|
| 1093 |
+
# 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
|
| 1094 |
+
# 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
|
| 1095 |
+
# 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
|
| 1096 |
+
# 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
|
| 1097 |
+
# 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
|
| 1098 |
+
# 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
|
| 1099 |
+
# 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
|
| 1100 |
+
# 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
|
| 1101 |
+
# 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
|
| 1102 |
+
# 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
|
| 1103 |
+
# 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
|
| 1104 |
+
# 'convnext_small_in22k': 'convnext_small.fb_in22k',
|
| 1105 |
+
# 'convnext_base_in22k': 'convnext_base.fb_in22k',
|
| 1106 |
+
# 'convnext_large_in22k': 'convnext_large.fb_in22k',
|
| 1107 |
+
# 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
|
| 1108 |
+
# })
|
EAGLE/eagle/model/multimodal_encoder/vision_models/eva_vit.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# This file is modified from https://github.com/baaivision/EVA
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import fvcore.nn.weight_init as weight_init
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import math
|
| 24 |
+
import numpy as np
|
| 25 |
+
import logging
|
| 26 |
+
from functools import partial
|
| 27 |
+
from scipy import interpolate
|
| 28 |
+
from math import pi
|
| 29 |
+
from einops import rearrange, repeat
|
| 30 |
+
import warnings
|
| 31 |
+
from PIL import Image
|
| 32 |
+
import torch.utils.checkpoint as cp
|
| 33 |
+
from transformers import CLIPImageProcessor
|
| 34 |
+
# from ..utils.attention import FlashAttention, FlashMHA
|
| 35 |
+
# try:
|
| 36 |
+
# import xformers.ops as xops
|
| 37 |
+
# except:
|
| 38 |
+
# pass
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
BatchNorm2d = torch.nn.BatchNorm2d
|
| 42 |
+
|
| 43 |
+
class Conv2d(torch.nn.Conv2d):
|
| 44 |
+
"""
|
| 45 |
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, *args, **kwargs):
|
| 49 |
+
"""
|
| 50 |
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
| 51 |
+
Args:
|
| 52 |
+
norm (nn.Module, optional): a normalization layer
|
| 53 |
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
| 54 |
+
It assumes that norm layer is used before activation.
|
| 55 |
+
"""
|
| 56 |
+
norm = kwargs.pop("norm", None)
|
| 57 |
+
activation = kwargs.pop("activation", None)
|
| 58 |
+
super().__init__(*args, **kwargs)
|
| 59 |
+
|
| 60 |
+
self.norm = norm
|
| 61 |
+
self.activation = activation
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
# torchscript does not support SyncBatchNorm yet
|
| 65 |
+
# https://github.com/pytorch/pytorch/issues/40507
|
| 66 |
+
# and we skip these codes in torchscript since:
|
| 67 |
+
# 1. currently we only support torchscript in evaluation mode
|
| 68 |
+
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
| 69 |
+
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
| 70 |
+
if not torch.jit.is_scripting():
|
| 71 |
+
with warnings.catch_warnings(record=True):
|
| 72 |
+
if x.numel() == 0 and self.training:
|
| 73 |
+
# https://github.com/pytorch/pytorch/issues/12013
|
| 74 |
+
assert not isinstance(
|
| 75 |
+
self.norm, torch.nn.SyncBatchNorm
|
| 76 |
+
), "SyncBatchNorm does not support empty inputs!"
|
| 77 |
+
|
| 78 |
+
x = F.conv2d(
|
| 79 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 80 |
+
)
|
| 81 |
+
if self.norm is not None:
|
| 82 |
+
x = self.norm(x)
|
| 83 |
+
if self.activation is not None:
|
| 84 |
+
x = self.activation(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def window_partition(x, window_size):
|
| 89 |
+
"""
|
| 90 |
+
Partition into non-overlapping windows with padding if needed.
|
| 91 |
+
Args:
|
| 92 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 93 |
+
window_size (int): window size.
|
| 94 |
+
Returns:
|
| 95 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 96 |
+
(Hp, Wp): padded height and width before partition
|
| 97 |
+
"""
|
| 98 |
+
B, H, W, C = x.shape
|
| 99 |
+
|
| 100 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 101 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 102 |
+
if pad_h > 0 or pad_w > 0:
|
| 103 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 104 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 105 |
+
|
| 106 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 107 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 108 |
+
return windows, (Hp, Wp)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 112 |
+
"""
|
| 113 |
+
Window unpartition into original sequences and removing padding.
|
| 114 |
+
Args:
|
| 115 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 116 |
+
window_size (int): window size.
|
| 117 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 118 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 119 |
+
Returns:
|
| 120 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 121 |
+
"""
|
| 122 |
+
Hp, Wp = pad_hw
|
| 123 |
+
H, W = hw
|
| 124 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 125 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
| 126 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 127 |
+
|
| 128 |
+
if Hp > H or Wp > W:
|
| 129 |
+
x = x[:, :H, :W, :].contiguous()
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
| 134 |
+
"""
|
| 135 |
+
Get relative positional embeddings according to the relative positions of
|
| 136 |
+
query and key sizes.
|
| 137 |
+
Args:
|
| 138 |
+
q_size (int): size of query q.
|
| 139 |
+
k_size (int): size of key k.
|
| 140 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
| 141 |
+
Returns:
|
| 142 |
+
Extracted positional embeddings according to relative positions.
|
| 143 |
+
"""
|
| 144 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 145 |
+
use_log_interpolation = True
|
| 146 |
+
|
| 147 |
+
# Interpolate rel pos if needed.
|
| 148 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 149 |
+
if not use_log_interpolation:
|
| 150 |
+
# Interpolate rel pos.
|
| 151 |
+
rel_pos_resized = F.interpolate(
|
| 152 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 153 |
+
size=max_rel_dist,
|
| 154 |
+
mode="linear",
|
| 155 |
+
)
|
| 156 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 157 |
+
else:
|
| 158 |
+
src_size = rel_pos.shape[0]
|
| 159 |
+
dst_size = max_rel_dist
|
| 160 |
+
|
| 161 |
+
# q = 1.13492
|
| 162 |
+
q = 1.0903078
|
| 163 |
+
dis = []
|
| 164 |
+
|
| 165 |
+
cur = 1
|
| 166 |
+
for i in range(src_size // 2):
|
| 167 |
+
dis.append(cur)
|
| 168 |
+
cur += q ** (i + 1)
|
| 169 |
+
|
| 170 |
+
r_ids = [-_ for _ in reversed(dis)]
|
| 171 |
+
x = r_ids + [0] + dis
|
| 172 |
+
t = dst_size // 2.0
|
| 173 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
| 174 |
+
all_rel_pos_bias = []
|
| 175 |
+
for i in range(rel_pos.shape[1]):
|
| 176 |
+
z = rel_pos[:, i].view(src_size).cpu().float().numpy()
|
| 177 |
+
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
|
| 178 |
+
all_rel_pos_bias.append(
|
| 179 |
+
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
|
| 180 |
+
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
|
| 181 |
+
else:
|
| 182 |
+
rel_pos_resized = rel_pos
|
| 183 |
+
|
| 184 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 185 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 186 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 187 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 188 |
+
|
| 189 |
+
return rel_pos_resized[relative_coords.long()]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
| 193 |
+
"""
|
| 194 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| 195 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
| 196 |
+
Args:
|
| 197 |
+
attn (Tensor): attention map.
|
| 198 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
| 199 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
| 200 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
| 201 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
| 202 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
| 203 |
+
Returns:
|
| 204 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
| 205 |
+
"""
|
| 206 |
+
q_h, q_w = q_size
|
| 207 |
+
k_h, k_w = k_size
|
| 208 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
| 209 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
| 210 |
+
|
| 211 |
+
B, _, dim = q.shape
|
| 212 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
| 213 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
| 214 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
| 215 |
+
|
| 216 |
+
attn = (
|
| 217 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 218 |
+
).view(B, q_h * q_w, k_h * k_w)
|
| 219 |
+
|
| 220 |
+
return attn
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
| 224 |
+
"""
|
| 225 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| 226 |
+
dimension for the original embeddings.
|
| 227 |
+
Args:
|
| 228 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| 229 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| 230 |
+
hw (Tuple): size of input image tokens.
|
| 231 |
+
Returns:
|
| 232 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
| 233 |
+
"""
|
| 234 |
+
h, w = hw
|
| 235 |
+
if has_cls_token:
|
| 236 |
+
abs_pos = abs_pos[:, 1:]
|
| 237 |
+
xy_num = abs_pos.shape[1]
|
| 238 |
+
size = int(math.sqrt(xy_num))
|
| 239 |
+
assert size * size == xy_num
|
| 240 |
+
|
| 241 |
+
if size != h or size != w:
|
| 242 |
+
original_datatype = abs_pos.dtype
|
| 243 |
+
new_abs_pos = F.interpolate(
|
| 244 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented
|
| 245 |
+
size=(h, w),
|
| 246 |
+
mode="bicubic",
|
| 247 |
+
align_corners=False,
|
| 248 |
+
).to(original_datatype)
|
| 249 |
+
|
| 250 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
| 251 |
+
else:
|
| 252 |
+
return abs_pos.reshape(1, h, w, -1)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class PatchEmbed(nn.Module):
|
| 256 |
+
"""
|
| 257 |
+
Image to Patch Embedding.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(
|
| 261 |
+
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
| 262 |
+
):
|
| 263 |
+
"""
|
| 264 |
+
Args:
|
| 265 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 266 |
+
stride (Tuple): stride of the projection layer.
|
| 267 |
+
padding (Tuple): padding size of the projection layer.
|
| 268 |
+
in_chans (int): Number of input image channels.
|
| 269 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 270 |
+
"""
|
| 271 |
+
super().__init__()
|
| 272 |
+
|
| 273 |
+
self.proj = nn.Conv2d(
|
| 274 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
x = self.proj(x)
|
| 279 |
+
# B C H W -> B H W C
|
| 280 |
+
x = x.permute(0, 2, 3, 1)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def broadcat(tensors, dim = -1):
|
| 285 |
+
num_tensors = len(tensors)
|
| 286 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 287 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
| 288 |
+
shape_len = list(shape_lens)[0]
|
| 289 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 290 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 291 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 292 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
| 293 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 294 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 295 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 296 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 297 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 298 |
+
return torch.cat(tensors, dim = dim)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def rotate_half(x):
|
| 303 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 304 |
+
x1, x2 = x.unbind(dim = -1)
|
| 305 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 306 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
dim,
|
| 314 |
+
pt_seq_len,
|
| 315 |
+
ft_seq_len=None,
|
| 316 |
+
custom_freqs = None,
|
| 317 |
+
freqs_for = 'lang',
|
| 318 |
+
theta = 10000,
|
| 319 |
+
max_freq = 10,
|
| 320 |
+
num_freqs = 1,
|
| 321 |
+
):
|
| 322 |
+
super().__init__()
|
| 323 |
+
if custom_freqs:
|
| 324 |
+
freqs = custom_freqs
|
| 325 |
+
elif freqs_for == 'lang':
|
| 326 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 327 |
+
elif freqs_for == 'pixel':
|
| 328 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 329 |
+
elif freqs_for == 'constant':
|
| 330 |
+
freqs = torch.ones(num_freqs).float()
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 333 |
+
|
| 334 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 335 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 336 |
+
|
| 337 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
| 338 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
| 339 |
+
|
| 340 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
| 341 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
| 342 |
+
|
| 343 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
| 344 |
+
|
| 345 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 346 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 347 |
+
|
| 348 |
+
# print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 349 |
+
|
| 350 |
+
def forward(self, t, start_index = 0):
|
| 351 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 352 |
+
end_index = start_index + rot_dim
|
| 353 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 354 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 355 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 356 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 362 |
+
def __init__(
|
| 363 |
+
self,
|
| 364 |
+
dim,
|
| 365 |
+
pt_seq_len=16,
|
| 366 |
+
ft_seq_len=None,
|
| 367 |
+
custom_freqs = None,
|
| 368 |
+
freqs_for = 'lang',
|
| 369 |
+
theta = 10000,
|
| 370 |
+
max_freq = 10,
|
| 371 |
+
num_freqs = 1,
|
| 372 |
+
):
|
| 373 |
+
super().__init__()
|
| 374 |
+
if custom_freqs:
|
| 375 |
+
freqs = custom_freqs
|
| 376 |
+
elif freqs_for == 'lang':
|
| 377 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 378 |
+
elif freqs_for == 'pixel':
|
| 379 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 380 |
+
elif freqs_for == 'constant':
|
| 381 |
+
freqs = torch.ones(num_freqs).float()
|
| 382 |
+
else:
|
| 383 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
| 384 |
+
|
| 385 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
| 386 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 387 |
+
|
| 388 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
| 389 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 390 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 391 |
+
|
| 392 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 393 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 394 |
+
|
| 395 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 396 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 397 |
+
|
| 398 |
+
# print('======== shape of rope freq', self.freqs_cos.shape, '========')
|
| 399 |
+
|
| 400 |
+
def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class FrozenBatchNorm2d(nn.Module):
|
| 404 |
+
"""
|
| 405 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 406 |
+
It contains non-trainable buffers called
|
| 407 |
+
"weight" and "bias", "running_mean", "running_var",
|
| 408 |
+
initialized to perform identity transformation.
|
| 409 |
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
| 410 |
+
which are computed from the original four parameters of BN.
|
| 411 |
+
The affine transform `x * weight + bias` will perform the equivalent
|
| 412 |
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
| 413 |
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
| 414 |
+
will be left unchanged as identity transformation.
|
| 415 |
+
Other pre-trained backbone models may contain all 4 parameters.
|
| 416 |
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
_version = 3
|
| 420 |
+
|
| 421 |
+
def __init__(self, num_features, eps=1e-5):
|
| 422 |
+
super().__init__()
|
| 423 |
+
self.num_features = num_features
|
| 424 |
+
self.eps = eps
|
| 425 |
+
self.register_buffer("weight", torch.ones(num_features))
|
| 426 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
| 427 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
| 428 |
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
if x.requires_grad:
|
| 432 |
+
# When gradients are needed, F.batch_norm will use extra memory
|
| 433 |
+
# because its backward op computes gradients for weight/bias as well.
|
| 434 |
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
| 435 |
+
bias = self.bias - self.running_mean * scale
|
| 436 |
+
scale = scale.reshape(1, -1, 1, 1)
|
| 437 |
+
bias = bias.reshape(1, -1, 1, 1)
|
| 438 |
+
out_dtype = x.dtype # may be half
|
| 439 |
+
return x * scale.to(out_dtype) + bias.to(out_dtype)
|
| 440 |
+
else:
|
| 441 |
+
# When gradients are not needed, F.batch_norm is a single fused op
|
| 442 |
+
# and provide more optimization opportunities.
|
| 443 |
+
return F.batch_norm(
|
| 444 |
+
x,
|
| 445 |
+
self.running_mean,
|
| 446 |
+
self.running_var,
|
| 447 |
+
self.weight,
|
| 448 |
+
self.bias,
|
| 449 |
+
training=False,
|
| 450 |
+
eps=self.eps,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
def _load_from_state_dict(
|
| 454 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 455 |
+
):
|
| 456 |
+
version = local_metadata.get("version", None)
|
| 457 |
+
|
| 458 |
+
if version is None or version < 2:
|
| 459 |
+
# No running_mean/var in early versions
|
| 460 |
+
# This will silent the warnings
|
| 461 |
+
if prefix + "running_mean" not in state_dict:
|
| 462 |
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
| 463 |
+
if prefix + "running_var" not in state_dict:
|
| 464 |
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
| 465 |
+
|
| 466 |
+
super()._load_from_state_dict(
|
| 467 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def __repr__(self):
|
| 471 |
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
| 472 |
+
|
| 473 |
+
@classmethod
|
| 474 |
+
def convert_frozen_batchnorm(cls, module):
|
| 475 |
+
"""
|
| 476 |
+
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
| 477 |
+
Args:
|
| 478 |
+
module (torch.nn.Module):
|
| 479 |
+
Returns:
|
| 480 |
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
| 481 |
+
Otherwise, in-place convert module and return it.
|
| 482 |
+
Similar to convert_sync_batchnorm in
|
| 483 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
| 484 |
+
"""
|
| 485 |
+
bn_module = nn.modules.batchnorm
|
| 486 |
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
| 487 |
+
res = module
|
| 488 |
+
if isinstance(module, bn_module):
|
| 489 |
+
res = cls(module.num_features)
|
| 490 |
+
if module.affine:
|
| 491 |
+
res.weight.data = module.weight.data.clone().detach()
|
| 492 |
+
res.bias.data = module.bias.data.clone().detach()
|
| 493 |
+
res.running_mean.data = module.running_mean.data
|
| 494 |
+
res.running_var.data = module.running_var.data
|
| 495 |
+
res.eps = module.eps
|
| 496 |
+
else:
|
| 497 |
+
for name, child in module.named_children():
|
| 498 |
+
new_child = cls.convert_frozen_batchnorm(child)
|
| 499 |
+
if new_child is not child:
|
| 500 |
+
res.add_module(name, new_child)
|
| 501 |
+
return res
|
| 502 |
+
|
| 503 |
+
class LayerNorm(nn.Module):
|
| 504 |
+
"""
|
| 505 |
+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
|
| 506 |
+
variance normalization over the channel dimension for inputs that have shape
|
| 507 |
+
(batch_size, channels, height, width).
|
| 508 |
+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 514 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 515 |
+
self.eps = eps
|
| 516 |
+
self.normalized_shape = (normalized_shape,)
|
| 517 |
+
|
| 518 |
+
def forward(self, x):
|
| 519 |
+
u = x.mean(1, keepdim=True)
|
| 520 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 521 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 522 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 523 |
+
return x
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class CNNBlockBase(nn.Module):
|
| 527 |
+
"""
|
| 528 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
| 529 |
+
The input and output of `forward()` method must be NCHW tensors.
|
| 530 |
+
The method can perform arbitrary computation but must match the given
|
| 531 |
+
channels and stride specification.
|
| 532 |
+
Attribute:
|
| 533 |
+
in_channels (int):
|
| 534 |
+
out_channels (int):
|
| 535 |
+
stride (int):
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
def __init__(self, in_channels, out_channels, stride):
|
| 539 |
+
"""
|
| 540 |
+
The `__init__` method of any subclass should also contain these arguments.
|
| 541 |
+
Args:
|
| 542 |
+
in_channels (int):
|
| 543 |
+
out_channels (int):
|
| 544 |
+
stride (int):
|
| 545 |
+
"""
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.in_channels = in_channels
|
| 548 |
+
self.out_channels = out_channels
|
| 549 |
+
self.stride = stride
|
| 550 |
+
|
| 551 |
+
def freeze(self):
|
| 552 |
+
"""
|
| 553 |
+
Make this block not trainable.
|
| 554 |
+
This method sets all parameters to `requires_grad=False`,
|
| 555 |
+
and convert all BatchNorm layers to FrozenBatchNorm
|
| 556 |
+
Returns:
|
| 557 |
+
the block itself
|
| 558 |
+
"""
|
| 559 |
+
for p in self.parameters():
|
| 560 |
+
p.requires_grad = False
|
| 561 |
+
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
| 562 |
+
return self
|
| 563 |
+
|
| 564 |
+
def get_norm(norm, out_channels):
|
| 565 |
+
"""
|
| 566 |
+
Args:
|
| 567 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
| 568 |
+
or a callable that takes a channel number and returns
|
| 569 |
+
the normalization layer as a nn.Module.
|
| 570 |
+
Returns:
|
| 571 |
+
nn.Module or None: the normalization layer
|
| 572 |
+
"""
|
| 573 |
+
if norm is None:
|
| 574 |
+
return None
|
| 575 |
+
if isinstance(norm, str):
|
| 576 |
+
if len(norm) == 0:
|
| 577 |
+
return None
|
| 578 |
+
norm = {
|
| 579 |
+
"BN": BatchNorm2d,
|
| 580 |
+
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
| 581 |
+
"SyncBN": nn.SyncBatchNorm,
|
| 582 |
+
"FrozenBN": FrozenBatchNorm2d,
|
| 583 |
+
"GN": lambda channels: nn.GroupNorm(32, channels),
|
| 584 |
+
# for debugging:
|
| 585 |
+
"nnSyncBN": nn.SyncBatchNorm,
|
| 586 |
+
"LN": lambda channels: LayerNorm(channels)
|
| 587 |
+
}[norm]
|
| 588 |
+
return norm(out_channels)
|
| 589 |
+
|
| 590 |
+
class DropPath(nn.Module):
|
| 591 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
def __init__(self, drop_prob=None):
|
| 595 |
+
super(DropPath, self).__init__()
|
| 596 |
+
self.drop_prob = drop_prob
|
| 597 |
+
|
| 598 |
+
def forward(self, x):
|
| 599 |
+
if self.drop_prob == 0. or not self.training:
|
| 600 |
+
return x
|
| 601 |
+
keep_prob = 1 - self.drop_prob
|
| 602 |
+
# work with diff dim tensors, not just 2D ConvNets
|
| 603 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 604 |
+
random_tensor = keep_prob + \
|
| 605 |
+
torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 606 |
+
random_tensor.floor_() # binarize
|
| 607 |
+
output = x.div(keep_prob) * random_tensor
|
| 608 |
+
return output
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class SwiGLU(nn.Module):
|
| 613 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
| 614 |
+
norm_layer=nn.LayerNorm, subln=False
|
| 615 |
+
):
|
| 616 |
+
super().__init__()
|
| 617 |
+
out_features = out_features or in_features
|
| 618 |
+
hidden_features = hidden_features or in_features
|
| 619 |
+
|
| 620 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
| 621 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
| 622 |
+
|
| 623 |
+
self.act = act_layer()
|
| 624 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
| 625 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
| 626 |
+
|
| 627 |
+
self.drop = nn.Dropout(drop)
|
| 628 |
+
|
| 629 |
+
def forward(self, x):
|
| 630 |
+
x1 = self.w1(x)
|
| 631 |
+
x2 = self.w2(x)
|
| 632 |
+
hidden = self.act(x1) * x2
|
| 633 |
+
x = self.ffn_ln(hidden)
|
| 634 |
+
x = self.w3(x)
|
| 635 |
+
x = self.drop(x)
|
| 636 |
+
return x
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class Attention(nn.Module):
|
| 640 |
+
def __init__(
|
| 641 |
+
self,
|
| 642 |
+
dim,
|
| 643 |
+
num_heads=8,
|
| 644 |
+
qkv_bias=True,
|
| 645 |
+
qk_scale=None,
|
| 646 |
+
attn_head_dim=None,
|
| 647 |
+
norm_layer=nn.LayerNorm,
|
| 648 |
+
rope=None,
|
| 649 |
+
xattn=True,
|
| 650 |
+
subln=False
|
| 651 |
+
):
|
| 652 |
+
super().__init__()
|
| 653 |
+
self.num_heads = num_heads
|
| 654 |
+
head_dim = dim // num_heads
|
| 655 |
+
if attn_head_dim is not None:
|
| 656 |
+
head_dim = attn_head_dim
|
| 657 |
+
all_head_dim = head_dim * self.num_heads
|
| 658 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 659 |
+
|
| 660 |
+
self.subln = subln
|
| 661 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 662 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 663 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 664 |
+
|
| 665 |
+
if qkv_bias:
|
| 666 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 667 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 668 |
+
else:
|
| 669 |
+
self.q_bias = None
|
| 670 |
+
self.v_bias = None
|
| 671 |
+
|
| 672 |
+
self.rope = rope
|
| 673 |
+
self.xattn = xattn
|
| 674 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 675 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
| 676 |
+
|
| 677 |
+
if self.xattn:
|
| 678 |
+
factory_kwargs = {'device': 'cuda', 'dtype': torch.float16}
|
| 679 |
+
self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs)
|
| 680 |
+
|
| 681 |
+
def forward(self, x):
|
| 682 |
+
B, H, W, C = x.shape
|
| 683 |
+
x = x.view(B, -1, C)
|
| 684 |
+
N = H * W
|
| 685 |
+
|
| 686 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
| 687 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
| 688 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
| 689 |
+
|
| 690 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
| 691 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
| 692 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
| 693 |
+
|
| 694 |
+
## rope
|
| 695 |
+
q = self.rope(q).type_as(v)
|
| 696 |
+
k = self.rope(k).type_as(v)
|
| 697 |
+
|
| 698 |
+
if self.xattn:
|
| 699 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
| 700 |
+
k = k.permute(0, 2, 1, 3)
|
| 701 |
+
v = v.permute(0, 2, 1, 3)
|
| 702 |
+
|
| 703 |
+
kv = torch.stack([k, v], dim=2)
|
| 704 |
+
x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False)
|
| 705 |
+
# x = xops.memory_efficient_attention(q, k, v)
|
| 706 |
+
x = x.reshape(B, N, -1)
|
| 707 |
+
x = self.inner_attn_ln(x)
|
| 708 |
+
else:
|
| 709 |
+
q = q * self.scale
|
| 710 |
+
attn = (q @ k.transpose(-2, -1))
|
| 711 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
| 712 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 713 |
+
x = self.inner_attn_ln(x)
|
| 714 |
+
|
| 715 |
+
x = self.proj(x)
|
| 716 |
+
x = x.view(B, H, W, C)
|
| 717 |
+
|
| 718 |
+
return x
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
class ResBottleneckBlock(CNNBlockBase):
|
| 722 |
+
"""
|
| 723 |
+
The standard bottleneck residual block without the last activation layer.
|
| 724 |
+
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
|
| 725 |
+
"""
|
| 726 |
+
|
| 727 |
+
def __init__(
|
| 728 |
+
self,
|
| 729 |
+
in_channels,
|
| 730 |
+
out_channels,
|
| 731 |
+
bottleneck_channels,
|
| 732 |
+
norm="LN",
|
| 733 |
+
act_layer=nn.GELU,
|
| 734 |
+
):
|
| 735 |
+
"""
|
| 736 |
+
Args:
|
| 737 |
+
in_channels (int): Number of input channels.
|
| 738 |
+
out_channels (int): Number of output channels.
|
| 739 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
| 740 |
+
"bottleneck" conv layers.
|
| 741 |
+
norm (str or callable): normalization for all conv layers.
|
| 742 |
+
See :func:`layers.get_norm` for supported format.
|
| 743 |
+
act_layer (callable): activation for all conv layers.
|
| 744 |
+
"""
|
| 745 |
+
super().__init__(in_channels, out_channels, 1)
|
| 746 |
+
|
| 747 |
+
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
| 748 |
+
self.norm1 = get_norm(norm, bottleneck_channels)
|
| 749 |
+
self.act1 = act_layer()
|
| 750 |
+
|
| 751 |
+
self.conv2 = Conv2d(
|
| 752 |
+
bottleneck_channels,
|
| 753 |
+
bottleneck_channels,
|
| 754 |
+
3,
|
| 755 |
+
padding=1,
|
| 756 |
+
bias=False,
|
| 757 |
+
)
|
| 758 |
+
self.norm2 = get_norm(norm, bottleneck_channels)
|
| 759 |
+
self.act2 = act_layer()
|
| 760 |
+
|
| 761 |
+
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
| 762 |
+
self.norm3 = get_norm(norm, out_channels)
|
| 763 |
+
|
| 764 |
+
for layer in [self.conv1, self.conv2, self.conv3]:
|
| 765 |
+
weight_init.c2_msra_fill(layer)
|
| 766 |
+
for layer in [self.norm1, self.norm2]:
|
| 767 |
+
layer.weight.data.fill_(1.0)
|
| 768 |
+
layer.bias.data.zero_()
|
| 769 |
+
# zero init last norm layer.
|
| 770 |
+
self.norm3.weight.data.zero_()
|
| 771 |
+
self.norm3.bias.data.zero_()
|
| 772 |
+
|
| 773 |
+
def forward(self, x):
|
| 774 |
+
out = x
|
| 775 |
+
for layer in self.children():
|
| 776 |
+
out = layer(out)
|
| 777 |
+
|
| 778 |
+
out = x + out
|
| 779 |
+
return out
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class Block(nn.Module):
|
| 783 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
| 784 |
+
|
| 785 |
+
def __init__(
|
| 786 |
+
self,
|
| 787 |
+
dim,
|
| 788 |
+
num_heads,
|
| 789 |
+
mlp_ratio=4*2/3,
|
| 790 |
+
qkv_bias=True,
|
| 791 |
+
drop_path=0.0,
|
| 792 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 793 |
+
window_size=0,
|
| 794 |
+
use_residual_block=False,
|
| 795 |
+
rope=None,
|
| 796 |
+
xattn=True,
|
| 797 |
+
subln=False,
|
| 798 |
+
# with_cp=True,
|
| 799 |
+
):
|
| 800 |
+
"""
|
| 801 |
+
Args:
|
| 802 |
+
dim (int): Number of input channels.
|
| 803 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 804 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 805 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 806 |
+
drop_path (float): Stochastic depth rate.
|
| 807 |
+
norm_layer (nn.Module): Normalization layer.
|
| 808 |
+
act_layer (nn.Module): Activation layer.
|
| 809 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 810 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 811 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then not
|
| 812 |
+
use window attention.
|
| 813 |
+
use_residual_block (bool): If True, use a residual block after the MLP block.
|
| 814 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
| 815 |
+
parameter size.
|
| 816 |
+
"""
|
| 817 |
+
super().__init__()
|
| 818 |
+
self.norm1 = norm_layer(dim)
|
| 819 |
+
self.attn = Attention(
|
| 820 |
+
dim,
|
| 821 |
+
num_heads=num_heads,
|
| 822 |
+
qkv_bias=qkv_bias,
|
| 823 |
+
rope=rope,
|
| 824 |
+
xattn=xattn,
|
| 825 |
+
subln=subln
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
# self.with_cp = with_cp
|
| 830 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 831 |
+
self.norm2 = norm_layer(dim)
|
| 832 |
+
self.mlp = SwiGLU(
|
| 833 |
+
in_features=dim,
|
| 834 |
+
hidden_features=int(dim * mlp_ratio),
|
| 835 |
+
subln=True,
|
| 836 |
+
norm_layer=norm_layer,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.window_size = window_size
|
| 840 |
+
|
| 841 |
+
self.use_residual_block = use_residual_block
|
| 842 |
+
if use_residual_block:
|
| 843 |
+
# Use a residual block with bottleneck channel as dim // 2
|
| 844 |
+
self.residual = ResBottleneckBlock(
|
| 845 |
+
in_channels=dim,
|
| 846 |
+
out_channels=dim,
|
| 847 |
+
bottleneck_channels=dim // 2,
|
| 848 |
+
norm="LN",
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
def _forward(self, x):
|
| 852 |
+
shortcut = x
|
| 853 |
+
x = self.norm1(x)
|
| 854 |
+
|
| 855 |
+
# Window partition
|
| 856 |
+
if self.window_size > 0:
|
| 857 |
+
H, W = x.shape[1], x.shape[2]
|
| 858 |
+
x, pad_hw = window_partition(x, self.window_size)
|
| 859 |
+
|
| 860 |
+
x = self.attn(x)
|
| 861 |
+
|
| 862 |
+
# Reverse window partition
|
| 863 |
+
if self.window_size > 0:
|
| 864 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
| 865 |
+
|
| 866 |
+
x = shortcut + self.drop_path(x)
|
| 867 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 868 |
+
|
| 869 |
+
if self.use_residual_block:
|
| 870 |
+
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
| 871 |
+
|
| 872 |
+
return x
|
| 873 |
+
|
| 874 |
+
def forward(self, x, with_cp=False):
|
| 875 |
+
# if self.with_cp and self.training:
|
| 876 |
+
if with_cp:
|
| 877 |
+
x = cp.checkpoint(self._forward, x)
|
| 878 |
+
else:
|
| 879 |
+
x = self._forward(x)
|
| 880 |
+
return x
|
| 881 |
+
|
| 882 |
+
#@BACKBONES.register_module()
|
| 883 |
+
class EVAViT(nn.Module):
|
| 884 |
+
"""
|
| 885 |
+
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
|
| 886 |
+
"Exploring Plain Vision Transformer Backbones for Object Detection",
|
| 887 |
+
https://arxiv.org/abs/2203.16527
|
| 888 |
+
"""
|
| 889 |
+
|
| 890 |
+
def __init__(
|
| 891 |
+
self,
|
| 892 |
+
img_size=1024,
|
| 893 |
+
patch_size=16,
|
| 894 |
+
in_chans=3,
|
| 895 |
+
embed_dim=768,
|
| 896 |
+
depth=12,
|
| 897 |
+
num_heads=12,
|
| 898 |
+
mlp_ratio=4*2/3,
|
| 899 |
+
qkv_bias=True,
|
| 900 |
+
drop_path_rate=0.0,
|
| 901 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 902 |
+
act_layer=nn.GELU,
|
| 903 |
+
use_abs_pos=True,
|
| 904 |
+
use_rel_pos=False,
|
| 905 |
+
# sim_fpn=None,
|
| 906 |
+
rope=True,
|
| 907 |
+
pt_hw_seq_len=16,
|
| 908 |
+
intp_freq=True,
|
| 909 |
+
window_size=0,
|
| 910 |
+
global_window_size=0,
|
| 911 |
+
window_block_indexes=(),
|
| 912 |
+
residual_block_indexes=(),
|
| 913 |
+
pretrain_img_size=224,
|
| 914 |
+
pretrain_use_cls_token=True,
|
| 915 |
+
out_feature="last_feat",
|
| 916 |
+
subln=False,
|
| 917 |
+
xattn=True,
|
| 918 |
+
# with_cp=True,
|
| 919 |
+
frozen=False,
|
| 920 |
+
):
|
| 921 |
+
"""
|
| 922 |
+
Args:
|
| 923 |
+
img_size (int): Input image size.
|
| 924 |
+
patch_size (int): Patch size.
|
| 925 |
+
in_chans (int): Number of input image channels.
|
| 926 |
+
embed_dim (int): Patch embedding dimension.
|
| 927 |
+
depth (int): Depth of ViT.
|
| 928 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 929 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 930 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 931 |
+
drop_path_rate (float): Stochastic depth rate.
|
| 932 |
+
norm_layer (nn.Module): Normalization layer.
|
| 933 |
+
act_layer (nn.Module): Activation layer.
|
| 934 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
| 935 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 936 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 937 |
+
window_size (int): Window size for window attention blocks.
|
| 938 |
+
window_block_indexes (list): Indexes for blocks using window attention.
|
| 939 |
+
residual_block_indexes (list): Indexes for blocks using conv propagation.
|
| 940 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
| 941 |
+
pretrain_img_size (int): input image size for pretraining models.
|
| 942 |
+
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
|
| 943 |
+
out_feature (str): name of the feature from the last block.
|
| 944 |
+
"""
|
| 945 |
+
super().__init__()
|
| 946 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
| 947 |
+
self.patch_embed = PatchEmbed(
|
| 948 |
+
kernel_size=(patch_size, patch_size),
|
| 949 |
+
stride=(patch_size, patch_size),
|
| 950 |
+
in_chans=in_chans,
|
| 951 |
+
embed_dim=embed_dim,
|
| 952 |
+
)
|
| 953 |
+
self.frozen = frozen
|
| 954 |
+
self.gradient_checkpointing = False
|
| 955 |
+
|
| 956 |
+
if use_abs_pos:
|
| 957 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 958 |
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
| 959 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
| 960 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
| 961 |
+
else:
|
| 962 |
+
self.pos_embed = None
|
| 963 |
+
|
| 964 |
+
half_head_dim = embed_dim // num_heads // 2
|
| 965 |
+
hw_seq_len = img_size // patch_size
|
| 966 |
+
|
| 967 |
+
self.rope_win = VisionRotaryEmbeddingFast(
|
| 968 |
+
dim=half_head_dim,
|
| 969 |
+
pt_seq_len=pt_hw_seq_len,
|
| 970 |
+
ft_seq_len=window_size if intp_freq else None,
|
| 971 |
+
)
|
| 972 |
+
self.rope_glb = VisionRotaryEmbeddingFast(
|
| 973 |
+
dim=half_head_dim,
|
| 974 |
+
pt_seq_len=pt_hw_seq_len,
|
| 975 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
# stochastic depth decay rule
|
| 979 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 980 |
+
|
| 981 |
+
self.blocks = nn.ModuleList()
|
| 982 |
+
for i in range(depth):
|
| 983 |
+
block = Block(
|
| 984 |
+
dim=embed_dim,
|
| 985 |
+
num_heads=num_heads,
|
| 986 |
+
mlp_ratio=mlp_ratio,
|
| 987 |
+
qkv_bias=qkv_bias,
|
| 988 |
+
drop_path=dpr[i],
|
| 989 |
+
norm_layer=norm_layer,
|
| 990 |
+
window_size=window_size if i in window_block_indexes else global_window_size,
|
| 991 |
+
use_residual_block=i in residual_block_indexes,
|
| 992 |
+
rope=self.rope_win if i in window_block_indexes else self.rope_glb,
|
| 993 |
+
xattn=xattn,
|
| 994 |
+
subln=subln,
|
| 995 |
+
# with_cp=with_cp,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
self.blocks.append(block)
|
| 999 |
+
|
| 1000 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
| 1001 |
+
self._out_feature_strides = {out_feature: patch_size}
|
| 1002 |
+
self._out_features = [out_feature]
|
| 1003 |
+
|
| 1004 |
+
if self.pos_embed is not None:
|
| 1005 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
| 1006 |
+
|
| 1007 |
+
self._freeze_stages()
|
| 1008 |
+
|
| 1009 |
+
def _freeze_stages(self):
|
| 1010 |
+
if self.frozen:
|
| 1011 |
+
self.eval()
|
| 1012 |
+
for m in self.parameters():
|
| 1013 |
+
m.requires_grad = False
|
| 1014 |
+
|
| 1015 |
+
def forward(self, x):
|
| 1016 |
+
x = self.patch_embed(x)
|
| 1017 |
+
if self.pos_embed is not None:
|
| 1018 |
+
x = x + get_abs_pos(
|
| 1019 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
for blk in self.blocks:
|
| 1023 |
+
x = blk(x, with_cp=self.gradient_checkpointing) # b, h, w, c
|
| 1024 |
+
x = x.permute(0, 3, 1, 2) # b, c, h, w
|
| 1025 |
+
|
| 1026 |
+
return x
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
class EVAVITVisionTower(nn.Module):
|
| 1030 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 1031 |
+
super().__init__()
|
| 1032 |
+
|
| 1033 |
+
self.is_loaded = False
|
| 1034 |
+
self.vision_tower_name = vision_tower
|
| 1035 |
+
self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
|
| 1036 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 1037 |
+
|
| 1038 |
+
self.args = args
|
| 1039 |
+
self.vision_tower, vision_tower_config = build_eva_vit(args=args,
|
| 1040 |
+
model_name=vision_tower,
|
| 1041 |
+
image_size=args.input_image_size
|
| 1042 |
+
)
|
| 1043 |
+
self.input_image_size=args.input_image_size
|
| 1044 |
+
self.vision_tower.config = vision_tower_config
|
| 1045 |
+
self.freeze_vision = args.freeze_vision
|
| 1046 |
+
|
| 1047 |
+
if not self.is_loaded:
|
| 1048 |
+
self.load_model()
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def load_model(self):
|
| 1052 |
+
if self.is_loaded:
|
| 1053 |
+
return
|
| 1054 |
+
|
| 1055 |
+
# hardcode
|
| 1056 |
+
self.image_processor = CLIPImageProcessor(crop_size={"height": self.args.input_image_size, "width": self.args.input_image_size},
|
| 1057 |
+
size={'shortest_edge': self.args.input_image_size},
|
| 1058 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
| 1059 |
+
image_std=[0.26862954, 0.26130258, 0.27577711])
|
| 1060 |
+
|
| 1061 |
+
# load weights
|
| 1062 |
+
if self.args.vision_tower_pretrained_from is not None:
|
| 1063 |
+
if not os.path.exists(self.args.vision_tower_pretrained_from):
|
| 1064 |
+
import warnings
|
| 1065 |
+
warnings.warn("The vision tower weights for EVA-02 vision tower does not exists, this will cause problem if you are training the model from scratch!")
|
| 1066 |
+
self.is_loaded = True
|
| 1067 |
+
return
|
| 1068 |
+
|
| 1069 |
+
pretrained_params = torch.load(self.args.vision_tower_pretrained_from)
|
| 1070 |
+
if 'ema_state' in pretrained_params:
|
| 1071 |
+
pretrained_params = pretrained_params['ema_state']
|
| 1072 |
+
elif 'module' in pretrained_params:
|
| 1073 |
+
pretrained_params = pretrained_params['module']
|
| 1074 |
+
|
| 1075 |
+
from collections import OrderedDict
|
| 1076 |
+
new_params = OrderedDict()
|
| 1077 |
+
|
| 1078 |
+
kw = ""
|
| 1079 |
+
if "det" in self.args.vision_tower_pretrained_from.lower():
|
| 1080 |
+
kw = "backbone.net."
|
| 1081 |
+
elif "clip" in self.args.vision_tower_pretrained_from.lower():
|
| 1082 |
+
kw = "visual."
|
| 1083 |
+
|
| 1084 |
+
for k, v in pretrained_params.items():
|
| 1085 |
+
if len(kw) > 0:
|
| 1086 |
+
if kw in k and ("rope" not in k):
|
| 1087 |
+
new_params[k.replace(kw, "")] = v
|
| 1088 |
+
else:
|
| 1089 |
+
if "rope" not in k:
|
| 1090 |
+
new_params[k] = v
|
| 1091 |
+
|
| 1092 |
+
incompatiblekeys = self.vision_tower.load_state_dict(new_params, strict=False)
|
| 1093 |
+
for k in incompatiblekeys[0]:
|
| 1094 |
+
if "rope" not in k:
|
| 1095 |
+
warnings.warn(f"Find incompatible keys {k} in state dict.")
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
if self.freeze_vision:
|
| 1099 |
+
self.vision_tower.requires_grad_(False)
|
| 1100 |
+
|
| 1101 |
+
self.is_loaded = True
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
# @torch.no_grad()
|
| 1105 |
+
def forward(self, images):
|
| 1106 |
+
if type(images) is list:
|
| 1107 |
+
image_features = []
|
| 1108 |
+
for image in images:
|
| 1109 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
| 1110 |
+
image_feature = image_forward_out.flatten(2,3).transpose(1,2) # b, n, c
|
| 1111 |
+
image_features.append(image_feature)
|
| 1112 |
+
else:
|
| 1113 |
+
image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
| 1114 |
+
|
| 1115 |
+
return image_forward_out
|
| 1116 |
+
|
| 1117 |
+
@property
|
| 1118 |
+
def dummy_feature(self):
|
| 1119 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 1120 |
+
|
| 1121 |
+
@property
|
| 1122 |
+
def dtype(self):
|
| 1123 |
+
return next(self.vision_tower.parameters()).dtype
|
| 1124 |
+
|
| 1125 |
+
@property
|
| 1126 |
+
def device(self):
|
| 1127 |
+
return next(self.vision_tower.parameters()).device
|
| 1128 |
+
|
| 1129 |
+
@property
|
| 1130 |
+
def config(self):
|
| 1131 |
+
return self.vision_tower.config
|
| 1132 |
+
|
| 1133 |
+
@property
|
| 1134 |
+
def hidden_size(self):
|
| 1135 |
+
#return self.config.hidden_size
|
| 1136 |
+
return self.config['hidden_dim']
|
| 1137 |
+
|
| 1138 |
+
@property
|
| 1139 |
+
def num_patches(self):
|
| 1140 |
+
# return (self.config.image_size // self.config.patch_size) ** 2
|
| 1141 |
+
return self.config['num_patches']
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
def build_eva_vit(args,
|
| 1145 |
+
model_name=None,
|
| 1146 |
+
image_size=224,
|
| 1147 |
+
window_attn=True
|
| 1148 |
+
):
|
| 1149 |
+
|
| 1150 |
+
if "336" in args.vision_tower_pretrained_from:
|
| 1151 |
+
pretrained_image_size = 336
|
| 1152 |
+
else:
|
| 1153 |
+
pretrained_image_size = 224
|
| 1154 |
+
|
| 1155 |
+
if "clip" in args.vision_tower_pretrained_from.lower():
|
| 1156 |
+
subln = True
|
| 1157 |
+
else:
|
| 1158 |
+
subln = False
|
| 1159 |
+
|
| 1160 |
+
if model_name == 'eva02-l-16':
|
| 1161 |
+
# shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
|
| 1162 |
+
if window_attn:
|
| 1163 |
+
window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
|
| 1164 |
+
else:
|
| 1165 |
+
window_block_indexes = ()
|
| 1166 |
+
|
| 1167 |
+
model = EVAViT(
|
| 1168 |
+
img_size=image_size,
|
| 1169 |
+
patch_size=16,
|
| 1170 |
+
window_size=16,
|
| 1171 |
+
in_chans=3,
|
| 1172 |
+
embed_dim=1024,
|
| 1173 |
+
depth=24,
|
| 1174 |
+
num_heads=16,
|
| 1175 |
+
mlp_ratio=4*2/3,
|
| 1176 |
+
window_block_indexes = window_block_indexes,
|
| 1177 |
+
qkv_bias=True,
|
| 1178 |
+
drop_path_rate=0.0,
|
| 1179 |
+
xattn=False,
|
| 1180 |
+
# with_cp=False,
|
| 1181 |
+
# frozen=True,
|
| 1182 |
+
)
|
| 1183 |
+
# image_size = 224 # HARDCODE
|
| 1184 |
+
eva_config = dict(image_size=image_size,
|
| 1185 |
+
patch_size=16,
|
| 1186 |
+
window_size=16,
|
| 1187 |
+
hidden_dim=1024,
|
| 1188 |
+
depth=24,
|
| 1189 |
+
num_heads=16,
|
| 1190 |
+
window_block_indexes=window_block_indexes,
|
| 1191 |
+
num_patches=image_size ** 2 // 16 ** 2,
|
| 1192 |
+
pretrained_from=args.vision_tower_pretrained_from
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
elif model_name == 'eva02-l-14':
|
| 1196 |
+
# shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
|
| 1197 |
+
if window_attn:
|
| 1198 |
+
window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
|
| 1199 |
+
else:
|
| 1200 |
+
window_block_indexes = ()
|
| 1201 |
+
|
| 1202 |
+
model = EVAViT(
|
| 1203 |
+
img_size=image_size,
|
| 1204 |
+
pretrain_img_size=pretrained_image_size,
|
| 1205 |
+
patch_size=14,
|
| 1206 |
+
window_size=16,
|
| 1207 |
+
in_chans=3,
|
| 1208 |
+
embed_dim=1024,
|
| 1209 |
+
depth=24,
|
| 1210 |
+
num_heads=16,
|
| 1211 |
+
mlp_ratio=4*2/3,
|
| 1212 |
+
window_block_indexes = window_block_indexes,
|
| 1213 |
+
qkv_bias=True,
|
| 1214 |
+
drop_path_rate=0.0,
|
| 1215 |
+
xattn=False,
|
| 1216 |
+
# with_cp=False,
|
| 1217 |
+
subln=subln,
|
| 1218 |
+
# frozen=True,
|
| 1219 |
+
)
|
| 1220 |
+
# image_size = 224 # HARDCODE
|
| 1221 |
+
eva_config = dict(image_size=image_size,
|
| 1222 |
+
patch_size=14,
|
| 1223 |
+
window_size=16,
|
| 1224 |
+
hidden_dim=1024,
|
| 1225 |
+
depth=24,
|
| 1226 |
+
num_heads=16,
|
| 1227 |
+
window_block_indexes=window_block_indexes,
|
| 1228 |
+
num_patches=image_size ** 2 // 14 ** 2,
|
| 1229 |
+
pretrained_from=args.vision_tower_pretrained_from
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
else:
|
| 1233 |
+
raise NotImplementedError
|
| 1234 |
+
|
| 1235 |
+
return model, eva_config
|
EAGLE/eagle/model/multimodal_projector/__init__.py
ADDED
|
File without changes
|
EAGLE/eagle/model/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
class IdentityMap(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x, *args, **kwargs):
|
| 10 |
+
return x
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def config(self):
|
| 14 |
+
return {"mm_projector_type": 'identity'}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SimpleResBlock(nn.Module):
|
| 18 |
+
def __init__(self, channels):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 21 |
+
|
| 22 |
+
self.proj = nn.Sequential(
|
| 23 |
+
nn.Linear(channels, channels),
|
| 24 |
+
nn.GELU(),
|
| 25 |
+
nn.Linear(channels, channels)
|
| 26 |
+
)
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.pre_norm(x)
|
| 29 |
+
return x + self.proj(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_vision_projector(config, delay_load=False, fpn_input_dim=[], **kwargs):
|
| 33 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
| 34 |
+
|
| 35 |
+
if projector_type == 'linear':
|
| 36 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 37 |
+
|
| 38 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
| 39 |
+
if mlp_gelu_match:
|
| 40 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 41 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 42 |
+
for _ in range(1, mlp_depth):
|
| 43 |
+
modules.append(nn.GELU())
|
| 44 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 45 |
+
return nn.Sequential(*modules)
|
| 46 |
+
|
| 47 |
+
if projector_type == 'identity':
|
| 48 |
+
return IdentityMap()
|
| 49 |
+
|
| 50 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
EAGLE/lmms_eval/api/__init__.py
ADDED
|
File without changes
|
EAGLE/lmms_eval/api/filter.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from lmms_eval.api.instance import Instance
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Filter:
|
| 9 |
+
"""
|
| 10 |
+
Filter classes operate on a per-task level.
|
| 11 |
+
They take all model outputs (`instance.resps` for all `task.instances`)
|
| 12 |
+
across all instances of a task, and perform operations.
|
| 13 |
+
In a single run, one can configure any number of separate filters or lists of filters.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def apply(self, resps, docs):
|
| 23 |
+
"""
|
| 24 |
+
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
| 25 |
+
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
| 26 |
+
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
| 27 |
+
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
| 28 |
+
"""
|
| 29 |
+
return resps
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class FilterEnsemble:
|
| 34 |
+
"""
|
| 35 |
+
FilterEnsemble creates a pipeline applying multiple filters.
|
| 36 |
+
Its intended usage is to stack multiple post-processing steps in order.
|
| 37 |
+
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
| 38 |
+
pipeline separately.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
name: str
|
| 42 |
+
filters: List[Filter]
|
| 43 |
+
|
| 44 |
+
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
|
| 45 |
+
resps = [inst.resps for inst in instances] # operate just on the model responses
|
| 46 |
+
for f in self.filters:
|
| 47 |
+
# apply filters in sequence
|
| 48 |
+
resps = f.apply(resps, docs)
|
| 49 |
+
|
| 50 |
+
# add the end results after filtering to filtered_requests of their respective source instances.
|
| 51 |
+
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
| 52 |
+
for inst, resp in zip(instances, resps):
|
| 53 |
+
inst.filtered_resps[self.name] = resp
|
EAGLE/lmms_eval/api/instance.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Literal, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Instance:
|
| 7 |
+
request_type: Literal["loglikelihood", "generate_until"]
|
| 8 |
+
arguments: tuple
|
| 9 |
+
idx: int
|
| 10 |
+
metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here
|
| 11 |
+
resps: list = field(default_factory=list)
|
| 12 |
+
filtered_resps: dict = field(default_factory=dict)
|
| 13 |
+
|
| 14 |
+
# initialized after init
|
| 15 |
+
task_name: str = None
|
| 16 |
+
doc_id: str = None
|
| 17 |
+
repeats: str = None
|
| 18 |
+
doc: dict = None
|
| 19 |
+
|
| 20 |
+
def __post_init__(self) -> None:
|
| 21 |
+
# unpack metadata field
|
| 22 |
+
self.task_name, self.doc_id, self.repeats = self.metadata
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def args(self):
|
| 26 |
+
"""
|
| 27 |
+
Returns (string,) where `string` is the string to calculate loglikelihood over
|
| 28 |
+
"""
|
| 29 |
+
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
|
EAGLE/lmms_eval/api/metrics.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Iterable
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import sacrebleu
|
| 6 |
+
import sklearn.metrics
|
| 7 |
+
import random
|
| 8 |
+
import evaluate
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from lmms_eval.api.registry import register_metric, register_aggregation
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Register Aggregations First
|
| 19 |
+
@register_aggregation("mean")
|
| 20 |
+
def mean(arr):
|
| 21 |
+
return sum(arr) / len(arr)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_aggregation("median")
|
| 25 |
+
def median(arr):
|
| 26 |
+
return arr[len(arr) // 2]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Certain metrics must be calculated across all documents in a benchmark.
|
| 30 |
+
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
|
| 31 |
+
@register_aggregation("perplexity")
|
| 32 |
+
def perplexity(items):
|
| 33 |
+
# return math.exp(-mean(items))
|
| 34 |
+
items = torch.exp(torch.tensor(items)).tolist()
|
| 35 |
+
return sum(items) / len(items)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@register_aggregation("weighted_perplexity")
|
| 39 |
+
def weighted_perplexity(items):
|
| 40 |
+
return math.exp(-weighted_mean(items))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_aggregation("bits_per_byte")
|
| 44 |
+
def bits_per_byte(items):
|
| 45 |
+
return -weighted_mean(items) / math.log(2)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@register_aggregation("f1")
|
| 49 |
+
def f1_score(items):
|
| 50 |
+
unzipped_list = list(zip(*items))
|
| 51 |
+
golds = unzipped_list[0]
|
| 52 |
+
preds = unzipped_list[1]
|
| 53 |
+
fscore = sklearn.metrics.f1_score(golds, preds)
|
| 54 |
+
|
| 55 |
+
return np.max(fscore)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@register_aggregation("matthews_corrcoef")
|
| 59 |
+
def matthews_corrcoef(items):
|
| 60 |
+
unzipped_list = list(zip(*items))
|
| 61 |
+
golds = unzipped_list[0]
|
| 62 |
+
preds = unzipped_list[1]
|
| 63 |
+
# print(preds)
|
| 64 |
+
return sklearn.metrics.matthews_corrcoef(golds, preds)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@register_aggregation("bleu")
|
| 68 |
+
def bleu(items):
|
| 69 |
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
| 70 |
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
| 71 |
+
n-grams in the candidate translation to n-grams in the reference text, where
|
| 72 |
+
1-gram or unigram would be each token and a bigram comparison would be each
|
| 73 |
+
word pair. The comparison is made regardless of word order
|
| 74 |
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
| 75 |
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
| 76 |
+
|
| 77 |
+
Higher is better
|
| 78 |
+
"""
|
| 79 |
+
refs = list(zip(*items))[0]
|
| 80 |
+
preds = list(zip(*items))[1]
|
| 81 |
+
refs, preds = _sacreformat(refs, preds)
|
| 82 |
+
return sacrebleu.corpus_bleu(preds, refs).score
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@register_aggregation("chrf")
|
| 86 |
+
def chrf(items):
|
| 87 |
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
| 88 |
+
based on character n-gram precision and recall enhanced with word n-grams.
|
| 89 |
+
Source: https://github.com/m-popovic/chrF
|
| 90 |
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
| 91 |
+
|
| 92 |
+
Higher is better # TODO I think
|
| 93 |
+
"""
|
| 94 |
+
refs = list(zip(*items))[0]
|
| 95 |
+
preds = list(zip(*items))[1]
|
| 96 |
+
refs, preds = _sacreformat(refs, preds)
|
| 97 |
+
return sacrebleu.corpus_chrf(preds, refs).score
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@register_aggregation("ter")
|
| 101 |
+
def ter(items):
|
| 102 |
+
"""Translation Error Rate is an error metric for machine translation that
|
| 103 |
+
measures the number of edits required to change a system output into one
|
| 104 |
+
of the references
|
| 105 |
+
Source: http://www.cs.umd.edu/~snover/tercom/
|
| 106 |
+
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
| 107 |
+
|
| 108 |
+
Lower is better
|
| 109 |
+
"""
|
| 110 |
+
refs = list(zip(*items))[0]
|
| 111 |
+
preds = list(zip(*items))[1]
|
| 112 |
+
refs, preds = _sacreformat(refs, preds)
|
| 113 |
+
return sacrebleu.corpus_ter(preds, refs).score
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@register_metric(
|
| 117 |
+
metric="acc",
|
| 118 |
+
higher_is_better=True,
|
| 119 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 120 |
+
aggregation="mean",
|
| 121 |
+
)
|
| 122 |
+
def acc_fn(items): # This is a passthrough function
|
| 123 |
+
return items
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@register_metric(
|
| 127 |
+
metric="acc_norm",
|
| 128 |
+
higher_is_better=True,
|
| 129 |
+
output_type=["loglikelihood", "multiple_choice"],
|
| 130 |
+
aggregation="mean",
|
| 131 |
+
)
|
| 132 |
+
def acc_norm_fn(items): # This is a passthrough function
|
| 133 |
+
return items
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@register_metric(
|
| 137 |
+
metric="acc_mutual_info",
|
| 138 |
+
higher_is_better=True,
|
| 139 |
+
output_type="multiple_choice",
|
| 140 |
+
aggregation="mean",
|
| 141 |
+
)
|
| 142 |
+
def acc_mutual_info_fn(items): # This is a passthrough function
|
| 143 |
+
return items
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
exact_match = evaluate.load("exact_match")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@register_metric(
|
| 150 |
+
metric="exact_match",
|
| 151 |
+
higher_is_better=True,
|
| 152 |
+
output_type="generate_until",
|
| 153 |
+
aggregation="mean",
|
| 154 |
+
)
|
| 155 |
+
def exact_match_fn(**kwargs):
|
| 156 |
+
return exact_match.compute(**kwargs)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@register_metric(
|
| 160 |
+
metric="perplexity",
|
| 161 |
+
higher_is_better=False,
|
| 162 |
+
output_type="loglikelihood",
|
| 163 |
+
aggregation="perplexity",
|
| 164 |
+
)
|
| 165 |
+
def perplexity_fn(items): # This is a passthrough function
|
| 166 |
+
return items
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def levenshtein_distance(s1, s2):
|
| 170 |
+
if len(s1) > len(s2):
|
| 171 |
+
s1, s2 = s2, s1
|
| 172 |
+
|
| 173 |
+
distances = range(len(s1) + 1)
|
| 174 |
+
for i2, c2 in enumerate(s2):
|
| 175 |
+
distances_ = [i2 + 1]
|
| 176 |
+
for i1, c1 in enumerate(s1):
|
| 177 |
+
if c1 == c2:
|
| 178 |
+
distances_.append(distances[i1])
|
| 179 |
+
else:
|
| 180 |
+
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
| 181 |
+
distances = distances_
|
| 182 |
+
return distances[-1]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@register_metric(
|
| 186 |
+
metric="anls",
|
| 187 |
+
higher_is_better=True,
|
| 188 |
+
output_type="generate_until",
|
| 189 |
+
aggregation="mean",
|
| 190 |
+
)
|
| 191 |
+
def anls(
|
| 192 |
+
references,
|
| 193 |
+
predictions,
|
| 194 |
+
thresh_hold=0.5,
|
| 195 |
+
): # This is a passthrough function
|
| 196 |
+
"""https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/infographicsvqa_eval.py"""
|
| 197 |
+
values = []
|
| 198 |
+
for answer in references:
|
| 199 |
+
# preprocess both the answers - gt and prediction
|
| 200 |
+
gt_answer = " ".join(answer.strip().lower().split())
|
| 201 |
+
det_answer = " ".join(predictions[0].strip().lower().split())
|
| 202 |
+
|
| 203 |
+
# dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
|
| 204 |
+
dist = levenshtein_distance(gt_answer, det_answer)
|
| 205 |
+
length = max(len(answer.upper()), len(predictions[0].upper()))
|
| 206 |
+
values.append(0.0 if length == 0 else float(dist) / float(length))
|
| 207 |
+
|
| 208 |
+
question_result = 1 - min(values)
|
| 209 |
+
|
| 210 |
+
if question_result < thresh_hold:
|
| 211 |
+
question_result = 0
|
| 212 |
+
return {"anls": question_result}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def pop_stddev(arr):
|
| 216 |
+
mu = mean(arr)
|
| 217 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def sample_stddev(arr):
|
| 221 |
+
mu = mean(arr)
|
| 222 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def mean_stderr(arr):
|
| 226 |
+
return sample_stddev(arr) / math.sqrt(len(arr))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@register_metric(
|
| 230 |
+
metric="mcc",
|
| 231 |
+
higher_is_better=True,
|
| 232 |
+
output_type="multiple_choice",
|
| 233 |
+
aggregation="matthews_corrcoef",
|
| 234 |
+
)
|
| 235 |
+
def mcc_fn(items): # This is a passthrough function
|
| 236 |
+
return items
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@register_metric(
|
| 240 |
+
metric="f1",
|
| 241 |
+
higher_is_better=True,
|
| 242 |
+
output_type="multiple_choice",
|
| 243 |
+
aggregation="f1",
|
| 244 |
+
)
|
| 245 |
+
def f1_fn(items): # This is a passthrough function
|
| 246 |
+
return items
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@register_metric(
|
| 250 |
+
metric="bleu",
|
| 251 |
+
higher_is_better=True,
|
| 252 |
+
output_type="generate_until",
|
| 253 |
+
aggregation="bleu",
|
| 254 |
+
)
|
| 255 |
+
def bleu_fn(items): # This is a passthrough function
|
| 256 |
+
return items
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@register_metric(
|
| 260 |
+
metric="chrf",
|
| 261 |
+
higher_is_better=True,
|
| 262 |
+
output_type="generate_until",
|
| 263 |
+
aggregation="chrf",
|
| 264 |
+
)
|
| 265 |
+
def chrf_fn(items): # This is a passthrough function
|
| 266 |
+
return items
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@register_metric(
|
| 270 |
+
metric="ter",
|
| 271 |
+
higher_is_better=True,
|
| 272 |
+
output_type="generate_until",
|
| 273 |
+
aggregation="ter",
|
| 274 |
+
)
|
| 275 |
+
def ter_fn(items): # This is a passthrough function
|
| 276 |
+
return items
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@register_metric(
|
| 280 |
+
metric="acc_all",
|
| 281 |
+
higher_is_better=True,
|
| 282 |
+
output_type="loglikelihood",
|
| 283 |
+
aggregation="mean",
|
| 284 |
+
)
|
| 285 |
+
def acc_all(items):
|
| 286 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 287 |
+
question_scoring_dict = {}
|
| 288 |
+
preds = list(zip(*items))[0]
|
| 289 |
+
docs = list(zip(*items))[1]
|
| 290 |
+
|
| 291 |
+
for doc, pred in zip(docs, preds):
|
| 292 |
+
paragraph_id = doc["idx"]["paragraph"]
|
| 293 |
+
question_id = doc["idx"]["question"]
|
| 294 |
+
if (paragraph_id, question_id) not in question_scoring_dict:
|
| 295 |
+
question_scoring_dict[(paragraph_id, question_id)] = []
|
| 296 |
+
|
| 297 |
+
gold_label = doc["label"] == 1
|
| 298 |
+
|
| 299 |
+
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
| 300 |
+
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
| 301 |
+
return acc
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def acc_all_stderr(items):
|
| 305 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 306 |
+
question_scoring_dict = {}
|
| 307 |
+
preds = list(zip(*items))[0]
|
| 308 |
+
docs = list(zip(*items))[1]
|
| 309 |
+
|
| 310 |
+
for doc, pred in zip(docs, preds):
|
| 311 |
+
question_id = doc["idx"]["question"]
|
| 312 |
+
if question_id not in question_scoring_dict:
|
| 313 |
+
question_scoring_dict[question_id] = []
|
| 314 |
+
|
| 315 |
+
gold_label = doc["label"] == 1
|
| 316 |
+
question_scoring_dict[question_id].append(gold_label == pred)
|
| 317 |
+
|
| 318 |
+
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
| 319 |
+
return acc
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 323 |
+
"""Compute max metric between prediction and each ground truth."""
|
| 324 |
+
scores_for_ground_truths = []
|
| 325 |
+
for ground_truth in ground_truths:
|
| 326 |
+
score = metric_fn(prediction, ground_truth)
|
| 327 |
+
scores_for_ground_truths.append(score)
|
| 328 |
+
return max(scores_for_ground_truths)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def weighted_mean(items):
|
| 332 |
+
a, b = zip(*items)
|
| 333 |
+
return sum(a) / sum(b)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def is_non_str_iterable(obj):
|
| 337 |
+
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _sacreformat(refs, preds):
|
| 341 |
+
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
| 342 |
+
# Sacrebleu expects (List[str], List[List[str])
|
| 343 |
+
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
| 344 |
+
|
| 345 |
+
# Note [ref1_stream] is the first reference for each pred.
|
| 346 |
+
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
| 347 |
+
# This is a different order of dimensions that I would expect
|
| 348 |
+
|
| 349 |
+
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
| 350 |
+
# Must become List[List[str]] with the inner list corresponding to preds
|
| 351 |
+
if not is_non_str_iterable(refs):
|
| 352 |
+
refs = list(refs)
|
| 353 |
+
if not is_non_str_iterable(refs[0]):
|
| 354 |
+
refs = [[ref] for ref in refs]
|
| 355 |
+
refs = list(zip(*refs))
|
| 356 |
+
# Note the number of refs in each ref list much match the number of preds
|
| 357 |
+
|
| 358 |
+
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
| 359 |
+
if not is_non_str_iterable(preds):
|
| 360 |
+
preds = list(preds)
|
| 361 |
+
if is_non_str_iterable(preds[0]):
|
| 362 |
+
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
| 363 |
+
preds = [pred[0] for pred in preds]
|
| 364 |
+
|
| 365 |
+
return refs, preds
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# stderr stuff
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class _bootstrap_internal:
|
| 372 |
+
def __init__(self, f, n) -> None:
|
| 373 |
+
self.f = f
|
| 374 |
+
self.n = n
|
| 375 |
+
|
| 376 |
+
def __call__(self, v):
|
| 377 |
+
i, xs = v
|
| 378 |
+
rnd = random.Random()
|
| 379 |
+
rnd.seed(i)
|
| 380 |
+
res = []
|
| 381 |
+
for _ in range(self.n):
|
| 382 |
+
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
| 383 |
+
return res
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def bootstrap_stderr(f, xs, iters):
|
| 387 |
+
import multiprocessing as mp
|
| 388 |
+
|
| 389 |
+
pool = mp.Pool(mp.cpu_count())
|
| 390 |
+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
| 391 |
+
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
| 392 |
+
# Unfortunately, I haven't been able to figure out what the right correction is
|
| 393 |
+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
| 394 |
+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
| 395 |
+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
| 396 |
+
res = []
|
| 397 |
+
chunk_size = min(1000, iters)
|
| 398 |
+
from tqdm import tqdm
|
| 399 |
+
|
| 400 |
+
print("bootstrapping for stddev:", f.__name__)
|
| 401 |
+
for bootstrap in tqdm(
|
| 402 |
+
pool.imap(
|
| 403 |
+
_bootstrap_internal(f, chunk_size),
|
| 404 |
+
[(i, xs) for i in range(iters // chunk_size)],
|
| 405 |
+
),
|
| 406 |
+
total=iters // chunk_size,
|
| 407 |
+
):
|
| 408 |
+
# sample w replacement
|
| 409 |
+
res.extend(bootstrap)
|
| 410 |
+
|
| 411 |
+
pool.close()
|
| 412 |
+
return sample_stddev(res)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def stderr_for_metric(metric, bootstrap_iters):
|
| 416 |
+
bootstrappable = [
|
| 417 |
+
median,
|
| 418 |
+
matthews_corrcoef,
|
| 419 |
+
f1_score,
|
| 420 |
+
perplexity,
|
| 421 |
+
bleu,
|
| 422 |
+
chrf,
|
| 423 |
+
ter,
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
if metric in bootstrappable:
|
| 427 |
+
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
| 428 |
+
|
| 429 |
+
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
| 430 |
+
|
| 431 |
+
return stderr.get(metric, None)
|
EAGLE/lmms_eval/api/model.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from typing import Union, List, Tuple, Optional, Type, TypeVar
|
| 5 |
+
from sqlitedict import SqliteDict
|
| 6 |
+
import json
|
| 7 |
+
import hashlib
|
| 8 |
+
from lmms_eval.api.instance import Instance
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from lmms_eval import utils
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 14 |
+
|
| 15 |
+
T = TypeVar("T", bound="lmms")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class lmms(abc.ABC):
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
"""Defines the interface that should be implemented by all lmms subclasses.
|
| 21 |
+
lmmss are assumed to take image-text as input and yield strings as output
|
| 22 |
+
(inputs/outputs should be tokenization-agnostic.)
|
| 23 |
+
"""
|
| 24 |
+
# set rank and world size to a single process, by default.
|
| 25 |
+
self._rank = 0
|
| 26 |
+
self._world_size = 1
|
| 27 |
+
self.cache_hook = CacheHook(None)
|
| 28 |
+
self.task_dict = {}
|
| 29 |
+
|
| 30 |
+
@abc.abstractmethod
|
| 31 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 32 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
| 33 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
| 34 |
+
LMM calls whenever possible.
|
| 35 |
+
|
| 36 |
+
:param requests: list[Instance]
|
| 37 |
+
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
| 38 |
+
`context: str`
|
| 39 |
+
Context string. Implementations of LMM must be able to handle an
|
| 40 |
+
empty context string.
|
| 41 |
+
`continuation: str`
|
| 42 |
+
The continuation over which log likelihood will be calculated. If
|
| 43 |
+
there is a word boundary, the space should be in the continuation.
|
| 44 |
+
For example, context="hello" continuation=" world" is correct.
|
| 45 |
+
'visual_list: list[dict]'
|
| 46 |
+
Visual input to the model. Can be None.
|
| 47 |
+
|
| 48 |
+
:return: list[tuple[float, bool]]
|
| 49 |
+
A list of pairs (logprob, isgreedy)
|
| 50 |
+
`logprob: float`
|
| 51 |
+
The log probability of `continuation`.
|
| 52 |
+
`isgreedy`:
|
| 53 |
+
Whether `continuation` would be generated by greedy sampling from `context`.
|
| 54 |
+
"""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
# TODO: Add an optional max length
|
| 58 |
+
@abc.abstractmethod
|
| 59 |
+
def generate_until(self, requests) -> List[str]:
|
| 60 |
+
"""Generate greedily until a stopping sequence
|
| 61 |
+
|
| 62 |
+
:param requests: list[Instance]
|
| 63 |
+
A list of Instance objects with property `args` which returns a tuple (context, until).
|
| 64 |
+
context: str
|
| 65 |
+
Context string
|
| 66 |
+
generation_kwargs: dict
|
| 67 |
+
Generation Kwargs
|
| 68 |
+
'visual_list: list[dict]'
|
| 69 |
+
Visual input to the model. Can be None.
|
| 70 |
+
:return: list[str]
|
| 71 |
+
A list of strings continuation
|
| 72 |
+
continuation: str
|
| 73 |
+
The generated continuation.
|
| 74 |
+
"""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T:
|
| 79 |
+
"""
|
| 80 |
+
Creates an instance of the LMM class using the given argument string and additional config.
|
| 81 |
+
|
| 82 |
+
Parameters:
|
| 83 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 84 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
- Instance of the LMM class.
|
| 88 |
+
"""
|
| 89 |
+
additional_config = {} if additional_config is None else additional_config
|
| 90 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 91 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 92 |
+
return cls(**args, **args2)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def rank(self):
|
| 96 |
+
# used in the case of parallelism. Hardcoded to
|
| 97 |
+
# ensure no errors arise using API models which do
|
| 98 |
+
# not support multi-device parallelism nor expect it.
|
| 99 |
+
return self._rank
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def world_size(self):
|
| 103 |
+
# used in the case of parallelism. Hardcoded to
|
| 104 |
+
# ensure no errors arise using API models which do
|
| 105 |
+
# not support multi-device parallelism nor expect it.
|
| 106 |
+
return self._world_size
|
| 107 |
+
|
| 108 |
+
def set_cache_hook(self, cache_hook) -> None:
|
| 109 |
+
self.cache_hook = cache_hook
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
### SQLite-based caching of LMM responses
|
| 113 |
+
def hash_args(attr, args):
|
| 114 |
+
dat = json.dumps([attr] + list(args))
|
| 115 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class CacheHook:
|
| 119 |
+
def __init__(self, cachinglm) -> None:
|
| 120 |
+
if cachinglm is None:
|
| 121 |
+
self.dbdict = None
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
self.dbdict = cachinglm.dbdict
|
| 125 |
+
|
| 126 |
+
def add_partial(self, attr, req, res) -> None:
|
| 127 |
+
if self.dbdict is None:
|
| 128 |
+
return
|
| 129 |
+
hsh = hash_args(attr, req)
|
| 130 |
+
self.dbdict[hsh] = res
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class CachingLMM:
|
| 134 |
+
def __init__(self, lm, cache_db) -> None:
|
| 135 |
+
"""LMM wrapper that returns cached results if they exist, and uses the underlying LMM if not.
|
| 136 |
+
|
| 137 |
+
:param lm: LMM
|
| 138 |
+
Underlying LMM
|
| 139 |
+
:param cache_db: str
|
| 140 |
+
Path to cache db
|
| 141 |
+
"""
|
| 142 |
+
self.lm = lm
|
| 143 |
+
self.cache_db = cache_db
|
| 144 |
+
if os.path.dirname(cache_db):
|
| 145 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 146 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 147 |
+
|
| 148 |
+
# add hook to lm
|
| 149 |
+
lm.set_cache_hook(self.get_cache_hook())
|
| 150 |
+
|
| 151 |
+
def __getattr__(self, attr):
|
| 152 |
+
lm_attr = getattr(self.lm, attr)
|
| 153 |
+
if not callable(lm_attr):
|
| 154 |
+
return lm_attr
|
| 155 |
+
|
| 156 |
+
def fn(requests):
|
| 157 |
+
res = []
|
| 158 |
+
remaining_reqs = []
|
| 159 |
+
warned = False
|
| 160 |
+
# figure out which ones are cached and which ones are new
|
| 161 |
+
eval_logger.info(f"Loading '{attr}' responses from cache '{self.cache_db}' where possible...")
|
| 162 |
+
for req in tqdm(requests):
|
| 163 |
+
hsh = hash_args(attr, req.args)
|
| 164 |
+
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
| 165 |
+
# when we are doing non-greedy generation, don't use the cache
|
| 166 |
+
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
| 167 |
+
if not warned:
|
| 168 |
+
eval_logger.warning(f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests.")
|
| 169 |
+
warned = True
|
| 170 |
+
res.append(None)
|
| 171 |
+
remaining_reqs.append(req)
|
| 172 |
+
elif hsh in self.dbdict:
|
| 173 |
+
ob = self.dbdict[hsh]
|
| 174 |
+
|
| 175 |
+
assert ob is not None
|
| 176 |
+
|
| 177 |
+
res.append(ob)
|
| 178 |
+
else:
|
| 179 |
+
res.append(None)
|
| 180 |
+
remaining_reqs.append(req)
|
| 181 |
+
|
| 182 |
+
# actually run the LMM on the requests that do not have cached results
|
| 183 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 184 |
+
|
| 185 |
+
# stick the new ones back into the list and also cache any of the new ones
|
| 186 |
+
resptr = 0
|
| 187 |
+
for req, r in zip(remaining_reqs, rem_res):
|
| 188 |
+
while res[resptr] is not None:
|
| 189 |
+
resptr += 1
|
| 190 |
+
|
| 191 |
+
res[resptr] = r
|
| 192 |
+
|
| 193 |
+
# caching
|
| 194 |
+
hsh = hash_args(attr, req.args)
|
| 195 |
+
self.dbdict[hsh] = r
|
| 196 |
+
self.dbdict.commit()
|
| 197 |
+
|
| 198 |
+
return res
|
| 199 |
+
|
| 200 |
+
return fn
|
| 201 |
+
|
| 202 |
+
def get_cache_hook(self):
|
| 203 |
+
return CacheHook(self)
|
EAGLE/lmms_eval/api/registry.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lmms_eval.api.model import lmms
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 6 |
+
|
| 7 |
+
MODEL_REGISTRY = {}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def register_model(*names):
|
| 11 |
+
# either pass a list or a single alias.
|
| 12 |
+
# function receives them as a tuple of strings
|
| 13 |
+
|
| 14 |
+
def decorate(cls):
|
| 15 |
+
for name in names:
|
| 16 |
+
assert issubclass(cls, lmms), f"Model '{name}' ({cls.__name__}) must extend lmms class"
|
| 17 |
+
|
| 18 |
+
assert name not in MODEL_REGISTRY, f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
|
| 19 |
+
|
| 20 |
+
MODEL_REGISTRY[name] = cls
|
| 21 |
+
return cls
|
| 22 |
+
|
| 23 |
+
return decorate
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_model(model_name):
|
| 27 |
+
try:
|
| 28 |
+
return MODEL_REGISTRY[model_name]
|
| 29 |
+
except KeyError:
|
| 30 |
+
raise ValueError(f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
TASK_REGISTRY = {} # Key: task name, Value: task ConfigurableTask class
|
| 34 |
+
GROUP_REGISTRY = {} # Key: group name, Value: list of task names or group names
|
| 35 |
+
ALL_TASKS = set() # Set of all task names and group names
|
| 36 |
+
func2task_index = {} # Key: task ConfigurableTask class, Value: task name
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def register_task(name):
|
| 40 |
+
def decorate(fn):
|
| 41 |
+
assert name not in TASK_REGISTRY, f"task named '{name}' conflicts with existing registered task!"
|
| 42 |
+
|
| 43 |
+
TASK_REGISTRY[name] = fn
|
| 44 |
+
ALL_TASKS.add(name)
|
| 45 |
+
func2task_index[fn.__name__] = name
|
| 46 |
+
return fn
|
| 47 |
+
|
| 48 |
+
return decorate
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def register_group(name):
|
| 52 |
+
def decorate(fn):
|
| 53 |
+
func_name = func2task_index[fn.__name__]
|
| 54 |
+
if name in GROUP_REGISTRY:
|
| 55 |
+
GROUP_REGISTRY[name].append(func_name)
|
| 56 |
+
else:
|
| 57 |
+
GROUP_REGISTRY[name] = [func_name]
|
| 58 |
+
ALL_TASKS.add(name)
|
| 59 |
+
return fn
|
| 60 |
+
|
| 61 |
+
return decorate
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
OUTPUT_TYPE_REGISTRY = {}
|
| 65 |
+
METRIC_REGISTRY = {}
|
| 66 |
+
METRIC_AGGREGATION_REGISTRY = {}
|
| 67 |
+
AGGREGATION_REGISTRY = {}
|
| 68 |
+
HIGHER_IS_BETTER_REGISTRY = {}
|
| 69 |
+
|
| 70 |
+
DEFAULT_METRIC_REGISTRY = {
|
| 71 |
+
"loglikelihood": [
|
| 72 |
+
"perplexity",
|
| 73 |
+
"acc",
|
| 74 |
+
],
|
| 75 |
+
"multiple_choice": ["acc", "acc_norm"],
|
| 76 |
+
"generate_until": ["exact_match"],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def register_metric(**args):
|
| 81 |
+
# TODO: do we want to enforce a certain interface to registered metrics?
|
| 82 |
+
def decorate(fn):
|
| 83 |
+
assert "metric" in args
|
| 84 |
+
name = args["metric"]
|
| 85 |
+
|
| 86 |
+
for key, registry in [
|
| 87 |
+
("metric", METRIC_REGISTRY),
|
| 88 |
+
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
|
| 89 |
+
("aggregation", METRIC_AGGREGATION_REGISTRY),
|
| 90 |
+
]:
|
| 91 |
+
if key in args:
|
| 92 |
+
value = args[key]
|
| 93 |
+
assert value not in registry, f"{key} named '{value}' conflicts with existing registered {key}!"
|
| 94 |
+
|
| 95 |
+
if key == "metric":
|
| 96 |
+
registry[name] = fn
|
| 97 |
+
elif key == "aggregation":
|
| 98 |
+
registry[name] = AGGREGATION_REGISTRY[value]
|
| 99 |
+
else:
|
| 100 |
+
registry[name] = value
|
| 101 |
+
|
| 102 |
+
return fn
|
| 103 |
+
|
| 104 |
+
return decorate
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_aggregation(name):
|
| 108 |
+
def decorate(fn):
|
| 109 |
+
assert name not in AGGREGATION_REGISTRY, f"aggregation named '{name}' conflicts with existing registered aggregation!"
|
| 110 |
+
|
| 111 |
+
AGGREGATION_REGISTRY[name] = fn
|
| 112 |
+
return fn
|
| 113 |
+
|
| 114 |
+
return decorate
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_aggregation(name):
|
| 118 |
+
try:
|
| 119 |
+
return AGGREGATION_REGISTRY[name]
|
| 120 |
+
except KeyError:
|
| 121 |
+
eval_logger.warning(
|
| 122 |
+
"{} not a registered aggregation metric!".format(name),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_metric_aggregation(name):
|
| 127 |
+
try:
|
| 128 |
+
return METRIC_AGGREGATION_REGISTRY[name]
|
| 129 |
+
except KeyError:
|
| 130 |
+
eval_logger.warning(
|
| 131 |
+
"{} metric is not assigned a default aggregation!".format(name),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def is_higher_better(metric_name):
|
| 136 |
+
try:
|
| 137 |
+
return HIGHER_IS_BETTER_REGISTRY[metric_name]
|
| 138 |
+
except KeyError:
|
| 139 |
+
eval_logger.warning(f"higher_is_better not specified for metric '{metric_name}'!")
|
EAGLE/lmms_eval/api/samplers.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ContextSampler:
|
| 2 |
+
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
|
| 3 |
+
self.rnd = rnd
|
| 4 |
+
assert self.rnd, "must pass rnd to FewShotSampler!"
|
| 5 |
+
|
| 6 |
+
self.task = task
|
| 7 |
+
self.config = task._config
|
| 8 |
+
|
| 9 |
+
self.target_delimiter = self.config.target_delimiter
|
| 10 |
+
self.fewshot_delimiter = self.config.fewshot_delimiter
|
| 11 |
+
|
| 12 |
+
self.doc_to_text = self.task.doc_to_text
|
| 13 |
+
self.doc_to_target = self.task.doc_to_target
|
| 14 |
+
self.doc_to_choice = self.task.doc_to_choice
|
| 15 |
+
|
| 16 |
+
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
| 17 |
+
if fewshot_indices: # subset few-shot docs from
|
| 18 |
+
self.docs = self.docs.select(fewshot_indices)
|
| 19 |
+
|
| 20 |
+
def get_context(self, doc, num_fewshot):
|
| 21 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
| 22 |
+
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
|
| 23 |
+
|
| 24 |
+
# draw `n_samples` docs from fewshot_docs
|
| 25 |
+
fewshotex = self.sample(n_samples)
|
| 26 |
+
|
| 27 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 28 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 29 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 30 |
+
|
| 31 |
+
labeled_examples = (
|
| 32 |
+
self.fewshot_delimiter.join(
|
| 33 |
+
[
|
| 34 |
+
# TODO: is separating doc_to_text and doc_to_target by one space always desired?
|
| 35 |
+
(self.doc_to_text(doc) if (self.config.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)])
|
| 36 |
+
+ self.target_delimiter
|
| 37 |
+
+ (
|
| 38 |
+
str(self.doc_to_target(doc)[0])
|
| 39 |
+
if type(self.doc_to_target(doc)) is list
|
| 40 |
+
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
|
| 41 |
+
)
|
| 42 |
+
for doc in selected_docs
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
+ self.fewshot_delimiter
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return labeled_examples
|
| 49 |
+
|
| 50 |
+
def sample(self, n):
|
| 51 |
+
"""
|
| 52 |
+
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
return self.rnd.sample(self.docs, n)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class FirstNSampler(ContextSampler):
|
| 59 |
+
def sample(self, n) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Draw the first `n` samples in order from the specified split.
|
| 62 |
+
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
| 63 |
+
"""
|
| 64 |
+
assert n <= len(self.docs), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
| 65 |
+
return self.docs[:n]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BalancedSampler(ContextSampler):
|
| 69 |
+
def sample(self, n) -> None:
|
| 70 |
+
"""
|
| 71 |
+
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
| 72 |
+
TODO: what order should they be in? maybe random?
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ManualSampler(ContextSampler):
|
| 79 |
+
def sample(self, n) -> None:
|
| 80 |
+
""" """
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
SAMPLER_REGISTRY = {
|
| 85 |
+
"default": ContextSampler,
|
| 86 |
+
"first_n": FirstNSampler,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_sampler(name):
|
| 91 |
+
try:
|
| 92 |
+
return SAMPLER_REGISTRY[name]
|
| 93 |
+
except KeyError:
|
| 94 |
+
raise ValueError(f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}")
|
EAGLE/lmms_eval/api/task.py
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from dataclasses import dataclass, field, asdict
|
| 3 |
+
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import ast
|
| 8 |
+
import logging
|
| 9 |
+
import random
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import datasets
|
| 13 |
+
from datasets import Image, Sequence
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import ImageFile
|
| 16 |
+
|
| 17 |
+
from datasets import DownloadConfig
|
| 18 |
+
from typing import Union, List, Any
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
from tenacity import retry, stop_after_attempt, wait_fixed
|
| 21 |
+
|
| 22 |
+
from lmms_eval import utils
|
| 23 |
+
from lmms_eval.api import samplers
|
| 24 |
+
from lmms_eval.api.instance import Instance
|
| 25 |
+
|
| 26 |
+
from lmms_eval.filters import build_filter_ensemble
|
| 27 |
+
from lmms_eval.api.registry import (
|
| 28 |
+
get_aggregation,
|
| 29 |
+
get_metric_aggregation,
|
| 30 |
+
is_higher_better,
|
| 31 |
+
DEFAULT_METRIC_REGISTRY,
|
| 32 |
+
METRIC_REGISTRY,
|
| 33 |
+
OUTPUT_TYPE_REGISTRY,
|
| 34 |
+
AGGREGATION_REGISTRY,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
ALL_OUTPUT_TYPES = [
|
| 38 |
+
"loglikelihood",
|
| 39 |
+
"multiple_choice",
|
| 40 |
+
"generate_until",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 45 |
+
|
| 46 |
+
# HuggingfaceM4/NoCaps contains truncated image in test split
|
| 47 |
+
# Include this inside code block to avoid error
|
| 48 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class TaskConfig(dict):
|
| 53 |
+
# task naming/registry
|
| 54 |
+
task: str = None
|
| 55 |
+
task_alias: str = None
|
| 56 |
+
group: Union[str, list] = None
|
| 57 |
+
group_alias: Union[str, list] = None
|
| 58 |
+
# HF dataset options.
|
| 59 |
+
# which dataset to use,
|
| 60 |
+
# and what splits for what purpose
|
| 61 |
+
dataset_path: str = None
|
| 62 |
+
dataset_name: str = None
|
| 63 |
+
dataset_kwargs: dict = None
|
| 64 |
+
training_split: str = None
|
| 65 |
+
validation_split: str = None
|
| 66 |
+
test_split: str = None
|
| 67 |
+
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
|
| 68 |
+
# formatting / prompting options.
|
| 69 |
+
# see docs/advanced_task_guide.md for more info
|
| 70 |
+
process_docs: Callable = None
|
| 71 |
+
doc_to_visual: Union[Callable, str] = None
|
| 72 |
+
doc_to_text: Union[Callable, str] = None
|
| 73 |
+
doc_to_target: Union[Callable, str] = None
|
| 74 |
+
doc_to_choice: Union[Callable, str, dict, list] = None
|
| 75 |
+
process_results: Union[Callable, str] = None
|
| 76 |
+
use_prompt: str = None
|
| 77 |
+
description: str = ""
|
| 78 |
+
target_delimiter: str = " "
|
| 79 |
+
fewshot_delimiter: str = "\n\n"
|
| 80 |
+
fewshot_config: dict = None
|
| 81 |
+
# runtime configuration options
|
| 82 |
+
num_fewshot: int = None
|
| 83 |
+
# scoring options
|
| 84 |
+
metric_list: list = None
|
| 85 |
+
output_type: str = "generate_until"
|
| 86 |
+
generation_kwargs: dict = None
|
| 87 |
+
repeats: int = 1
|
| 88 |
+
filter_list: Union[str, list] = None
|
| 89 |
+
should_decontaminate: bool = False
|
| 90 |
+
doc_to_decontamination_query: str = None
|
| 91 |
+
|
| 92 |
+
metadata: Union[str, list] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 93 |
+
|
| 94 |
+
model_specific_prompt_kwargs: dict = None
|
| 95 |
+
model_specific_generation_kwargs: dict = None
|
| 96 |
+
model_specific_target_kwargs: dict = None
|
| 97 |
+
|
| 98 |
+
def __post_init__(self) -> None:
|
| 99 |
+
if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)):
|
| 100 |
+
import inspect
|
| 101 |
+
from importlib import import_module
|
| 102 |
+
|
| 103 |
+
self.dataset_path = inspect.getfile(import_module(self.dataset_path))
|
| 104 |
+
|
| 105 |
+
if self.generation_kwargs is not None:
|
| 106 |
+
if self.output_type != "generate_until":
|
| 107 |
+
eval_logger.warning(f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!")
|
| 108 |
+
assert self.output_type != "generate_until"
|
| 109 |
+
|
| 110 |
+
if "temperature" in self.generation_kwargs:
|
| 111 |
+
self.generation_kwargs["temperature"] = float(self.generation_kwargs["temperature"])
|
| 112 |
+
|
| 113 |
+
if "until" not in self.generation_kwargs:
|
| 114 |
+
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
| 115 |
+
else:
|
| 116 |
+
if self.output_type == "generate_until":
|
| 117 |
+
# ensure that we greedily generate in absence of explicit arguments otherwise
|
| 118 |
+
self.generation_kwargs = {
|
| 119 |
+
"until": None if self.fewshot_delimiter is None else [self.fewshot_delimiter],
|
| 120 |
+
"do_sample": False,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, item):
|
| 126 |
+
return getattr(self, item)
|
| 127 |
+
|
| 128 |
+
def __setitem__(self, item, value):
|
| 129 |
+
return setattr(self, item, value)
|
| 130 |
+
|
| 131 |
+
def to_dict(self):
|
| 132 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
| 133 |
+
null fields will not be printed.
|
| 134 |
+
Used for dumping results alongside full task configuration
|
| 135 |
+
|
| 136 |
+
:return: dict
|
| 137 |
+
A printable dictionary version of the TaskConfig object.
|
| 138 |
+
|
| 139 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
| 140 |
+
"""
|
| 141 |
+
cfg_dict = asdict(self)
|
| 142 |
+
# remove values that are `None`
|
| 143 |
+
for k, v in list(cfg_dict.items()):
|
| 144 |
+
if v is None:
|
| 145 |
+
cfg_dict.pop(k)
|
| 146 |
+
elif isinstance(v, Callable):
|
| 147 |
+
# TODO: this should handle Promptsource template objects as a separate case?
|
| 148 |
+
cfg_dict[k] = str(v)
|
| 149 |
+
return cfg_dict
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Task(abc.ABC):
|
| 153 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
| 154 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 155 |
+
|
| 156 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
| 157 |
+
This is usually a dictionary e.g.
|
| 158 |
+
{"question": ..., "answer": ...} or
|
| 159 |
+
{"question": ..., question, answer)
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
VERSION = None
|
| 163 |
+
|
| 164 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 165 |
+
# or a path to a custom `datasets` loading script.
|
| 166 |
+
DATASET_PATH: str = None
|
| 167 |
+
|
| 168 |
+
# The name of a subset within `DATASET_PATH`.
|
| 169 |
+
DATASET_NAME: str = None
|
| 170 |
+
|
| 171 |
+
OUTPUT_TYPE: str = None
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
data_dir=None,
|
| 176 |
+
cache_dir=None,
|
| 177 |
+
download_mode=None,
|
| 178 |
+
config=None,
|
| 179 |
+
) -> None:
|
| 180 |
+
"""
|
| 181 |
+
:param data_dir: str
|
| 182 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 183 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 184 |
+
the dataset is not publicly accessible).
|
| 185 |
+
:param cache_dir: str
|
| 186 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 187 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 188 |
+
`~/.cache/huggingface/datasets`
|
| 189 |
+
NOTE: You can change the cache location globally for a given process
|
| 190 |
+
to another directory:
|
| 191 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 192 |
+
:param download_mode: datasets.DownloadMode
|
| 193 |
+
How to treat pre-existing `Task` downloads and data.
|
| 194 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 195 |
+
Reuse download and reuse dataset.
|
| 196 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 197 |
+
Reuse download with fresh dataset.
|
| 198 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 199 |
+
Fresh download and fresh dataset.
|
| 200 |
+
"""
|
| 201 |
+
self.download(data_dir, cache_dir, download_mode)
|
| 202 |
+
self._training_docs = None
|
| 203 |
+
self._fewshot_docs = None
|
| 204 |
+
self._instances = None
|
| 205 |
+
|
| 206 |
+
self._config = TaskConfig({**config}) if config else TaskConfig()
|
| 207 |
+
|
| 208 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 209 |
+
|
| 210 |
+
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
|
| 211 |
+
"""Downloads and returns the task dataset.
|
| 212 |
+
Override this method to download the dataset from a custom API.
|
| 213 |
+
|
| 214 |
+
:param data_dir: str
|
| 215 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 216 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 217 |
+
the dataset is not publicly accessible).
|
| 218 |
+
:param cache_dir: str
|
| 219 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 220 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 221 |
+
`~/.cache/huggingface/datasets`
|
| 222 |
+
NOTE: You can change the cache location globally for a given process
|
| 223 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 224 |
+
to another directory:
|
| 225 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 226 |
+
:param download_mode: datasets.DownloadMode
|
| 227 |
+
How to treat pre-existing `Task` downloads and data.
|
| 228 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 229 |
+
Reuse download and reuse dataset.
|
| 230 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 231 |
+
Reuse download with fresh dataset.
|
| 232 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 233 |
+
Fresh download and fresh dataset.
|
| 234 |
+
"""
|
| 235 |
+
self.dataset = datasets.load_dataset(
|
| 236 |
+
path=self.DATASET_PATH,
|
| 237 |
+
name=self.DATASET_NAME,
|
| 238 |
+
data_dir=data_dir,
|
| 239 |
+
cache_dir=cache_dir,
|
| 240 |
+
download_mode=download_mode,
|
| 241 |
+
)
|
| 242 |
+
self.dataset_no_image = datasets.load_dataset(
|
| 243 |
+
path=self.DATASET_PATH,
|
| 244 |
+
name=self.DATASET_NAME,
|
| 245 |
+
data_dir=data_dir,
|
| 246 |
+
cache_dir=cache_dir,
|
| 247 |
+
download_mode=download_mode,
|
| 248 |
+
)
|
| 249 |
+
for doc_name in self.dataset_no_image:
|
| 250 |
+
remove_cols = []
|
| 251 |
+
features = self.dataset_no_image[doc_name].features
|
| 252 |
+
# If it is an Image instance or a Sequence of Image instance. Remove it
|
| 253 |
+
for feature in features:
|
| 254 |
+
if isinstance(features[feature], Image):
|
| 255 |
+
remove_cols.append(feature)
|
| 256 |
+
elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image):
|
| 257 |
+
remove_cols.append(feature)
|
| 258 |
+
for remove_col in remove_cols:
|
| 259 |
+
self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col)
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def config(self):
|
| 263 |
+
"""Returns the TaskConfig associated with this class."""
|
| 264 |
+
return self._config
|
| 265 |
+
|
| 266 |
+
@abc.abstractmethod
|
| 267 |
+
def has_training_docs(self):
|
| 268 |
+
"""Whether the task has a training set"""
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
@abc.abstractmethod
|
| 272 |
+
def has_validation_docs(self):
|
| 273 |
+
"""Whether the task has a validation set"""
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
@abc.abstractmethod
|
| 277 |
+
def has_test_docs(self):
|
| 278 |
+
"""Whether the task has a test set"""
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
def training_docs(self):
|
| 282 |
+
"""
|
| 283 |
+
:return: Iterable[obj]
|
| 284 |
+
A iterable of any object, that doc_to_text can handle
|
| 285 |
+
"""
|
| 286 |
+
return []
|
| 287 |
+
|
| 288 |
+
def validation_docs(self):
|
| 289 |
+
"""
|
| 290 |
+
:return: Iterable[obj]
|
| 291 |
+
A iterable of any object, that doc_to_text can handle
|
| 292 |
+
"""
|
| 293 |
+
return []
|
| 294 |
+
|
| 295 |
+
def test_docs(self):
|
| 296 |
+
"""
|
| 297 |
+
:return: Iterable[obj]
|
| 298 |
+
A iterable of any object, that doc_to_text can handle
|
| 299 |
+
"""
|
| 300 |
+
return []
|
| 301 |
+
|
| 302 |
+
def fewshot_docs(self):
|
| 303 |
+
"""
|
| 304 |
+
:return: Iterable[obj]
|
| 305 |
+
A iterable of any object, that doc_to_text can handle
|
| 306 |
+
"""
|
| 307 |
+
if self.has_training_docs():
|
| 308 |
+
return self.training_docs()
|
| 309 |
+
elif self.has_validation_docs():
|
| 310 |
+
return self.validation_docs()
|
| 311 |
+
else:
|
| 312 |
+
if self.config.num_fewshot is not None:
|
| 313 |
+
eval_logger.warning("has_training_docs and has_validation_docs are False" ", using test_docs as fewshot_docs but this is not recommended.")
|
| 314 |
+
return self.test_docs()
|
| 315 |
+
|
| 316 |
+
def _process_doc(self, doc):
|
| 317 |
+
"""
|
| 318 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 319 |
+
documents. This can be used in a map over documents of a data split.
|
| 320 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 321 |
+
|
| 322 |
+
:return: dict
|
| 323 |
+
The processed version of the specified `doc`.
|
| 324 |
+
"""
|
| 325 |
+
return doc
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def instances(self):
|
| 329 |
+
"""After calling `task.build_all_requests()`, tasks
|
| 330 |
+
maintain a list of the dataset instances which will be evaluated.
|
| 331 |
+
"""
|
| 332 |
+
return self._instances
|
| 333 |
+
|
| 334 |
+
def fewshot_examples(self, k, rnd):
|
| 335 |
+
if self._training_docs is None:
|
| 336 |
+
self._training_docs = list(self.training_docs())
|
| 337 |
+
|
| 338 |
+
return rnd.sample(self._training_docs, k)
|
| 339 |
+
|
| 340 |
+
def doc_to_decontamination_query(self, doc) -> None:
|
| 341 |
+
print("Override doc_to_decontamination_query with document specific decontamination query.")
|
| 342 |
+
assert False
|
| 343 |
+
|
| 344 |
+
@abc.abstractmethod
|
| 345 |
+
def doc_to_text(self, doc):
|
| 346 |
+
pass
|
| 347 |
+
|
| 348 |
+
@abc.abstractmethod
|
| 349 |
+
def doc_to_target(self, doc):
|
| 350 |
+
pass
|
| 351 |
+
|
| 352 |
+
# @profile
|
| 353 |
+
def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
|
| 354 |
+
"""Build a set of Instances for a task, and store them in task.instances"""
|
| 355 |
+
if self.has_test_docs():
|
| 356 |
+
docs = self.test_docs()
|
| 357 |
+
split = self.config.test_split
|
| 358 |
+
elif self.has_validation_docs():
|
| 359 |
+
docs = self.validation_docs()
|
| 360 |
+
split = self.config.validation_split
|
| 361 |
+
else:
|
| 362 |
+
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 363 |
+
|
| 364 |
+
eval_logger.info(f"Building contexts for task {self.CONFIG.task} on rank {rank}...")
|
| 365 |
+
instances = []
|
| 366 |
+
doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit)
|
| 367 |
+
doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator)
|
| 368 |
+
total_docs = sum(1 for _ in doc_id_iterator_counting)
|
| 369 |
+
pbar = tqdm(total=total_docs, desc=f"Building context", disable=(rank != 0))
|
| 370 |
+
for doc_id in doc_id_iterator:
|
| 371 |
+
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
| 372 |
+
fewshot_ctx = self.fewshot_context(doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, self.config.training_split if self.has_training_docs() else split)
|
| 373 |
+
|
| 374 |
+
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
| 375 |
+
inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), split=split)
|
| 376 |
+
|
| 377 |
+
if not isinstance(inst, list):
|
| 378 |
+
inst = [inst]
|
| 379 |
+
|
| 380 |
+
instances.extend(inst)
|
| 381 |
+
pbar.update(1)
|
| 382 |
+
|
| 383 |
+
pbar.close()
|
| 384 |
+
self._instances = instances
|
| 385 |
+
assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
|
| 386 |
+
|
| 387 |
+
@abc.abstractmethod
|
| 388 |
+
def construct_requests(self, doc_id, ctx, **kwargs):
|
| 389 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 390 |
+
Requests which will be sent to the LMM.
|
| 391 |
+
|
| 392 |
+
:param doc_id: int
|
| 393 |
+
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
| 394 |
+
whichever is the main split used.
|
| 395 |
+
:param ctx: str
|
| 396 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 397 |
+
language description, as well as the few shot examples, and the question
|
| 398 |
+
part of the document for `doc`.
|
| 399 |
+
:param repeats: int
|
| 400 |
+
TODO: update this docstring
|
| 401 |
+
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
| 402 |
+
can be increased for techniques like majority voting.
|
| 403 |
+
"""
|
| 404 |
+
pass
|
| 405 |
+
|
| 406 |
+
@abc.abstractmethod
|
| 407 |
+
def process_results(self, doc, results):
|
| 408 |
+
"""Take a single document and the LMM results and evaluates, returning a
|
| 409 |
+
dict where keys are the names of submetrics and values are the values of
|
| 410 |
+
the metric for that one document
|
| 411 |
+
|
| 412 |
+
:param doc:
|
| 413 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 414 |
+
:param results:
|
| 415 |
+
The results of the requests created in construct_requests.
|
| 416 |
+
"""
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
@abc.abstractmethod
|
| 420 |
+
def aggregation(self):
|
| 421 |
+
"""
|
| 422 |
+
:returns: {str: [metric_score] -> float}
|
| 423 |
+
A dictionary where keys are the names of submetrics and values are
|
| 424 |
+
functions that aggregate a list of metric scores
|
| 425 |
+
"""
|
| 426 |
+
pass
|
| 427 |
+
|
| 428 |
+
@abc.abstractmethod
|
| 429 |
+
def higher_is_better(self):
|
| 430 |
+
"""
|
| 431 |
+
:returns: {str: bool}
|
| 432 |
+
A dictionary where keys are the names of submetrics and values are
|
| 433 |
+
whether a higher value of the submetric is better
|
| 434 |
+
"""
|
| 435 |
+
pass
|
| 436 |
+
|
| 437 |
+
@classmethod
|
| 438 |
+
def count_bytes(cls, doc):
|
| 439 |
+
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
| 440 |
+
return len(doc.encode("utf-8"))
|
| 441 |
+
|
| 442 |
+
@utils.positional_deprecated
|
| 443 |
+
def fewshot_context(
|
| 444 |
+
self,
|
| 445 |
+
doc_id,
|
| 446 |
+
num_fewshot,
|
| 447 |
+
split,
|
| 448 |
+
rnd=random.Random(1234),
|
| 449 |
+
description=None,
|
| 450 |
+
):
|
| 451 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 452 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 453 |
+
|
| 454 |
+
:param doc_id: int
|
| 455 |
+
The document id as returned from training_docs, validation_docs, or test_docs.
|
| 456 |
+
:param num_fewshot: int
|
| 457 |
+
The number of fewshot examples to provide in the returned context string.
|
| 458 |
+
:param split: str
|
| 459 |
+
The split of the document to retrieve from the dataset
|
| 460 |
+
:param rnd: random.Random
|
| 461 |
+
The pseudo-random number generator used to randomly sample examples.
|
| 462 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 463 |
+
:param description: str
|
| 464 |
+
The task's description that will be prepended to the fewshot examples.
|
| 465 |
+
:returns: str
|
| 466 |
+
The fewshot context.
|
| 467 |
+
"""
|
| 468 |
+
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
|
| 469 |
+
|
| 470 |
+
description = description if description else ""
|
| 471 |
+
doc = self.dataset_no_image[split][doc_id]
|
| 472 |
+
|
| 473 |
+
if num_fewshot == 0:
|
| 474 |
+
labeled_examples = ""
|
| 475 |
+
else:
|
| 476 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 477 |
+
if self.has_training_docs():
|
| 478 |
+
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
| 479 |
+
else:
|
| 480 |
+
if self._fewshot_docs is None:
|
| 481 |
+
self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
|
| 482 |
+
|
| 483 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 484 |
+
|
| 485 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 486 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 487 |
+
|
| 488 |
+
labeled_examples = "\n\n".join([self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]) + "\n\n"
|
| 489 |
+
|
| 490 |
+
example = self.doc_to_text(doc)
|
| 491 |
+
return description + labeled_examples + example
|
| 492 |
+
|
| 493 |
+
def apply_filters(self):
|
| 494 |
+
if hasattr(self, "_filters"):
|
| 495 |
+
for f in self._filters:
|
| 496 |
+
f.apply(self._instances, None)
|
| 497 |
+
else:
|
| 498 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 499 |
+
return self._instances
|
| 500 |
+
|
| 501 |
+
def dump_config(self) -> dict:
|
| 502 |
+
"""Returns a dictionary representing the task's config.
|
| 503 |
+
|
| 504 |
+
:returns: str
|
| 505 |
+
The fewshot context.
|
| 506 |
+
"""
|
| 507 |
+
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
| 508 |
+
# (num_fewshot)
|
| 509 |
+
return self.config.to_dict()
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class ConfigurableTask(Task):
|
| 513 |
+
VERSION = "Yaml"
|
| 514 |
+
OUTPUT_TYPE = None
|
| 515 |
+
CONFIG = None
|
| 516 |
+
|
| 517 |
+
def __init__(self, model_name) -> None: # TODO no super() call here
|
| 518 |
+
# Get pre-configured attributes
|
| 519 |
+
self._config = self.CONFIG
|
| 520 |
+
# different model requires different prompt, we have to take those into account.
|
| 521 |
+
|
| 522 |
+
self.model_name = model_name
|
| 523 |
+
self._prepare_model_specific_config()
|
| 524 |
+
|
| 525 |
+
assert self.config.output_type in ALL_OUTPUT_TYPES
|
| 526 |
+
self.OUTPUT_TYPE = self.config.output_type
|
| 527 |
+
|
| 528 |
+
self.DATASET_PATH = self.config.dataset_path
|
| 529 |
+
|
| 530 |
+
if self.config.dataset_name is not None:
|
| 531 |
+
self.DATASET_NAME = self.config.dataset_name
|
| 532 |
+
|
| 533 |
+
self._prepare_metric_and_aggregation()
|
| 534 |
+
|
| 535 |
+
self.download(self.config.dataset_kwargs)
|
| 536 |
+
self._training_docs = None
|
| 537 |
+
self._fewshot_docs = None
|
| 538 |
+
|
| 539 |
+
if self.config.filter_list is not None:
|
| 540 |
+
self._filters = []
|
| 541 |
+
for filter_config in self.config.filter_list:
|
| 542 |
+
for filter_pipeline in filter_config:
|
| 543 |
+
filter_name = filter_config["name"]
|
| 544 |
+
filter_functions = filter_config["filter"]
|
| 545 |
+
components = []
|
| 546 |
+
for function in filter_functions:
|
| 547 |
+
kwargs = {key: function[key] for key in function if key != "function"}
|
| 548 |
+
components.append([function["function"], kwargs])
|
| 549 |
+
filter_pipeline = build_filter_ensemble(filter_name, components)
|
| 550 |
+
self._filters.append(filter_pipeline)
|
| 551 |
+
else:
|
| 552 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 553 |
+
if self.config.fewshot_config is not None:
|
| 554 |
+
self.sampler = samplers.get_sampler(self.config.fewshot_config.get("sampler", "default") if self.config.fewshot_config else "default")(list(self.fewshot_docs()), self, rnd=random.Random(1234))
|
| 555 |
+
|
| 556 |
+
if self.has_test_docs():
|
| 557 |
+
self.task_docs = self.test_docs()
|
| 558 |
+
elif self.has_validation_docs():
|
| 559 |
+
self.task_docs = self.validation_docs()
|
| 560 |
+
else:
|
| 561 |
+
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 562 |
+
|
| 563 |
+
# Test One Doc
|
| 564 |
+
self.features = list(self.task_docs.features.keys())
|
| 565 |
+
self.multiple_input = 0
|
| 566 |
+
self.multiple_target = 0
|
| 567 |
+
test_doc = self.task_docs[0]
|
| 568 |
+
test_text = self.doc_to_text(test_doc)
|
| 569 |
+
test_target = self.doc_to_target(test_doc)
|
| 570 |
+
|
| 571 |
+
if self.config.doc_to_choice is not None:
|
| 572 |
+
test_choice = self.doc_to_choice(test_doc)
|
| 573 |
+
if type(test_choice) is not list:
|
| 574 |
+
eval_logger.error("doc_to_choice must return list")
|
| 575 |
+
else:
|
| 576 |
+
num_choice = len(test_choice)
|
| 577 |
+
|
| 578 |
+
if type(test_text) is int:
|
| 579 |
+
self.multiple_input = num_choice
|
| 580 |
+
else:
|
| 581 |
+
test_choice = None
|
| 582 |
+
|
| 583 |
+
if type(test_target) is list:
|
| 584 |
+
self.multiple_target = len(test_target)
|
| 585 |
+
else:
|
| 586 |
+
if (type(test_target) is int) and (test_choice is not None):
|
| 587 |
+
test_target = test_choice[test_target]
|
| 588 |
+
else:
|
| 589 |
+
test_target = str(test_target)
|
| 590 |
+
|
| 591 |
+
if test_choice is not None:
|
| 592 |
+
check_choices = test_choice
|
| 593 |
+
else:
|
| 594 |
+
check_choices = [test_target]
|
| 595 |
+
if self.config.doc_to_choice is not None:
|
| 596 |
+
for choice in check_choices:
|
| 597 |
+
choice_has_whitespace = True if choice[0].isspace() else False
|
| 598 |
+
delimiter_has_whitespace = True if self.config.target_delimiter.rstrip() != self.config.target_delimiter else False
|
| 599 |
+
|
| 600 |
+
if delimiter_has_whitespace and choice_has_whitespace:
|
| 601 |
+
eval_logger.warning(f'Both target_delimiter and target choice: "{choice}" have whitespace')
|
| 602 |
+
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
| 603 |
+
eval_logger.warning(f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace')
|
| 604 |
+
|
| 605 |
+
def _prepare_model_specific_config(self):
|
| 606 |
+
self.model_specific_prompt_kwargs = self.config.model_specific_prompt_kwargs
|
| 607 |
+
if self.model_specific_prompt_kwargs is not None:
|
| 608 |
+
if self.model_name in self.model_specific_prompt_kwargs:
|
| 609 |
+
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs[self.model_name]
|
| 610 |
+
else:
|
| 611 |
+
self.model_specific_prompt_kwargs = self.model_specific_prompt_kwargs.get("default", None)
|
| 612 |
+
|
| 613 |
+
self.model_specific_target_kwargs = self.config.model_specific_target_kwargs
|
| 614 |
+
if self.model_specific_target_kwargs is not None:
|
| 615 |
+
if self.model_name in self.model_specific_target_kwargs:
|
| 616 |
+
self.model_specific_target_kwargs = self.model_specific_target_kwargs[self.model_name]
|
| 617 |
+
else:
|
| 618 |
+
self.model_specific_target_kwargs = self.model_specific_target_kwargs.get("default", None)
|
| 619 |
+
self.model_specific_generation_kwargs = self.config.model_specific_generation_kwargs
|
| 620 |
+
if self.model_specific_generation_kwargs is not None:
|
| 621 |
+
if self.model_name in self.model_specific_generation_kwargs:
|
| 622 |
+
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs[self.model_name]
|
| 623 |
+
else:
|
| 624 |
+
self.model_specific_generation_kwargs = self.model_specific_generation_kwargs.get("default", {})
|
| 625 |
+
|
| 626 |
+
self.config.generation_kwargs.update(self.model_specific_generation_kwargs)
|
| 627 |
+
|
| 628 |
+
def _prepare_metric_and_aggregation(self):
|
| 629 |
+
self._metric_fn_list = {}
|
| 630 |
+
self._metric_fn_kwargs = {}
|
| 631 |
+
self._aggregation_list = {}
|
| 632 |
+
self._higher_is_better = {}
|
| 633 |
+
|
| 634 |
+
if self.config.metric_list is None:
|
| 635 |
+
# TODO: handle this in TaskConfig.__post_init__ ?
|
| 636 |
+
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
| 637 |
+
|
| 638 |
+
for metric_name in _metric_list:
|
| 639 |
+
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
|
| 640 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 641 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
| 642 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 643 |
+
else:
|
| 644 |
+
for metric_config in self.config.metric_list:
|
| 645 |
+
assert "metric" in metric_config
|
| 646 |
+
metric_name = metric_config["metric"]
|
| 647 |
+
kwargs = {key: metric_config[key] for key in metric_config if key not in ["metric", "aggregation", "higher_is_better"]}
|
| 648 |
+
|
| 649 |
+
if self.config.process_results is not None:
|
| 650 |
+
self._metric_fn_list[metric_name] = None
|
| 651 |
+
self._metric_fn_kwargs[metric_name] = {}
|
| 652 |
+
elif callable(metric_name):
|
| 653 |
+
metric_fn = metric_name.__call__
|
| 654 |
+
metric_name = metric_name.__name__
|
| 655 |
+
self._metric_fn_list[metric_name] = metric_fn
|
| 656 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 657 |
+
else:
|
| 658 |
+
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
|
| 659 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
| 660 |
+
|
| 661 |
+
if "aggregation" in metric_config:
|
| 662 |
+
agg_name = metric_config["aggregation"]
|
| 663 |
+
if type(agg_name) == str:
|
| 664 |
+
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
| 665 |
+
elif callable(agg_name):
|
| 666 |
+
self._aggregation_list[metric_name] = metric_config["aggregation"]
|
| 667 |
+
else:
|
| 668 |
+
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
| 669 |
+
metric_agg = get_metric_aggregation(metric_name)
|
| 670 |
+
eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " f"using default " f"aggregation={INV_AGG_REGISTRY[metric_agg]}")
|
| 671 |
+
self._aggregation_list[metric_name] = metric_agg
|
| 672 |
+
|
| 673 |
+
if "higher_is_better" in metric_config:
|
| 674 |
+
self._higher_is_better[metric_name] = metric_config["higher_is_better"]
|
| 675 |
+
else:
|
| 676 |
+
eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. " f"using default " f"higher_is_better={is_higher_better(metric_name)}")
|
| 677 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 678 |
+
|
| 679 |
+
@retry(stop=stop_after_attempt(5), wait=wait_fixed(2))
|
| 680 |
+
def download(self, dataset_kwargs=None) -> None:
|
| 681 |
+
download_config = DownloadConfig()
|
| 682 |
+
download_config.max_retries = dataset_kwargs.get("max_retries", 3) if dataset_kwargs is not None else 3
|
| 683 |
+
download_config.num_proc = dataset_kwargs.get("num_proc", 8) if dataset_kwargs is not None else 8
|
| 684 |
+
self.dataset = datasets.load_dataset(
|
| 685 |
+
path=self.DATASET_PATH,
|
| 686 |
+
name=self.DATASET_NAME,
|
| 687 |
+
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
|
| 688 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 689 |
+
)
|
| 690 |
+
self.dataset_no_image = datasets.load_dataset(
|
| 691 |
+
path=self.DATASET_PATH,
|
| 692 |
+
name=self.DATASET_NAME,
|
| 693 |
+
download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS,
|
| 694 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 695 |
+
)
|
| 696 |
+
for doc_name in self.dataset_no_image:
|
| 697 |
+
remove_cols = []
|
| 698 |
+
features = self.dataset_no_image[doc_name].features
|
| 699 |
+
# If it is an Image instance or a Sequence of Image instance. Remove it
|
| 700 |
+
for feature in features:
|
| 701 |
+
if isinstance(features[feature], Image):
|
| 702 |
+
remove_cols.append(feature)
|
| 703 |
+
elif isinstance(features[feature], Sequence) and isinstance(features[feature].feature, Image):
|
| 704 |
+
remove_cols.append(feature)
|
| 705 |
+
for remove_col in remove_cols:
|
| 706 |
+
self.dataset_no_image[doc_name] = self.dataset_no_image[doc_name].remove_columns(remove_col)
|
| 707 |
+
|
| 708 |
+
def has_training_docs(self) -> bool:
|
| 709 |
+
if self.config.training_split is not None:
|
| 710 |
+
return True
|
| 711 |
+
else:
|
| 712 |
+
return False
|
| 713 |
+
|
| 714 |
+
def has_validation_docs(self) -> bool:
|
| 715 |
+
if self.config.validation_split is not None:
|
| 716 |
+
return True
|
| 717 |
+
else:
|
| 718 |
+
return False
|
| 719 |
+
|
| 720 |
+
def has_test_docs(self) -> bool:
|
| 721 |
+
if self.config.test_split is not None:
|
| 722 |
+
return True
|
| 723 |
+
else:
|
| 724 |
+
return False
|
| 725 |
+
|
| 726 |
+
def training_docs(self) -> datasets.Dataset:
|
| 727 |
+
if self.has_training_docs():
|
| 728 |
+
if self.config.process_docs is not None:
|
| 729 |
+
return self.config.process_docs(self.dataset[self.config.training_split])
|
| 730 |
+
return self.dataset[self.config.training_split]
|
| 731 |
+
|
| 732 |
+
def validation_docs(self) -> datasets.Dataset:
|
| 733 |
+
if self.has_validation_docs():
|
| 734 |
+
if self.config.process_docs is not None:
|
| 735 |
+
return self.config.process_docs(self.dataset[self.config.validation_split])
|
| 736 |
+
return self.dataset[self.config.validation_split]
|
| 737 |
+
|
| 738 |
+
def test_docs(self) -> datasets.Dataset:
|
| 739 |
+
if self.has_test_docs():
|
| 740 |
+
if self.config.process_docs is not None:
|
| 741 |
+
return self.config.process_docs(self.dataset[self.config.test_split])
|
| 742 |
+
return self.dataset[self.config.test_split]
|
| 743 |
+
|
| 744 |
+
def fewshot_docs(self):
|
| 745 |
+
if self.config.fewshot_split is not None:
|
| 746 |
+
return self.dataset[self.config.fewshot_split]
|
| 747 |
+
else:
|
| 748 |
+
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
| 749 |
+
eval_logger.warning(f"Task '{self.config.task}': " "num_fewshot > 0 but fewshot_split is None. " "using preconfigured rule.")
|
| 750 |
+
return super().fewshot_docs()
|
| 751 |
+
|
| 752 |
+
@utils.positional_deprecated
|
| 753 |
+
def fewshot_context(self, doc_id, num_fewshot, split):
|
| 754 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 755 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 756 |
+
|
| 757 |
+
:param doc_id: str
|
| 758 |
+
The document id as returned from training_docs, validation_docs, or test_docs.
|
| 759 |
+
:param num_fewshot: int
|
| 760 |
+
The number of fewshot examples to provide in the returned context string.
|
| 761 |
+
:returns: str
|
| 762 |
+
The fewshot context.
|
| 763 |
+
"""
|
| 764 |
+
doc = self.dataset_no_image[split][doc_id]
|
| 765 |
+
if num_fewshot == 0:
|
| 766 |
+
# always prepend the (possibly empty) task description
|
| 767 |
+
labeled_examples = self.config.description
|
| 768 |
+
else:
|
| 769 |
+
labeled_examples = self.config.description + self.sampler.get_context(doc, num_fewshot)
|
| 770 |
+
example = self.doc_to_text(doc)
|
| 771 |
+
if type(example) == str:
|
| 772 |
+
return labeled_examples + example
|
| 773 |
+
elif type(example) == list:
|
| 774 |
+
return [labeled_examples + ex for ex in example]
|
| 775 |
+
elif type(example) == int:
|
| 776 |
+
if self.config.doc_to_choice is not None:
|
| 777 |
+
choices = self.doc_to_choice(doc)
|
| 778 |
+
return labeled_examples + choices[example]
|
| 779 |
+
else:
|
| 780 |
+
return labeled_examples + str(example)
|
| 781 |
+
|
| 782 |
+
def apply_filters(self):
|
| 783 |
+
if hasattr(self, "_filters"):
|
| 784 |
+
for f in self._filters:
|
| 785 |
+
f.apply(self._instances, self.task_docs)
|
| 786 |
+
else:
|
| 787 |
+
eval_logger.warning("No filter defined, passing through instances")
|
| 788 |
+
return self._instances
|
| 789 |
+
|
| 790 |
+
def should_decontaminate(self):
|
| 791 |
+
return self.config.should_decontaminate
|
| 792 |
+
|
| 793 |
+
def doc_to_decontamination_query(self, doc):
|
| 794 |
+
if self.config.should_decontaminate:
|
| 795 |
+
if self.config.doc_to_decontamination_query is None:
|
| 796 |
+
return self.doc_to_text(doc)
|
| 797 |
+
else:
|
| 798 |
+
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
| 799 |
+
if doc_to_decontamination_query in self.features:
|
| 800 |
+
return doc[doc_to_decontamination_query]
|
| 801 |
+
elif callable(doc_to_decontamination_query):
|
| 802 |
+
return doc_to_decontamination_query(doc)
|
| 803 |
+
else:
|
| 804 |
+
return ast.literal_eval(utils.apply_template(self.config.doc_to_decontamination_query, doc))
|
| 805 |
+
|
| 806 |
+
def _process_doc(self, doc):
|
| 807 |
+
"""
|
| 808 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 809 |
+
documents. This can be used in a map over documents of a data split.
|
| 810 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 811 |
+
|
| 812 |
+
:return: dict
|
| 813 |
+
The processed version of the specified `doc`.
|
| 814 |
+
"""
|
| 815 |
+
return doc
|
| 816 |
+
|
| 817 |
+
def doc_to_text(self, doc):
|
| 818 |
+
doc_to_text = self.config.doc_to_text
|
| 819 |
+
|
| 820 |
+
if type(doc_to_text) == int:
|
| 821 |
+
return doc_to_text
|
| 822 |
+
elif type(doc_to_text) == str:
|
| 823 |
+
if doc_to_text in self.features:
|
| 824 |
+
# if self.config.doc_to_choice is not None:
|
| 825 |
+
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
| 826 |
+
# else:
|
| 827 |
+
return doc[doc_to_text]
|
| 828 |
+
else:
|
| 829 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
| 830 |
+
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
| 831 |
+
return ast.literal_eval(text_string)
|
| 832 |
+
else:
|
| 833 |
+
return text_string
|
| 834 |
+
elif callable(doc_to_text):
|
| 835 |
+
return (
|
| 836 |
+
doc_to_text(doc, self.model_specific_prompt_kwargs)
|
| 837 |
+
if self.model_specific_prompt_kwargs is not None
|
| 838 |
+
else doc_to_text(
|
| 839 |
+
doc,
|
| 840 |
+
)
|
| 841 |
+
)
|
| 842 |
+
# Used when applying a Promptsource template
|
| 843 |
+
elif hasattr(doc_to_text, "apply"):
|
| 844 |
+
applied_prompt = doc_to_text.apply(doc)
|
| 845 |
+
if len(applied_prompt) == 2:
|
| 846 |
+
return applied_prompt[0]
|
| 847 |
+
else:
|
| 848 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 849 |
+
return self.config.fewshot_delimiter
|
| 850 |
+
else:
|
| 851 |
+
print(type(doc_to_text))
|
| 852 |
+
raise TypeError
|
| 853 |
+
|
| 854 |
+
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
|
| 855 |
+
doc_to_target = self.config.doc_to_target
|
| 856 |
+
|
| 857 |
+
if type(doc_to_target) == int:
|
| 858 |
+
return doc_to_target
|
| 859 |
+
elif type(doc_to_target) == str:
|
| 860 |
+
if doc_to_target in self.features:
|
| 861 |
+
# if self.config.doc_to_choice is not None:
|
| 862 |
+
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
| 863 |
+
# else:
|
| 864 |
+
return doc[doc_to_target]
|
| 865 |
+
else:
|
| 866 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
| 867 |
+
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
| 868 |
+
return ast.literal_eval(target_string)
|
| 869 |
+
elif len(target_string) >= 2 and (target_string[0] == "[") and (target_string[-1] == "]"):
|
| 870 |
+
try:
|
| 871 |
+
return ast.literal_eval(target_string)
|
| 872 |
+
except (SyntaxError, ValueError):
|
| 873 |
+
return target_string
|
| 874 |
+
else:
|
| 875 |
+
return target_string
|
| 876 |
+
elif type(doc_to_target) == list:
|
| 877 |
+
return doc_to_target
|
| 878 |
+
elif callable(doc_to_target):
|
| 879 |
+
return doc_to_target(doc, self.model_specific_target_kwargs) if self.model_specific_target_kwargs is not None else doc_to_target(doc)
|
| 880 |
+
# Used when applying a Promptsource template
|
| 881 |
+
elif hasattr(doc_to_target, "apply"):
|
| 882 |
+
applied_prompt = doc_to_target.apply(doc)
|
| 883 |
+
if len(applied_prompt) == 2:
|
| 884 |
+
return applied_prompt[1]
|
| 885 |
+
else:
|
| 886 |
+
eval_logger.warning("Applied prompt returns empty string")
|
| 887 |
+
return self.config.fewshot_delimiter
|
| 888 |
+
else:
|
| 889 |
+
raise TypeError
|
| 890 |
+
|
| 891 |
+
def doc_to_visual(self, doc: dict) -> Union[int, str, list]:
|
| 892 |
+
self.config.doc_to_visual
|
| 893 |
+
if type(self.config.doc_to_visual) == str:
|
| 894 |
+
assert self.config.doc_to_visual in self.features
|
| 895 |
+
# Single image. Still return a list for consistency.
|
| 896 |
+
return [doc[self.config.doc_to_visual]]
|
| 897 |
+
else:
|
| 898 |
+
assert callable(self.config.doc_to_visual)
|
| 899 |
+
return self.config.doc_to_visual(doc)
|
| 900 |
+
|
| 901 |
+
def doc_to_choice(self, doc: Any) -> List[str]:
|
| 902 |
+
if self.config.doc_to_choice is None:
|
| 903 |
+
eval_logger.error("doc_to_choice was called but not set in config")
|
| 904 |
+
else:
|
| 905 |
+
doc_to_choice = self.config.doc_to_choice
|
| 906 |
+
|
| 907 |
+
if type(doc_to_choice) == str:
|
| 908 |
+
if doc_to_choice in self.features:
|
| 909 |
+
return doc[doc_to_choice]
|
| 910 |
+
else:
|
| 911 |
+
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
| 912 |
+
elif type(doc_to_choice) == list:
|
| 913 |
+
return doc_to_choice
|
| 914 |
+
elif type(doc_to_choice) == dict:
|
| 915 |
+
return list(doc_to_choice.values())
|
| 916 |
+
elif callable(doc_to_choice):
|
| 917 |
+
return doc_to_choice(doc)
|
| 918 |
+
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
| 919 |
+
return doc_to_choice.get_answer_choices_list(doc)
|
| 920 |
+
else:
|
| 921 |
+
raise TypeError
|
| 922 |
+
|
| 923 |
+
def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]:
|
| 924 |
+
split = kwargs.get("split")
|
| 925 |
+
kwargs.pop("split")
|
| 926 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 927 |
+
arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split)
|
| 928 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 929 |
+
doc = self.dataset[split][doc_id]
|
| 930 |
+
choices = self.doc_to_choice(doc)
|
| 931 |
+
target_delimiter = self.config.target_delimiter
|
| 932 |
+
if self.multiple_input:
|
| 933 |
+
# If there are multiple inputs, choices are placed in the ctx
|
| 934 |
+
cont = self.doc_to_target(doc)
|
| 935 |
+
arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for ctx in choices]
|
| 936 |
+
else:
|
| 937 |
+
# Otherwise they are placed in the continuation
|
| 938 |
+
arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for cont in choices]
|
| 939 |
+
request_list = [
|
| 940 |
+
Instance(
|
| 941 |
+
request_type="loglikelihood",
|
| 942 |
+
# doc=doc,
|
| 943 |
+
arguments=arg,
|
| 944 |
+
idx=i,
|
| 945 |
+
**kwargs,
|
| 946 |
+
)
|
| 947 |
+
for i, arg in enumerate(arguments)
|
| 948 |
+
]
|
| 949 |
+
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
| 950 |
+
if "acc_mutual_info" in self._metric_fn_list.keys():
|
| 951 |
+
# if we are calculating multiple choice accuracy
|
| 952 |
+
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
| 953 |
+
|
| 954 |
+
# here mutual info refers to calculating
|
| 955 |
+
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
| 956 |
+
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
| 957 |
+
request_list.extend(
|
| 958 |
+
[
|
| 959 |
+
Instance(
|
| 960 |
+
request_type="loglikelihood",
|
| 961 |
+
# doc=doc,
|
| 962 |
+
arguments=("", "{}".format(choice)),
|
| 963 |
+
idx=i,
|
| 964 |
+
**kwargs,
|
| 965 |
+
)
|
| 966 |
+
for i, choice in enumerate(choices)
|
| 967 |
+
]
|
| 968 |
+
)
|
| 969 |
+
return request_list
|
| 970 |
+
|
| 971 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 972 |
+
arguments = (ctx, self.config.generation_kwargs, self.doc_to_visual, doc_id, self.config.task, split)
|
| 973 |
+
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)
|
| 974 |
+
|
| 975 |
+
def process_results(self, doc, results):
|
| 976 |
+
if callable(self.config.process_results):
|
| 977 |
+
return self.config.process_results(doc, results)
|
| 978 |
+
|
| 979 |
+
result_dict = {}
|
| 980 |
+
use_metric = list(self._metric_fn_list.keys())
|
| 981 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
| 982 |
+
results = results[0]
|
| 983 |
+
ll, is_greedy = results
|
| 984 |
+
return {
|
| 985 |
+
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
| 986 |
+
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
| 987 |
+
}
|
| 988 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 989 |
+
lls, is_greedy = zip(*results)
|
| 990 |
+
|
| 991 |
+
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
| 992 |
+
choices = self.doc_to_choice(doc)
|
| 993 |
+
completion_len = np.array([float(len(i)) for i in choices])
|
| 994 |
+
|
| 995 |
+
if 2 * len(choices) == len(lls) and "acc_mutual_info" in self._metric_fn_list.keys():
|
| 996 |
+
# then we are doing mutual info.
|
| 997 |
+
# this stores the "dryrun" / unconditional answer loglikelihoods
|
| 998 |
+
lls_unconditional = lls[1::2]
|
| 999 |
+
assert len(lls_unconditional) == len(choices)
|
| 1000 |
+
# and this stores our "regular" conditional loglikelihoods
|
| 1001 |
+
lls = lls[::2]
|
| 1002 |
+
|
| 1003 |
+
pred = np.argmax(lls)
|
| 1004 |
+
pred_norm = np.argmax(lls / completion_len)
|
| 1005 |
+
|
| 1006 |
+
if self.multiple_input:
|
| 1007 |
+
gold = self.doc_to_text(doc)
|
| 1008 |
+
else:
|
| 1009 |
+
gold = self.doc_to_target(doc)
|
| 1010 |
+
|
| 1011 |
+
gold_index_error = False
|
| 1012 |
+
if type(gold) is list:
|
| 1013 |
+
gold = [i if i < len(choices) else -100 for i in gold]
|
| 1014 |
+
if -100 in gold:
|
| 1015 |
+
gold_index_error = True
|
| 1016 |
+
else:
|
| 1017 |
+
if type(gold) is int:
|
| 1018 |
+
gold = gold if gold < len(choices) else -100
|
| 1019 |
+
elif type(gold) is str:
|
| 1020 |
+
gold = choices.index(gold) if gold in choices else -100
|
| 1021 |
+
|
| 1022 |
+
if gold == -100:
|
| 1023 |
+
gold_index_error = True
|
| 1024 |
+
|
| 1025 |
+
if gold_index_error:
|
| 1026 |
+
eval_logger.warning(f"Label index was not in within range of available choices," f"Sample:\n\n{doc}\n\n")
|
| 1027 |
+
|
| 1028 |
+
if self.multiple_target:
|
| 1029 |
+
acc = 1.0 if pred in gold else 0.0
|
| 1030 |
+
acc_norm = 1.0 if pred_norm in gold else 0.0
|
| 1031 |
+
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
| 1032 |
+
else:
|
| 1033 |
+
acc = 1.0 if pred == gold else 0.0
|
| 1034 |
+
acc_norm = 1.0 if pred_norm == gold else 0.0
|
| 1035 |
+
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
| 1036 |
+
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
| 1037 |
+
|
| 1038 |
+
result_dict = {
|
| 1039 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
| 1040 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 1041 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 1042 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 1043 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
if "acc_mutual_info" in use_metric:
|
| 1047 |
+
lls_mutual_info = [ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)]
|
| 1048 |
+
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
| 1049 |
+
result_dict["acc_mutual_info"] = acc_mutual_info
|
| 1050 |
+
|
| 1051 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
| 1052 |
+
gold = self.doc_to_target(doc)
|
| 1053 |
+
result = results[0]
|
| 1054 |
+
if self.config.doc_to_choice is not None:
|
| 1055 |
+
# If you set doc_to_choice,
|
| 1056 |
+
# it assumes that doc_to_target returns a number.
|
| 1057 |
+
choices = self.doc_to_choice(doc)
|
| 1058 |
+
gold = choices[gold]
|
| 1059 |
+
# we expect multiple_targets to be a list.
|
| 1060 |
+
elif self.multiple_target:
|
| 1061 |
+
gold = list(gold)
|
| 1062 |
+
elif type(gold) != type(result):
|
| 1063 |
+
# cast gold to the same type as result
|
| 1064 |
+
gold = type(result)(gold)
|
| 1065 |
+
|
| 1066 |
+
for metric in self._metric_fn_list.keys():
|
| 1067 |
+
if self.multiple_target:
|
| 1068 |
+
# in the case where we have multiple targets,
|
| 1069 |
+
# return true if any are true
|
| 1070 |
+
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
| 1071 |
+
scores = []
|
| 1072 |
+
if not isinstance(gold, list):
|
| 1073 |
+
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
| 1074 |
+
# print(gold)
|
| 1075 |
+
gold = [gold]
|
| 1076 |
+
for gold_option in gold:
|
| 1077 |
+
try:
|
| 1078 |
+
result_score = self._metric_fn_list[metric](
|
| 1079 |
+
references=[gold_option],
|
| 1080 |
+
predictions=[result],
|
| 1081 |
+
**self._metric_fn_kwargs[metric],
|
| 1082 |
+
)
|
| 1083 |
+
except TypeError: # TODO: this is hacky and I don't want to do it
|
| 1084 |
+
result_score = self._metric_fn_list[metric]([gold_option, result])
|
| 1085 |
+
if isinstance(result_score, dict):
|
| 1086 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1087 |
+
result_score = result_score[metric]
|
| 1088 |
+
scores.append(result_score)
|
| 1089 |
+
if any(scores):
|
| 1090 |
+
result_score = 1.0
|
| 1091 |
+
else:
|
| 1092 |
+
result_score = 0.0
|
| 1093 |
+
else:
|
| 1094 |
+
try:
|
| 1095 |
+
result_score = self._metric_fn_list[metric](
|
| 1096 |
+
references=[gold],
|
| 1097 |
+
predictions=[result],
|
| 1098 |
+
**self._metric_fn_kwargs[metric],
|
| 1099 |
+
)
|
| 1100 |
+
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
| 1101 |
+
result_score = self._metric_fn_list[metric]([gold, result])
|
| 1102 |
+
if isinstance(result_score, dict):
|
| 1103 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1104 |
+
result_score = result_score[metric]
|
| 1105 |
+
result_dict[metric] = result_score
|
| 1106 |
+
else:
|
| 1107 |
+
raise ValueError(
|
| 1108 |
+
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
| 1109 |
+
"'loglikelihood','generate_until' or 'multiple_choice'",
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
return result_dict
|
| 1113 |
+
|
| 1114 |
+
def aggregation(self):
|
| 1115 |
+
return self._aggregation_list
|
| 1116 |
+
|
| 1117 |
+
def higher_is_better(self):
|
| 1118 |
+
return self._higher_is_better
|
EAGLE/lmms_eval/filters/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lmms_eval.api.filter import FilterEnsemble
|
| 2 |
+
from . import selection
|
| 3 |
+
from . import extraction
|
| 4 |
+
from . import transformation
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
FILTER_REGISTRY = {
|
| 8 |
+
"take_first": selection.TakeFirstFilter,
|
| 9 |
+
"regex": extraction.RegexFilter,
|
| 10 |
+
"majority_vote": selection.MajorityVoteFilter,
|
| 11 |
+
"take_first_k": selection.TakeKFilter,
|
| 12 |
+
"remove_whitespace": extraction.WhitespaceFilter,
|
| 13 |
+
"lowercase": transformation.LowercaseFilter,
|
| 14 |
+
"uppercase": transformation.UppercaseFilter,
|
| 15 |
+
"map": transformation.MapFilter,
|
| 16 |
+
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
|
| 17 |
+
# that takes an input and returns a scalar and then should select the max reward,
|
| 18 |
+
# or should implement different filters for different ways of handling a reward model's inference.
|
| 19 |
+
# "arg_max": selection.ArgMaxFilter,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_filter(filter_name):
|
| 24 |
+
if filter_name in FILTER_REGISTRY:
|
| 25 |
+
return FILTER_REGISTRY[filter_name]
|
| 26 |
+
else:
|
| 27 |
+
return filter_name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_filter_ensemble(filter_name, components):
|
| 31 |
+
"""
|
| 32 |
+
Create a filtering pipeline.
|
| 33 |
+
"""
|
| 34 |
+
filters = []
|
| 35 |
+
for function, kwargs in components:
|
| 36 |
+
if kwargs is None:
|
| 37 |
+
f = get_filter(function)()
|
| 38 |
+
else:
|
| 39 |
+
# create a filter given its name in the registry
|
| 40 |
+
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
|
| 41 |
+
# add the filter as a pipeline step
|
| 42 |
+
filters.append(f)
|
| 43 |
+
|
| 44 |
+
return FilterEnsemble(name=filter_name, filters=filters)
|
EAGLE/lmms_eval/filters/decontamination.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lmms_eval.api.filter import Filter
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DecontaminationFilter(Filter):
|
| 5 |
+
"""
|
| 6 |
+
A filter which evaluates
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
name = "track_decontamination"
|
| 10 |
+
|
| 11 |
+
def __init__(self, path) -> None:
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
| 15 |
+
should further cache result on a given (task_name, doc_id)
|
| 16 |
+
"""
|
| 17 |
+
self._decontam_results = None
|
| 18 |
+
|
| 19 |
+
def apply(self, resps, docs) -> None:
|
| 20 |
+
"""
|
| 21 |
+
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
| 22 |
+
"""
|
| 23 |
+
pass
|
EAGLE/lmms_eval/filters/extraction.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from lmms_eval.api.filter import Filter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RegexFilter(Filter):
|
| 7 |
+
""" """
|
| 8 |
+
|
| 9 |
+
def __init__(self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]") -> None:
|
| 10 |
+
"""
|
| 11 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 12 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 13 |
+
"""
|
| 14 |
+
self.regex_pattern = regex_pattern
|
| 15 |
+
self.regex = re.compile(regex_pattern)
|
| 16 |
+
self.fallback = fallback
|
| 17 |
+
|
| 18 |
+
def apply(self, resps, docs):
|
| 19 |
+
# here, we assume we have a list, in which each element is
|
| 20 |
+
# a list of model responses for some particular input/target pair.
|
| 21 |
+
# so we process each of these (same input/target response sets)
|
| 22 |
+
# independently (and keep them a list.)
|
| 23 |
+
def filter_set(inst):
|
| 24 |
+
filtered = []
|
| 25 |
+
for resp in inst:
|
| 26 |
+
match = self.regex.search(resp)
|
| 27 |
+
if match:
|
| 28 |
+
match = match.group(1).strip()
|
| 29 |
+
else:
|
| 30 |
+
match = self.fallback
|
| 31 |
+
filtered.append(match)
|
| 32 |
+
return filtered
|
| 33 |
+
|
| 34 |
+
# print(resps)
|
| 35 |
+
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
| 36 |
+
# print(filtered_resps)
|
| 37 |
+
|
| 38 |
+
return filtered_resps
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class WhitespaceFilter(Filter):
|
| 42 |
+
""" """
|
| 43 |
+
|
| 44 |
+
def __init__(self) -> None:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def apply(self, resps, docs):
|
| 48 |
+
def filter_set(inst):
|
| 49 |
+
filtered_resp = []
|
| 50 |
+
for resp in inst:
|
| 51 |
+
if resp.startswith(" "):
|
| 52 |
+
resp = resp[1:]
|
| 53 |
+
|
| 54 |
+
filtered_resp.append(resp)
|
| 55 |
+
|
| 56 |
+
return filtered_resp
|
| 57 |
+
|
| 58 |
+
filtered_resps = [filter_set(resp) for resp in resps]
|
| 59 |
+
|
| 60 |
+
return filtered_resps
|
EAGLE/lmms_eval/filters/selection.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from lmms_eval.api.filter import Filter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TakeFirstFilter(Filter):
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
"""
|
| 9 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def apply(self, resps, docs):
|
| 13 |
+
"""
|
| 14 |
+
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
| 15 |
+
"""
|
| 16 |
+
return map(lambda r: r[0], resps)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TakeKFilter(Filter):
|
| 20 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 21 |
+
self.k = kwargs.pop("k")
|
| 22 |
+
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
def apply(self, resps, docs):
|
| 26 |
+
# check we have at least k responses per doc, else we can't take the first k
|
| 27 |
+
assert len(resps[0]) >= self.k, f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
| 28 |
+
return map(lambda r: r[: self.k], resps)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MajorityVoteFilter(Filter):
|
| 32 |
+
def __init__(self) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def apply(self, resps, docs):
|
| 38 |
+
"""
|
| 39 |
+
Each entry of `resps` is a list of model responses.
|
| 40 |
+
We select the response that occurs most frequently in each entry of `resps`.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def select_majority(resp):
|
| 44 |
+
counts = Counter(resp)
|
| 45 |
+
vote = counts.most_common(1)[0][0]
|
| 46 |
+
return vote
|
| 47 |
+
|
| 48 |
+
return map(lambda r: [select_majority(r)], resps)
|
EAGLE/lmms_eval/filters/transformation.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lmms_eval.api.filter import Filter
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LowercaseFilter(Filter):
|
| 5 |
+
def __init__(self) -> None:
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
def apply(self, resps, docs):
|
| 9 |
+
def filter_set(inst):
|
| 10 |
+
return [resp.lower() for resp in inst]
|
| 11 |
+
|
| 12 |
+
return [filter_set(resp) for resp in resps]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class UppercaseFilter(Filter):
|
| 16 |
+
def __init__(self) -> None:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def apply(self, resps, docs):
|
| 20 |
+
def filter_set(inst):
|
| 21 |
+
return [resp.upper() for resp in inst]
|
| 22 |
+
|
| 23 |
+
return [filter_set(resp) for resp in resps]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MapFilter(Filter):
|
| 27 |
+
def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
|
| 28 |
+
"""
|
| 29 |
+
Initializes the MapFilter with a given mapping dictionary and default value.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
| 33 |
+
Default is an empty dictionary.
|
| 34 |
+
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
| 35 |
+
Default is None.
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
| 39 |
+
"""
|
| 40 |
+
assert isinstance(mapping_dict, dict), "Provided mapping_dict is not a dictionary"
|
| 41 |
+
self.mapping_dict = mapping_dict
|
| 42 |
+
self.default_value = default_value
|
| 43 |
+
|
| 44 |
+
def apply(self, resps, docs):
|
| 45 |
+
def filter_set(inst):
|
| 46 |
+
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
| 47 |
+
|
| 48 |
+
return [filter_set(resp) for resp in resps]
|
EAGLE/lmms_eval/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
AVAILABLE_MODELS = {
|
| 4 |
+
"eagle": "Eagle",
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
for model_name, model_class in AVAILABLE_MODELS.items():
|
| 8 |
+
try:
|
| 9 |
+
exec(f"from .{model_name} import {model_class}")
|
| 10 |
+
except ImportError:
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import hf_transfer
|
| 15 |
+
|
| 16 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
EAGLE/lmms_eval/models/eagle.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import copy
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from datetime import timedelta
|
| 10 |
+
|
| 11 |
+
from lmms_eval import utils
|
| 12 |
+
from lmms_eval.api.instance import Instance
|
| 13 |
+
from lmms_eval.api.model import lmms
|
| 14 |
+
from lmms_eval.api.registry import register_model
|
| 15 |
+
from lmms_eval.utils import stop_sequences_criteria
|
| 16 |
+
|
| 17 |
+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
|
| 18 |
+
from accelerate.state import AcceleratorState
|
| 19 |
+
from typing import List, Optional, Union, Tuple
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
warnings.filterwarnings("ignore")
|
| 23 |
+
|
| 24 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from eagle.model.builder import load_pretrained_model
|
| 28 |
+
from eagle.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
| 29 |
+
from eagle.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
|
| 30 |
+
from eagle.conversation import conv_templates, SeparatorStyle
|
| 31 |
+
except ImportError:
|
| 32 |
+
eval_logger.error("Please add a symbolic link pointing to the eagle folder of repo ")
|
| 33 |
+
|
| 34 |
+
from transformers.integrations.deepspeed import (
|
| 35 |
+
is_deepspeed_zero3_enabled,
|
| 36 |
+
set_hf_deepspeed_config,
|
| 37 |
+
unset_hf_deepspeed_config,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def resize_image_with_aspect_ratio(img, min_size):
|
| 41 |
+
"""
|
| 42 |
+
Resize an image while maintaining its aspect ratio.
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
- image_path: str, path to the input image.
|
| 46 |
+
- min_size: int, the minimum size for the shortest side of the image.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
- resized_image: PIL.Image object, the resized image.
|
| 50 |
+
"""
|
| 51 |
+
# Get the original dimensions of the image
|
| 52 |
+
original_width, original_height = img.size
|
| 53 |
+
|
| 54 |
+
# Determine the aspect ratio
|
| 55 |
+
aspect_ratio = original_width / original_height
|
| 56 |
+
|
| 57 |
+
# Calculate new dimensions based on the shortest side
|
| 58 |
+
if original_width < original_height:
|
| 59 |
+
new_width = min_size
|
| 60 |
+
new_height = int(min_size / aspect_ratio)
|
| 61 |
+
else:
|
| 62 |
+
new_height = min_size
|
| 63 |
+
new_width = int(min_size * aspect_ratio)
|
| 64 |
+
|
| 65 |
+
# Resize the image while maintaining aspect ratio
|
| 66 |
+
resized_image = img.resize((new_width, new_height), Image.LANCZOS)# Image.ANTIALIAS)
|
| 67 |
+
|
| 68 |
+
return resized_image
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@register_model("eagle")
|
| 72 |
+
class Eagle(lmms):
|
| 73 |
+
"""
|
| 74 |
+
EAGLE Model
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
pretrained: str = "NVEagle/Eagle-X5-7B",
|
| 80 |
+
truncation: Optional[bool] = True,
|
| 81 |
+
device: Optional[str] = "cuda",
|
| 82 |
+
dtype: Optional[Union[str, torch.dtype]] = "",
|
| 83 |
+
batch_size: Optional[Union[int, str]] = 1,
|
| 84 |
+
trust_remote_code: Optional[bool] = False,
|
| 85 |
+
revision=None,
|
| 86 |
+
use_flash_attention_2=True,
|
| 87 |
+
device_map="",
|
| 88 |
+
conv_template="vicuna_v1",
|
| 89 |
+
use_cache=True,
|
| 90 |
+
truncate_context=False,
|
| 91 |
+
**kwargs,
|
| 92 |
+
) -> None:
|
| 93 |
+
super().__init__()
|
| 94 |
+
# Do not use kwargs for now
|
| 95 |
+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
|
| 96 |
+
|
| 97 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 98 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 99 |
+
if accelerator.num_processes > 1 and device_map == "":
|
| 100 |
+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
|
| 101 |
+
self.device_map = f"cuda:{accelerator.local_process_index}"
|
| 102 |
+
else:
|
| 103 |
+
self._device = torch.device(device)
|
| 104 |
+
self.device_map = device_map
|
| 105 |
+
|
| 106 |
+
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, use_flash_attention_2=use_flash_attention_2)
|
| 107 |
+
self._config = self._model.config
|
| 108 |
+
self.model.eval()
|
| 109 |
+
self.model.tie_weights()
|
| 110 |
+
self.truncation = truncation
|
| 111 |
+
self.batch_size_per_gpu = int(batch_size)
|
| 112 |
+
self.conv_template = conv_template
|
| 113 |
+
self.use_cache = use_cache
|
| 114 |
+
self.truncate_context = truncate_context
|
| 115 |
+
|
| 116 |
+
if accelerator.num_processes > 1 and device_map == "":
|
| 117 |
+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
|
| 118 |
+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
|
| 119 |
+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
|
| 120 |
+
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
|
| 121 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 122 |
+
kwargs = {
|
| 123 |
+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
|
| 124 |
+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
|
| 125 |
+
}
|
| 126 |
+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
|
| 127 |
+
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
|
| 128 |
+
|
| 129 |
+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 130 |
+
self._model = accelerator.prepare(self.model)
|
| 131 |
+
else:
|
| 132 |
+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
|
| 133 |
+
self.accelerator = accelerator
|
| 134 |
+
if self.accelerator.is_local_main_process:
|
| 135 |
+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
|
| 136 |
+
self._rank = self.accelerator.local_process_index
|
| 137 |
+
self._world_size = self.accelerator.num_processes
|
| 138 |
+
elif accelerator.num_processes == 1 and device_map == "auto":
|
| 139 |
+
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
|
| 140 |
+
self._rank = 0
|
| 141 |
+
self._word_size = 1
|
| 142 |
+
else:
|
| 143 |
+
eval_logger.info(f"Using single device: {self._device}")
|
| 144 |
+
self.model.to(self._device)
|
| 145 |
+
self._rank = 0
|
| 146 |
+
self._world_size = 1
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def config(self):
|
| 150 |
+
# return the associated transformers.AutoConfig for the given pretrained model.
|
| 151 |
+
return self._config
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def tokenizer(self):
|
| 155 |
+
return self._tokenizer
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def model(self):
|
| 159 |
+
# returns the model, unwrapping it if using Accelerate
|
| 160 |
+
if hasattr(self, "accelerator"):
|
| 161 |
+
return self.accelerator.unwrap_model(self._model)
|
| 162 |
+
else:
|
| 163 |
+
return self._model
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def eot_token_id(self):
|
| 167 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 168 |
+
return self.tokenizer.eos_token_id
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def max_length(self):
|
| 172 |
+
return self._max_length
|
| 173 |
+
|
| 174 |
+
def pad_sequence(self, input_ids, batch_first, padding_value):
|
| 175 |
+
if self.tokenizer.padding_side == "left":
|
| 176 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
| 177 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
| 178 |
+
if self.tokenizer.padding_side == "left":
|
| 179 |
+
input_ids = torch.flip(input_ids, [1])
|
| 180 |
+
return input_ids
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def batch_size(self):
|
| 184 |
+
return self.batch_size_per_gpu
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def device(self):
|
| 188 |
+
return self._device
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def rank(self):
|
| 192 |
+
return self._rank
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def world_size(self):
|
| 196 |
+
return self._world_size
|
| 197 |
+
|
| 198 |
+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
|
| 199 |
+
""" """
|
| 200 |
+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
|
| 201 |
+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
|
| 202 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 203 |
+
if left_truncate_len:
|
| 204 |
+
encoding = encoding[-left_truncate_len:]
|
| 205 |
+
return encoding
|
| 206 |
+
|
| 207 |
+
def tok_decode(self, tokens):
|
| 208 |
+
return self.tokenizer.decode(tokens)
|
| 209 |
+
|
| 210 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 211 |
+
# TODO
|
| 212 |
+
res = []
|
| 213 |
+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
|
| 214 |
+
|
| 215 |
+
for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
|
| 216 |
+
# encode, pad, and truncate contexts for this batch
|
| 217 |
+
if type(doc_to_target) == str:
|
| 218 |
+
continuation = doc_to_target
|
| 219 |
+
else:
|
| 220 |
+
continuation = doc_to_target(self.task_dict[task][split][doc_id])
|
| 221 |
+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
|
| 222 |
+
visuals = self.flatten(visuals)
|
| 223 |
+
if visuals:
|
| 224 |
+
image = process_images(visuals, self._image_processor, self._config)
|
| 225 |
+
if type(image) is list:
|
| 226 |
+
image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
|
| 227 |
+
else:
|
| 228 |
+
image = image.to(dtype=torch.float16, device=self.device)
|
| 229 |
+
else:
|
| 230 |
+
image = None
|
| 231 |
+
|
| 232 |
+
prompts_input = contexts[0]
|
| 233 |
+
|
| 234 |
+
if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
|
| 235 |
+
"""
|
| 236 |
+
Three senarios:
|
| 237 |
+
1. No image, and there for, no image token should be added.
|
| 238 |
+
2. image token is already specified in the context, so we don't need to add it.
|
| 239 |
+
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
|
| 240 |
+
"""
|
| 241 |
+
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
|
| 242 |
+
image_tokens = " ".join(image_tokens)
|
| 243 |
+
prompts_input = image_tokens + "\n" + contexts[0]
|
| 244 |
+
|
| 245 |
+
conv = conv_templates[self.conv_template].copy()
|
| 246 |
+
conv.append_message(conv.roles[0], prompts_input)
|
| 247 |
+
conv.append_message(conv.roles[1], None)
|
| 248 |
+
prompt = conv.get_prompt()
|
| 249 |
+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
| 250 |
+
contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
|
| 251 |
+
# Add the answer of the second role
|
| 252 |
+
conv.messages[1][1] = continuation
|
| 253 |
+
|
| 254 |
+
prompt = conv.get_prompt()
|
| 255 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
|
| 256 |
+
labels = input_ids.clone()
|
| 257 |
+
# Context part no need to calculate for loss
|
| 258 |
+
labels[0, : contxt_id.shape[1]] = -100
|
| 259 |
+
with torch.inference_mode():
|
| 260 |
+
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True)
|
| 261 |
+
loss = outputs["loss"]
|
| 262 |
+
# loss = torch.exp(loss)
|
| 263 |
+
logits = outputs["logits"]
|
| 264 |
+
greedy_tokens = logits.argmax(dim=-1)
|
| 265 |
+
cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq]
|
| 266 |
+
greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq]
|
| 267 |
+
max_equal = (greedy_tokens == cont_toks).all()
|
| 268 |
+
res.append((float(loss.item()), bool(max_equal)))
|
| 269 |
+
pbar.update(1)
|
| 270 |
+
pbar.close()
|
| 271 |
+
return res
|
| 272 |
+
|
| 273 |
+
def flatten(self, input):
|
| 274 |
+
new_list = []
|
| 275 |
+
for i in input:
|
| 276 |
+
for j in i:
|
| 277 |
+
new_list.append(j)
|
| 278 |
+
return new_list
|
| 279 |
+
|
| 280 |
+
def generate_until(self, requests: List[Instance]) -> List[str]:
|
| 281 |
+
res = []
|
| 282 |
+
|
| 283 |
+
def _collate(x):
|
| 284 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 285 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 286 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 287 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 288 |
+
# automatic adaptive batches much much easier to implement
|
| 289 |
+
# - any OOMs will happen right away rather than near the end
|
| 290 |
+
toks = self.tok_encode(x[0])
|
| 291 |
+
return -len(toks), x[0]
|
| 292 |
+
|
| 293 |
+
# we group requests by their generation_kwargs,
|
| 294 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 295 |
+
# in the same batch.
|
| 296 |
+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
|
| 297 |
+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
|
| 298 |
+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
|
| 299 |
+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
|
| 300 |
+
for chunk in chunks:
|
| 301 |
+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
|
| 302 |
+
task = task[0]
|
| 303 |
+
split = split[0]
|
| 304 |
+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
|
| 305 |
+
visuals = self.flatten(visuals)
|
| 306 |
+
# we assume all gen kwargs in the batch are the same
|
| 307 |
+
# this is safe to assume because the `grouper` object ensures it.
|
| 308 |
+
gen_kwargs = all_gen_kwargs[0]
|
| 309 |
+
|
| 310 |
+
# Set default values for until and max_new_tokens
|
| 311 |
+
until = [self.tok_decode(self.eot_token_id)]
|
| 312 |
+
|
| 313 |
+
# Update values from gen_kwargs if present
|
| 314 |
+
if "until" in gen_kwargs:
|
| 315 |
+
until = gen_kwargs.pop("until")
|
| 316 |
+
if isinstance(until, str):
|
| 317 |
+
until = [until]
|
| 318 |
+
elif not isinstance(until, list):
|
| 319 |
+
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
|
| 320 |
+
|
| 321 |
+
if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__:
|
| 322 |
+
# here we should pop it out of gen_kwargs so that it doesn't get passed to the model for next step of generation
|
| 323 |
+
self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
|
| 324 |
+
eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")
|
| 325 |
+
|
| 326 |
+
if visuals:
|
| 327 |
+
image_tensor = process_images(visuals, self._image_processor, self._config)
|
| 328 |
+
if type(image_tensor) is list:
|
| 329 |
+
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
|
| 330 |
+
else:
|
| 331 |
+
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
|
| 332 |
+
else:
|
| 333 |
+
image_tensor = None
|
| 334 |
+
|
| 335 |
+
# prompts_input = contexts[0]
|
| 336 |
+
|
| 337 |
+
question_input = []
|
| 338 |
+
|
| 339 |
+
for visual, context in zip(visuals, contexts):
|
| 340 |
+
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
|
| 341 |
+
"""
|
| 342 |
+
Three senarios:
|
| 343 |
+
1. No image, and there for, no image token should be added.
|
| 344 |
+
2. image token is already specified in the context, so we don't need to add it.
|
| 345 |
+
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
|
| 346 |
+
"""
|
| 347 |
+
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
|
| 348 |
+
image_tokens = " ".join(image_tokens)
|
| 349 |
+
question = image_tokens + "\n" + context
|
| 350 |
+
else:
|
| 351 |
+
question = context
|
| 352 |
+
|
| 353 |
+
conv = conv_templates[self.conv_template].copy()
|
| 354 |
+
conv.append_message(conv.roles[0], question)
|
| 355 |
+
conv.append_message(conv.roles[1], None)
|
| 356 |
+
prompt_question = conv.get_prompt()
|
| 357 |
+
question_input.append(prompt_question)
|
| 358 |
+
|
| 359 |
+
# The above for loop has bugs. When there is no visuals, e.g. pure text,
|
| 360 |
+
# there will be no for loop execute resulting in an empty question_input (because no visuals)
|
| 361 |
+
# Scenario 1 won't even be execute
|
| 362 |
+
if len(visuals) == 0:
|
| 363 |
+
for context in contexts:
|
| 364 |
+
question = context
|
| 365 |
+
conv = conv_templates[self.conv_template].copy()
|
| 366 |
+
conv.append_message(conv.roles[0], question)
|
| 367 |
+
conv.append_message(conv.roles[1], None)
|
| 368 |
+
prompt_question = conv.get_prompt()
|
| 369 |
+
question_input.append(prompt_question)
|
| 370 |
+
|
| 371 |
+
# input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
|
| 372 |
+
# preconfigure gen_kwargs with defaults
|
| 373 |
+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
|
| 374 |
+
if "max_new_tokens" not in gen_kwargs:
|
| 375 |
+
gen_kwargs["max_new_tokens"] = 1024
|
| 376 |
+
if "temperature" not in gen_kwargs:
|
| 377 |
+
gen_kwargs["temperature"] = 0
|
| 378 |
+
if "top_p" not in gen_kwargs:
|
| 379 |
+
gen_kwargs["top_p"] = None
|
| 380 |
+
if "num_beams" not in gen_kwargs:
|
| 381 |
+
gen_kwargs["num_beams"] = 1
|
| 382 |
+
|
| 383 |
+
input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input]
|
| 384 |
+
pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
| 385 |
+
input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
|
| 386 |
+
attention_masks = input_ids.ne(pad_token_ids).to(self.device)
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
cont = self.model.generate(
|
| 390 |
+
input_ids,
|
| 391 |
+
attention_mask=attention_masks,
|
| 392 |
+
pad_token_id=pad_token_ids,
|
| 393 |
+
images=image_tensor,
|
| 394 |
+
image_sizes=gen_kwargs["image_sizes"],
|
| 395 |
+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
|
| 396 |
+
temperature=gen_kwargs["temperature"],
|
| 397 |
+
top_p=gen_kwargs["top_p"],
|
| 398 |
+
num_beams=gen_kwargs["num_beams"],
|
| 399 |
+
max_new_tokens=gen_kwargs["max_new_tokens"],
|
| 400 |
+
use_cache=self.use_cache,
|
| 401 |
+
)
|
| 402 |
+
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
|
| 403 |
+
except Exception as e:
|
| 404 |
+
eval_logger.error(f"Error {e} in generating")
|
| 405 |
+
cont = ""
|
| 406 |
+
text_outputs = [""]
|
| 407 |
+
|
| 408 |
+
res.extend(text_outputs)
|
| 409 |
+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
|
| 410 |
+
pbar.update(1)
|
| 411 |
+
|
| 412 |
+
res = re_ords.get_original(res)
|
| 413 |
+
|
| 414 |
+
pbar.close()
|
| 415 |
+
return res
|
EAGLE/lmms_eval/models/gpt4v.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
import os
|
| 4 |
+
import base64
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import requests as url_requests
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from lmms_eval.api.instance import Instance
|
| 12 |
+
from lmms_eval.api.model import lmms
|
| 13 |
+
from lmms_eval.api.registry import register_model
|
| 14 |
+
from lmms_eval import utils
|
| 15 |
+
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
API_TYPE = os.getenv("API_TYPE", "openai")
|
| 19 |
+
NUM_SECONDS_TO_SLEEP = 5
|
| 20 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 21 |
+
|
| 22 |
+
if API_TYPE == "openai":
|
| 23 |
+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
|
| 24 |
+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
|
| 25 |
+
headers = {
|
| 26 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 27 |
+
"Content-Type": "application/json",
|
| 28 |
+
}
|
| 29 |
+
elif API_TYPE == "azure":
|
| 30 |
+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
|
| 31 |
+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
|
| 32 |
+
headers = {
|
| 33 |
+
"api-key": API_KEY,
|
| 34 |
+
"Content-Type": "application/json",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@register_model("gpt4V")
|
| 39 |
+
class GPT4V(lmms):
|
| 40 |
+
def __init__(self, **kwargs) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
# Manually set a image token for GPT4V so that we can search for it
|
| 43 |
+
# and split the text and image
|
| 44 |
+
# Here we just use the same token as llava for convenient
|
| 45 |
+
self.image_token = "<image>"
|
| 46 |
+
|
| 47 |
+
# Function to encode the image
|
| 48 |
+
def encode_image(self, image: Image):
|
| 49 |
+
output_buffer = BytesIO()
|
| 50 |
+
image.save(output_buffer, format="JPEG")
|
| 51 |
+
byte_data = output_buffer.getvalue()
|
| 52 |
+
base64_str = base64.b64encode(byte_data).decode("utf-8")
|
| 53 |
+
return base64_str
|
| 54 |
+
|
| 55 |
+
def flatten(self, input):
|
| 56 |
+
new_list = []
|
| 57 |
+
for i in input:
|
| 58 |
+
for j in i:
|
| 59 |
+
new_list.append(j)
|
| 60 |
+
return new_list
|
| 61 |
+
|
| 62 |
+
def generate_until(self, requests) -> List[str]:
|
| 63 |
+
res = []
|
| 64 |
+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
|
| 65 |
+
|
| 66 |
+
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
|
| 67 |
+
# encode, pad, and truncate contexts for this batch
|
| 68 |
+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
|
| 69 |
+
visuals = self.flatten(visuals)
|
| 70 |
+
imgs = []
|
| 71 |
+
for visual in visuals:
|
| 72 |
+
img = self.encode_image(visual)
|
| 73 |
+
imgs.append(img)
|
| 74 |
+
|
| 75 |
+
payload = {"model": "gpt-4-vision-preview", "messages": []}
|
| 76 |
+
response_json = {"role": "user", "content": []}
|
| 77 |
+
# When there is no image token in the context, append the image to the text
|
| 78 |
+
if self.image_token not in contexts:
|
| 79 |
+
payload["messages"].append(deepcopy(response_json))
|
| 80 |
+
payload["messages"][0]["content"].append({"type": "text", "text": contexts})
|
| 81 |
+
for img in imgs:
|
| 82 |
+
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
|
| 83 |
+
else:
|
| 84 |
+
contexts = contexts.split(self.image_token)
|
| 85 |
+
for idx, img in enumerate(imgs):
|
| 86 |
+
payload["messages"].append(deepcopy(response_json))
|
| 87 |
+
payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]})
|
| 88 |
+
payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}})
|
| 89 |
+
|
| 90 |
+
# If n image tokens are in the contexts
|
| 91 |
+
# contexts will be splitted into n+1 chunks
|
| 92 |
+
# Manually add it into the payload
|
| 93 |
+
payload["messages"].append(deepcopy(response_json))
|
| 94 |
+
payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]})
|
| 95 |
+
|
| 96 |
+
if "max_new_tokens" not in gen_kwargs:
|
| 97 |
+
gen_kwargs["max_new_tokens"] = 1024
|
| 98 |
+
if "temperature" not in gen_kwargs:
|
| 99 |
+
gen_kwargs["temperature"] = 0
|
| 100 |
+
if "top_p" not in gen_kwargs:
|
| 101 |
+
gen_kwargs["top_p"] = None
|
| 102 |
+
if "num_beams" not in gen_kwargs:
|
| 103 |
+
gen_kwargs["num_beams"] = 1
|
| 104 |
+
|
| 105 |
+
# payload["max_tokens"] = gen_kwargs["max_new_tokens"]
|
| 106 |
+
# payload["temperature"] = gen_kwargs["temperature"]
|
| 107 |
+
|
| 108 |
+
for attempt in range(5):
|
| 109 |
+
try:
|
| 110 |
+
response = url_requests.post(API_URL, headers=headers, json=payload, timeout=20)
|
| 111 |
+
response_data = response.json()
|
| 112 |
+
|
| 113 |
+
content = response_data["choices"][0]["message"]["content"].strip()
|
| 114 |
+
break # If successful, break out of the loop
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}")
|
| 118 |
+
if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt
|
| 119 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
| 120 |
+
else: # If this was the last attempt, log and return empty
|
| 121 |
+
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}")
|
| 122 |
+
content = ""
|
| 123 |
+
res.append(content)
|
| 124 |
+
pbar.update(1)
|
| 125 |
+
return res
|
| 126 |
+
|
| 127 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 128 |
+
# TODO
|
| 129 |
+
assert False, "GPT4V not support"
|
EAGLE/lmms_eval/tasks/__init__.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Union, Dict
|
| 3 |
+
|
| 4 |
+
from lmms_eval import utils
|
| 5 |
+
|
| 6 |
+
# from lmms_eval import prompts
|
| 7 |
+
from lmms_eval.api.task import TaskConfig, Task, ConfigurableTask
|
| 8 |
+
from lmms_eval.api.registry import (
|
| 9 |
+
register_task,
|
| 10 |
+
register_group,
|
| 11 |
+
TASK_REGISTRY,
|
| 12 |
+
GROUP_REGISTRY,
|
| 13 |
+
ALL_TASKS,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def register_configurable_task(config: Dict[str, str]) -> int:
|
| 22 |
+
SubClass = type(
|
| 23 |
+
config["task"] + "ConfigurableTask",
|
| 24 |
+
(ConfigurableTask,),
|
| 25 |
+
{"CONFIG": TaskConfig(**config)},
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if "task" in config:
|
| 29 |
+
task_name = "{}".format(config["task"])
|
| 30 |
+
register_task(task_name)(SubClass)
|
| 31 |
+
|
| 32 |
+
if "group" in config:
|
| 33 |
+
if config["group"] == config["task"]:
|
| 34 |
+
raise ValueError("task and group name cannot be the same")
|
| 35 |
+
elif type(config["group"]) == str:
|
| 36 |
+
group_name = [config["group"]]
|
| 37 |
+
else:
|
| 38 |
+
group_name = config["group"]
|
| 39 |
+
|
| 40 |
+
for group in group_name:
|
| 41 |
+
register_group(group)(SubClass)
|
| 42 |
+
|
| 43 |
+
return 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def register_configurable_group(config: Dict[str, str]) -> int:
|
| 47 |
+
group = config["group"]
|
| 48 |
+
task_list = config["task"]
|
| 49 |
+
task_names = utils.pattern_match(task_list, ALL_TASKS)
|
| 50 |
+
for task in task_names:
|
| 51 |
+
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
|
| 52 |
+
if group in GROUP_REGISTRY:
|
| 53 |
+
GROUP_REGISTRY[group].append(task)
|
| 54 |
+
else:
|
| 55 |
+
GROUP_REGISTRY[group] = [task]
|
| 56 |
+
ALL_TASKS.add(group)
|
| 57 |
+
return 0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
|
| 61 |
+
if "dataset_name" in task_config:
|
| 62 |
+
return "{dataset_path}_{dataset_name}".format(**task_config)
|
| 63 |
+
else:
|
| 64 |
+
return "{dataset_path}".format(**task_config)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def include_task_folder(task_dir: str, register_task: bool = True) -> None:
|
| 68 |
+
"""
|
| 69 |
+
Calling this function
|
| 70 |
+
"""
|
| 71 |
+
for root, subdirs, file_list in os.walk(task_dir):
|
| 72 |
+
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
|
| 73 |
+
for f in file_list:
|
| 74 |
+
if f.endswith(".yaml"):
|
| 75 |
+
yaml_path = os.path.join(root, f)
|
| 76 |
+
try:
|
| 77 |
+
config = utils.load_yaml_config(yaml_path)
|
| 78 |
+
|
| 79 |
+
if "task" not in config:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if register_task:
|
| 83 |
+
if type(config["task"]) == str:
|
| 84 |
+
register_configurable_task(config)
|
| 85 |
+
else:
|
| 86 |
+
if type(config["task"]) == list:
|
| 87 |
+
register_configurable_group(config)
|
| 88 |
+
|
| 89 |
+
# Log this silently and show it only when
|
| 90 |
+
# the user defines the appropriate verbosity.
|
| 91 |
+
except ModuleNotFoundError as e:
|
| 92 |
+
eval_logger.debug(f"{yaml_path}: {e}. Config will not be added to registry.")
|
| 93 |
+
except Exception as error:
|
| 94 |
+
import traceback
|
| 95 |
+
|
| 96 |
+
eval_logger.debug(f"Failed to load config in {yaml_path}. Config will not be added to registry\n" f"Error: {error}\n" f"Traceback: {traceback.format_exc()}")
|
| 97 |
+
return 0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def include_path(task_dir):
|
| 101 |
+
include_task_folder(task_dir)
|
| 102 |
+
# Register Benchmarks after all tasks have been added
|
| 103 |
+
include_task_folder(task_dir, register_task=False)
|
| 104 |
+
return 0
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def initialize_tasks(verbosity="INFO"):
|
| 108 |
+
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
|
| 109 |
+
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
|
| 110 |
+
include_path(task_dir)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_task(task_name, model_name):
|
| 114 |
+
try:
|
| 115 |
+
return TASK_REGISTRY[task_name](model_name=model_name)
|
| 116 |
+
except KeyError:
|
| 117 |
+
eval_logger.info("Available tasks:")
|
| 118 |
+
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
|
| 119 |
+
raise KeyError(f"Missing task {task_name}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_task_name_from_object(task_object):
|
| 123 |
+
for name, class_ in TASK_REGISTRY.items():
|
| 124 |
+
if class_ is task_object:
|
| 125 |
+
return name
|
| 126 |
+
|
| 127 |
+
# TODO: scrap this
|
| 128 |
+
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
|
| 129 |
+
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# TODO: pass num_fewshot and other cmdline overrides in a better way
|
| 133 |
+
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], model_name: str):
|
| 134 |
+
all_task_dict = {}
|
| 135 |
+
|
| 136 |
+
# Ensure task_name_list is a list to simplify processing
|
| 137 |
+
if not isinstance(task_name_list, list):
|
| 138 |
+
task_name_list = [task_name_list]
|
| 139 |
+
|
| 140 |
+
for task_element in task_name_list:
|
| 141 |
+
if isinstance(task_element, str) and task_element in GROUP_REGISTRY:
|
| 142 |
+
group_name = task_element
|
| 143 |
+
for task_name in GROUP_REGISTRY[task_element]:
|
| 144 |
+
if task_name not in all_task_dict:
|
| 145 |
+
# Recursively get the task dictionary for nested groups
|
| 146 |
+
task_obj = get_task_dict([task_name], model_name)
|
| 147 |
+
# Merge the dictionaries
|
| 148 |
+
all_task_dict.update({task_name: (group_name, task_obj.get(task_name, None))})
|
| 149 |
+
else:
|
| 150 |
+
task_name = task_element if isinstance(task_element, str) else task_element.EVAL_HARNESS_NAME
|
| 151 |
+
if task_name not in all_task_dict:
|
| 152 |
+
task_obj = get_task(task_name=task_name, model_name=model_name)
|
| 153 |
+
all_task_dict[task_name] = task_obj
|
| 154 |
+
|
| 155 |
+
return all_task_dict
|
EAGLE/lmms_eval/tasks/_task_utils/file_utils.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def generate_submission_file(file_name, args, subpath="submissions"):
|
| 5 |
+
path = os.path.join(args.output_path, subpath)
|
| 6 |
+
os.makedirs(path, exist_ok=True)
|
| 7 |
+
path = os.path.join(path, file_name)
|
| 8 |
+
return os.path.abspath(path)
|
EAGLE/lmms_eval/tasks/_task_utils/gpt_eval_utils.py
ADDED
|
File without changes
|
EAGLE/lmms_eval/tasks/_task_utils/vqa_eval_metric.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class EvalAIAnswerProcessor:
|
| 5 |
+
"""
|
| 6 |
+
Processes an answer similar to Eval AI
|
| 7 |
+
copied from
|
| 8 |
+
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
CONTRACTIONS = {
|
| 12 |
+
"aint": "ain't",
|
| 13 |
+
"arent": "aren't",
|
| 14 |
+
"cant": "can't",
|
| 15 |
+
"couldve": "could've",
|
| 16 |
+
"couldnt": "couldn't",
|
| 17 |
+
"couldn'tve": "couldn't've",
|
| 18 |
+
"couldnt've": "couldn't've",
|
| 19 |
+
"didnt": "didn't",
|
| 20 |
+
"doesnt": "doesn't",
|
| 21 |
+
"dont": "don't",
|
| 22 |
+
"hadnt": "hadn't",
|
| 23 |
+
"hadnt've": "hadn't've",
|
| 24 |
+
"hadn'tve": "hadn't've",
|
| 25 |
+
"hasnt": "hasn't",
|
| 26 |
+
"havent": "haven't",
|
| 27 |
+
"hed": "he'd",
|
| 28 |
+
"hed've": "he'd've",
|
| 29 |
+
"he'dve": "he'd've",
|
| 30 |
+
"hes": "he's",
|
| 31 |
+
"howd": "how'd",
|
| 32 |
+
"howll": "how'll",
|
| 33 |
+
"hows": "how's",
|
| 34 |
+
"Id've": "I'd've",
|
| 35 |
+
"I'dve": "I'd've",
|
| 36 |
+
"Im": "I'm",
|
| 37 |
+
"Ive": "I've",
|
| 38 |
+
"isnt": "isn't",
|
| 39 |
+
"itd": "it'd",
|
| 40 |
+
"itd've": "it'd've",
|
| 41 |
+
"it'dve": "it'd've",
|
| 42 |
+
"itll": "it'll",
|
| 43 |
+
"let's": "let's",
|
| 44 |
+
"maam": "ma'am",
|
| 45 |
+
"mightnt": "mightn't",
|
| 46 |
+
"mightnt've": "mightn't've",
|
| 47 |
+
"mightn'tve": "mightn't've",
|
| 48 |
+
"mightve": "might've",
|
| 49 |
+
"mustnt": "mustn't",
|
| 50 |
+
"mustve": "must've",
|
| 51 |
+
"neednt": "needn't",
|
| 52 |
+
"notve": "not've",
|
| 53 |
+
"oclock": "o'clock",
|
| 54 |
+
"oughtnt": "oughtn't",
|
| 55 |
+
"ow's'at": "'ow's'at",
|
| 56 |
+
"'ows'at": "'ow's'at",
|
| 57 |
+
"'ow'sat": "'ow's'at",
|
| 58 |
+
"shant": "shan't",
|
| 59 |
+
"shed've": "she'd've",
|
| 60 |
+
"she'dve": "she'd've",
|
| 61 |
+
"she's": "she's",
|
| 62 |
+
"shouldve": "should've",
|
| 63 |
+
"shouldnt": "shouldn't",
|
| 64 |
+
"shouldnt've": "shouldn't've",
|
| 65 |
+
"shouldn'tve": "shouldn't've",
|
| 66 |
+
"somebody'd": "somebodyd",
|
| 67 |
+
"somebodyd've": "somebody'd've",
|
| 68 |
+
"somebody'dve": "somebody'd've",
|
| 69 |
+
"somebodyll": "somebody'll",
|
| 70 |
+
"somebodys": "somebody's",
|
| 71 |
+
"someoned": "someone'd",
|
| 72 |
+
"someoned've": "someone'd've",
|
| 73 |
+
"someone'dve": "someone'd've",
|
| 74 |
+
"someonell": "someone'll",
|
| 75 |
+
"someones": "someone's",
|
| 76 |
+
"somethingd": "something'd",
|
| 77 |
+
"somethingd've": "something'd've",
|
| 78 |
+
"something'dve": "something'd've",
|
| 79 |
+
"somethingll": "something'll",
|
| 80 |
+
"thats": "that's",
|
| 81 |
+
"thered": "there'd",
|
| 82 |
+
"thered've": "there'd've",
|
| 83 |
+
"there'dve": "there'd've",
|
| 84 |
+
"therere": "there're",
|
| 85 |
+
"theres": "there's",
|
| 86 |
+
"theyd": "they'd",
|
| 87 |
+
"theyd've": "they'd've",
|
| 88 |
+
"they'dve": "they'd've",
|
| 89 |
+
"theyll": "they'll",
|
| 90 |
+
"theyre": "they're",
|
| 91 |
+
"theyve": "they've",
|
| 92 |
+
"twas": "'twas",
|
| 93 |
+
"wasnt": "wasn't",
|
| 94 |
+
"wed've": "we'd've",
|
| 95 |
+
"we'dve": "we'd've",
|
| 96 |
+
"weve": "we've",
|
| 97 |
+
"werent": "weren't",
|
| 98 |
+
"whatll": "what'll",
|
| 99 |
+
"whatre": "what're",
|
| 100 |
+
"whats": "what's",
|
| 101 |
+
"whatve": "what've",
|
| 102 |
+
"whens": "when's",
|
| 103 |
+
"whered": "where'd",
|
| 104 |
+
"wheres": "where's",
|
| 105 |
+
"whereve": "where've",
|
| 106 |
+
"whod": "who'd",
|
| 107 |
+
"whod've": "who'd've",
|
| 108 |
+
"who'dve": "who'd've",
|
| 109 |
+
"wholl": "who'll",
|
| 110 |
+
"whos": "who's",
|
| 111 |
+
"whove": "who've",
|
| 112 |
+
"whyll": "why'll",
|
| 113 |
+
"whyre": "why're",
|
| 114 |
+
"whys": "why's",
|
| 115 |
+
"wont": "won't",
|
| 116 |
+
"wouldve": "would've",
|
| 117 |
+
"wouldnt": "wouldn't",
|
| 118 |
+
"wouldnt've": "wouldn't've",
|
| 119 |
+
"wouldn'tve": "wouldn't've",
|
| 120 |
+
"yall": "y'all",
|
| 121 |
+
"yall'll": "y'all'll",
|
| 122 |
+
"y'allll": "y'all'll",
|
| 123 |
+
"yall'd've": "y'all'd've",
|
| 124 |
+
"y'alld've": "y'all'd've",
|
| 125 |
+
"y'all'dve": "y'all'd've",
|
| 126 |
+
"youd": "you'd",
|
| 127 |
+
"youd've": "you'd've",
|
| 128 |
+
"you'dve": "you'd've",
|
| 129 |
+
"youll": "you'll",
|
| 130 |
+
"youre": "you're",
|
| 131 |
+
"youve": "you've",
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
NUMBER_MAP = {
|
| 135 |
+
"none": "0",
|
| 136 |
+
"zero": "0",
|
| 137 |
+
"one": "1",
|
| 138 |
+
"two": "2",
|
| 139 |
+
"three": "3",
|
| 140 |
+
"four": "4",
|
| 141 |
+
"five": "5",
|
| 142 |
+
"six": "6",
|
| 143 |
+
"seven": "7",
|
| 144 |
+
"eight": "8",
|
| 145 |
+
"nine": "9",
|
| 146 |
+
"ten": "10",
|
| 147 |
+
}
|
| 148 |
+
ARTICLES = ["a", "an", "the"]
|
| 149 |
+
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
| 150 |
+
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
| 151 |
+
PUNCTUATIONS = [
|
| 152 |
+
";",
|
| 153 |
+
r"/",
|
| 154 |
+
"[",
|
| 155 |
+
"]",
|
| 156 |
+
'"',
|
| 157 |
+
"{",
|
| 158 |
+
"}",
|
| 159 |
+
"(",
|
| 160 |
+
")",
|
| 161 |
+
"=",
|
| 162 |
+
"+",
|
| 163 |
+
"\\",
|
| 164 |
+
"_",
|
| 165 |
+
"-",
|
| 166 |
+
">",
|
| 167 |
+
"<",
|
| 168 |
+
"@",
|
| 169 |
+
"`",
|
| 170 |
+
",",
|
| 171 |
+
"?",
|
| 172 |
+
"!",
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
def __init__(self, *args, **kwargs):
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
def word_tokenize(self, word):
|
| 179 |
+
word = word.lower()
|
| 180 |
+
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
| 181 |
+
return word.strip()
|
| 182 |
+
|
| 183 |
+
def process_punctuation(self, in_text):
|
| 184 |
+
out_text = in_text
|
| 185 |
+
for p in self.PUNCTUATIONS:
|
| 186 |
+
if (p + " " in in_text or " " + p in in_text) or (re.search(self.COMMA_STRIP, in_text) is not None):
|
| 187 |
+
out_text = out_text.replace(p, "")
|
| 188 |
+
else:
|
| 189 |
+
out_text = out_text.replace(p, " ")
|
| 190 |
+
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
| 191 |
+
return out_text
|
| 192 |
+
|
| 193 |
+
def process_digit_article(self, in_text):
|
| 194 |
+
out_text = []
|
| 195 |
+
temp_text = in_text.lower().split()
|
| 196 |
+
for word in temp_text:
|
| 197 |
+
word = self.NUMBER_MAP.setdefault(word, word)
|
| 198 |
+
if word not in self.ARTICLES:
|
| 199 |
+
out_text.append(word)
|
| 200 |
+
else:
|
| 201 |
+
pass
|
| 202 |
+
for word_id, word in enumerate(out_text):
|
| 203 |
+
if word in self.CONTRACTIONS:
|
| 204 |
+
out_text[word_id] = self.CONTRACTIONS[word]
|
| 205 |
+
out_text = " ".join(out_text)
|
| 206 |
+
return out_text
|
| 207 |
+
|
| 208 |
+
def __call__(self, item):
|
| 209 |
+
item = self.word_tokenize(item)
|
| 210 |
+
item = item.replace("\n", " ").replace("\t", " ").strip()
|
| 211 |
+
item = self.process_punctuation(item)
|
| 212 |
+
item = self.process_digit_article(item)
|
| 213 |
+
return item
|
EAGLE/lmms_eval/tasks/cmmmu/_cmmmu.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
group: cmmmu
|
| 2 |
+
task:
|
| 3 |
+
- cmmmu_val
|
| 4 |
+
- cmmmu_test
|
EAGLE/lmms_eval/tasks/cmmmu/_default_template_cmmmu_yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/CMMMU
|
| 2 |
+
output_type: generate_until
|
| 3 |
+
doc_to_visual: !function utils.cmmmu_doc_to_visual
|
| 4 |
+
doc_to_text: !function utils.cmmmu_doc_to_text
|
| 5 |
+
doc_to_target: "answer"
|
| 6 |
+
generation_kwargs:
|
| 7 |
+
max_new_tokens: 16
|
| 8 |
+
image_aspect_ratio: original
|
EAGLE/lmms_eval/tasks/cmmmu/cmmmu_test.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: "cmmmu_test"
|
| 2 |
+
test_split: test
|
| 3 |
+
# The return value of process_results will be used by metrics
|
| 4 |
+
process_results: !function utils.cmmmu_process_test_results_for_submission
|
| 5 |
+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
|
| 6 |
+
metric_list:
|
| 7 |
+
- metric: submission
|
| 8 |
+
aggregation: !function utils.cmmmu_test_aggregate_results_for_submission
|
| 9 |
+
higher_is_better: false
|
| 10 |
+
metadata:
|
| 11 |
+
- version: 0.0
|
| 12 |
+
include: _default_template_cmmmu_yaml
|
EAGLE/lmms_eval/tasks/cmmmu/cmmmu_val.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: "cmmmu_val"
|
| 2 |
+
test_split: val
|
| 3 |
+
# The return value of process_results will be used by metrics
|
| 4 |
+
process_results: !function utils.cmmmu_process_results
|
| 5 |
+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
|
| 6 |
+
generation_kwargs:
|
| 7 |
+
max_new_tokens: 16
|
| 8 |
+
image_aspect_ratio: original
|
| 9 |
+
metric_list:
|
| 10 |
+
- metric: cmmmu_acc
|
| 11 |
+
aggregation: !function utils.cmmmu_aggregate_results
|
| 12 |
+
higher_is_better: true
|
| 13 |
+
metadata:
|
| 14 |
+
- version: 0.0
|
| 15 |
+
include: _default_template_cmmmu_yaml
|
EAGLE/lmms_eval/tasks/cmmmu/utils.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
|
| 9 |
+
|
| 10 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 11 |
+
|
| 12 |
+
PROMPT = {
|
| 13 |
+
"task_instructions": [
|
| 14 |
+
"请回答以下多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。",
|
| 15 |
+
"请回答以下判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。",
|
| 16 |
+
"请回答以下填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。",
|
| 17 |
+
],
|
| 18 |
+
"multi_choice_example_format": ["问题:{}\n选项:\n{}\n正确答案:\n"],
|
| 19 |
+
"T/F_example_format": ["问题:{}\n正确答案:\n"],
|
| 20 |
+
"short_ans_example_format": ["问题:{}\n正确答案:\n"],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def construct_prompt(sample):
|
| 25 |
+
question = sample["question"]
|
| 26 |
+
task_instructions = PROMPT["task_instructions"]
|
| 27 |
+
|
| 28 |
+
if sample["type"] == "选择":
|
| 29 |
+
formatted_options = ""
|
| 30 |
+
start_chr = "A"
|
| 31 |
+
for i in range(1, 5):
|
| 32 |
+
formatted_options += f"({start_chr}) {sample[f'option{i}']}\n"
|
| 33 |
+
start_chr = chr(ord(start_chr) + 1)
|
| 34 |
+
|
| 35 |
+
current_example_template = PROMPT["multi_choice_example_format"][0]
|
| 36 |
+
current_example = current_example_template.format(question, formatted_options)
|
| 37 |
+
final_input_prompt = task_instructions[0] + "\n\n" + current_example
|
| 38 |
+
|
| 39 |
+
elif sample["type"] == "判断":
|
| 40 |
+
current_example_template = PROMPT["T/F_example_format"][0]
|
| 41 |
+
current_example = current_example_template.format(question)
|
| 42 |
+
final_input_prompt = task_instructions[1] + "\n\n" + current_example
|
| 43 |
+
|
| 44 |
+
else: # For fill in the blanks questions.
|
| 45 |
+
current_example_template = PROMPT["short_ans_example_format"][0]
|
| 46 |
+
current_example = current_example_template.format(question)
|
| 47 |
+
final_input_prompt = task_instructions[2] + "\n\n" + current_example
|
| 48 |
+
|
| 49 |
+
for i in range(1, 6):
|
| 50 |
+
final_input_prompt = final_input_prompt.replace(f'<img="{sample[f"image_{i}_filename"]}">', f"<图片 {i}>")
|
| 51 |
+
|
| 52 |
+
return final_input_prompt
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def cmmmu_doc_to_text(doc):
|
| 56 |
+
return construct_prompt(doc)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def cmmmu_doc_to_visual(doc):
|
| 60 |
+
prompt = construct_prompt(doc)
|
| 61 |
+
image_tokens = re.findall(r"<图片 \d+>", prompt)
|
| 62 |
+
# Remove <> and swap space as _
|
| 63 |
+
image_tokens = [image_token.strip("<>").replace(" ", "_").replace("图片", "image") for image_token in image_tokens]
|
| 64 |
+
visual = [doc[image_token].convert("RGB") for image_token in image_tokens]
|
| 65 |
+
return visual
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def cmmmu_process_results(doc, results):
|
| 69 |
+
pred = results[0]
|
| 70 |
+
if doc["type"] == "选择":
|
| 71 |
+
index2ans, all_choices = get_multi_choice_info([doc[f"option{i}"] for i in range(1, 5)])
|
| 72 |
+
parsed_pred = get_multi_choice_prediction(pred, all_choices, index2ans)
|
| 73 |
+
elif doc["type"] == "判断":
|
| 74 |
+
parsed_pred = get_TF_prediction(pred)
|
| 75 |
+
else:
|
| 76 |
+
parsed_pred = get_fill_blank_prediction(pred, doc["answer"])
|
| 77 |
+
return {"cmmmu_acc": {"id": doc["id"], "subdomain": doc["subcategory"], "question_type": doc["type"], "answer": doc["answer"], "parsed_pred": parsed_pred}}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def cmmmu_aggregate_results(results):
|
| 81 |
+
evaluation_result = {}
|
| 82 |
+
subset_to_eval_samples = defaultdict(list)
|
| 83 |
+
for result in results:
|
| 84 |
+
subset_to_eval_samples[result["subdomain"]].append(result)
|
| 85 |
+
for subset, sub_eval_samples in subset_to_eval_samples.items():
|
| 86 |
+
metric_dict = eval_cmmmu(sub_eval_samples)
|
| 87 |
+
evaluation_result[subset] = metric_dict
|
| 88 |
+
|
| 89 |
+
printable_results = {}
|
| 90 |
+
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
|
| 91 |
+
in_domain_cat_results = {}
|
| 92 |
+
for cat_name in in_domain_cats:
|
| 93 |
+
if cat_name in evaluation_result.keys():
|
| 94 |
+
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
|
| 95 |
+
else:
|
| 96 |
+
pass
|
| 97 |
+
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
|
| 98 |
+
in_domain_data_num = sum([cat_results["entries_num"] for cat_results in in_domain_cat_results.values()])
|
| 99 |
+
printable_results["Overall-" + domain] = {
|
| 100 |
+
"num": int(in_domain_data_num),
|
| 101 |
+
"acc": round(in_domain_ins_acc, 3),
|
| 102 |
+
}
|
| 103 |
+
# add sub category
|
| 104 |
+
for cat_name, cat_results in in_domain_cat_results.items():
|
| 105 |
+
printable_results[cat_name] = {
|
| 106 |
+
"num": int(cat_results["entries_num"]),
|
| 107 |
+
"acc": round(cat_results["acc"], 3),
|
| 108 |
+
}
|
| 109 |
+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
|
| 110 |
+
printable_results["Overall"] = {
|
| 111 |
+
"num": sum([cat_results["entries_num"] for cat_results in evaluation_result.values()]),
|
| 112 |
+
"acc": round(all_ins_acc, 3),
|
| 113 |
+
}
|
| 114 |
+
print(printable_results)
|
| 115 |
+
return printable_results["Overall"]["acc"]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def cmmmu_process_test_results_for_submission(doc, results):
|
| 119 |
+
response = results[0]
|
| 120 |
+
return {"submission": {"id": doc["id"], "type": doc["type"], "response": response}}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def cmmmu_test_aggregate_results_for_submission(results, args):
|
| 124 |
+
file = generate_submission_file("cmmmu_test_for_submission.jsonl", args)
|
| 125 |
+
with open(file, "w", encoding="utf8") as f:
|
| 126 |
+
for result in results:
|
| 127 |
+
json.dump(result, f, ensure_ascii=False)
|
| 128 |
+
f.write("\n")
|
| 129 |
+
eval_logger.info(f"Submission file saved to {file}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
##################
|
| 133 |
+
# Helper functions
|
| 134 |
+
##################
|
| 135 |
+
|
| 136 |
+
DOMAIN_CAT2SUB_CAT = {
|
| 137 |
+
"艺术与设计": ["艺术", "艺术理论", "设计", "音乐"],
|
| 138 |
+
"商业": ["会计", "经济", "金融", "管理", "营销"],
|
| 139 |
+
"科学": ["生物", "化学", "地理", "数学", "物理"],
|
| 140 |
+
"健康与医学": ["基础医学", "临床医学", "诊断学与实验室医学", "制药", "公共卫生"],
|
| 141 |
+
"人文社会科学": ["历史", "文献学", "社会学", "心理学"],
|
| 142 |
+
"技术与工程": ["农业", "建筑学", "计算机科学", "电子学", "能源和电力", "材料", "机械工程"],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def eval_cmmmu(entries):
|
| 147 |
+
correct_cnt = 0
|
| 148 |
+
for entry in entries:
|
| 149 |
+
parsed_pred = entry.get("parsed_pred", "")
|
| 150 |
+
correct = False
|
| 151 |
+
if entry.get("question_type") == "选择":
|
| 152 |
+
if parsed_pred == entry["answer"]:
|
| 153 |
+
correct_cnt += 1
|
| 154 |
+
correct = True
|
| 155 |
+
|
| 156 |
+
elif entry.get("question_type") == "填空":
|
| 157 |
+
norm_answers = normalize_str(entry["answer"], entry["answer"])
|
| 158 |
+
|
| 159 |
+
for pred in parsed_pred:
|
| 160 |
+
# already normalized
|
| 161 |
+
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
|
| 162 |
+
for norm_ans in norm_answers:
|
| 163 |
+
# only see if the string answer in the string pred
|
| 164 |
+
# print(norm_ans, pred)
|
| 165 |
+
if isinstance(norm_ans, str) and norm_ans in pred:
|
| 166 |
+
if not correct:
|
| 167 |
+
correct_cnt += 1
|
| 168 |
+
correct = True
|
| 169 |
+
break
|
| 170 |
+
else: # it's a number
|
| 171 |
+
if pred in norm_answers:
|
| 172 |
+
if not correct:
|
| 173 |
+
correct_cnt += 1
|
| 174 |
+
correct = True
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
positive_keywords = ["正确", "对", "准确", "肯定", "对的"]
|
| 179 |
+
negative_keywords = ["不对", "错误", "不正确", "不准确", "不合适", "否定", "错的", "错"]
|
| 180 |
+
ambiguous_keywords = ["对错", "是否正确", "否正确", "或者", "是否", "正确性", "对不"]
|
| 181 |
+
|
| 182 |
+
def judge_similarity(pred_list, positive_keywords, negative_keywords):
|
| 183 |
+
positive_count = 0
|
| 184 |
+
negative_count = 0
|
| 185 |
+
|
| 186 |
+
for pred in pred_list:
|
| 187 |
+
if any(pos_word in pred for pos_word in positive_keywords):
|
| 188 |
+
positive_count += 1
|
| 189 |
+
elif any(neg_word in pred for neg_word in negative_keywords):
|
| 190 |
+
negative_count += 1
|
| 191 |
+
|
| 192 |
+
if positive_count > negative_count:
|
| 193 |
+
return "对"
|
| 194 |
+
elif negative_count > positive_count:
|
| 195 |
+
return "错"
|
| 196 |
+
else:
|
| 197 |
+
return random.choice(["对", "错"])
|
| 198 |
+
|
| 199 |
+
answer = entry["answer"]
|
| 200 |
+
parsed_pred = [word for word in parsed_pred if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
|
| 201 |
+
result = judge_similarity(parsed_pred, positive_keywords, negative_keywords)
|
| 202 |
+
if result == answer:
|
| 203 |
+
correct_cnt += 1
|
| 204 |
+
correct = True
|
| 205 |
+
if correct:
|
| 206 |
+
entry["judge"] = "正确"
|
| 207 |
+
else:
|
| 208 |
+
entry["judge"] = "错误"
|
| 209 |
+
|
| 210 |
+
if len(entries) == 0:
|
| 211 |
+
print("entries_num == 0, please check your file")
|
| 212 |
+
results_count = {"correct_num": 0, "entries_num": 0, "acc": 0}
|
| 213 |
+
else:
|
| 214 |
+
results_count = {"correct_num": correct_cnt, "entries_num": len(entries), "acc": correct_cnt / len(entries)}
|
| 215 |
+
|
| 216 |
+
return results_count
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_multi_choice_prediction(response, all_choices, index2ans):
|
| 220 |
+
for char in [",", ".", "!", "?", ";", ":", "'"]:
|
| 221 |
+
response = response.strip(char)
|
| 222 |
+
response = " " + response + " " # add space to avoid partial match
|
| 223 |
+
|
| 224 |
+
candidates = []
|
| 225 |
+
|
| 226 |
+
for choice in all_choices: # (A) (B) (C) (D)
|
| 227 |
+
# Add the choice to candidates each time it appears in the response
|
| 228 |
+
candidates.extend([choice for _ in range(response.count(f"({choice})"))])
|
| 229 |
+
|
| 230 |
+
if len(candidates) == 0:
|
| 231 |
+
for choice in all_choices: # A B C D
|
| 232 |
+
# Similarly, add the choice for each occurrence
|
| 233 |
+
candidates.extend([choice for _ in range(response.count(f"{choice}"))])
|
| 234 |
+
|
| 235 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
| 236 |
+
for index, ans in index2ans.items():
|
| 237 |
+
# Add index for each occurrence of ans in response
|
| 238 |
+
candidates.extend([index for _ in range(response.count(ans))])
|
| 239 |
+
|
| 240 |
+
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
|
| 241 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
| 242 |
+
for index, ans in index2ans.items():
|
| 243 |
+
if ans in response:
|
| 244 |
+
candidates.append(index)
|
| 245 |
+
index_ans = False # it's content ans.
|
| 246 |
+
|
| 247 |
+
if len(candidates) == 0: # still not get answer, randomly choose one.
|
| 248 |
+
return random.choice(all_choices)
|
| 249 |
+
# return ''
|
| 250 |
+
else:
|
| 251 |
+
# Count the occurrence of each candidate
|
| 252 |
+
candidate_counts = Counter(candidates)
|
| 253 |
+
|
| 254 |
+
# Select the most frequent candidates
|
| 255 |
+
max_count = max(candidate_counts.values())
|
| 256 |
+
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
|
| 257 |
+
|
| 258 |
+
# Combine the most frequent candidates in ABCD order
|
| 259 |
+
return "".join(most_frequent_candidates)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def extract_numbers(string):
|
| 263 |
+
# Pattern for numbers with Chinese commas
|
| 264 |
+
pattern_commas = r"-?\d{1,3}(?:,\d{3})+"
|
| 265 |
+
# Pattern for scientific notation
|
| 266 |
+
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
|
| 267 |
+
# Pattern for simple numbers without Chinese commas
|
| 268 |
+
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)"
|
| 269 |
+
|
| 270 |
+
# Extract numbers with Chinese commas
|
| 271 |
+
numbers_with_commas = re.findall(pattern_commas, string)
|
| 272 |
+
# Extract numbers in scientific notation
|
| 273 |
+
numbers_scientific = re.findall(pattern_scientific, string)
|
| 274 |
+
# Extract simple numbers without Chinese commas
|
| 275 |
+
numbers_simple = re.findall(pattern_simple, string)
|
| 276 |
+
|
| 277 |
+
# Combine all extracted numbers
|
| 278 |
+
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
|
| 279 |
+
return all_numbers
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def check_is_number(string):
|
| 283 |
+
try:
|
| 284 |
+
float(string.replace(",", ""))
|
| 285 |
+
return True
|
| 286 |
+
except ValueError:
|
| 287 |
+
# check if there's comma inside
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def count_letters(string):
|
| 292 |
+
return sum(c.isalpha() and "a" <= c <= "z" or "A" <= c <= "Z" for c in string)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def normalize_str(string, answer):
|
| 296 |
+
# check if characters in the string
|
| 297 |
+
|
| 298 |
+
# if number, numerize it.
|
| 299 |
+
if string == None:
|
| 300 |
+
return [string]
|
| 301 |
+
string = string.strip()
|
| 302 |
+
|
| 303 |
+
is_number = check_is_number(string)
|
| 304 |
+
|
| 305 |
+
if is_number:
|
| 306 |
+
string = string.replace(",", "")
|
| 307 |
+
string = float(string)
|
| 308 |
+
# leave 2 decimal
|
| 309 |
+
string = round(string, 2)
|
| 310 |
+
return [string]
|
| 311 |
+
else: # it's likely to be a string
|
| 312 |
+
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
|
| 313 |
+
return []
|
| 314 |
+
return [string]
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_fill_blank_prediction(response, answer):
|
| 318 |
+
"""get the prediction from the generated response,
|
| 319 |
+
return a list of predicted strings or numbers"""
|
| 320 |
+
|
| 321 |
+
def get_key_subresponses(response):
|
| 322 |
+
key_responses = []
|
| 323 |
+
response = response.strip("。").strip()
|
| 324 |
+
sub_responses = re.split(r"。|\n", response)
|
| 325 |
+
indicators_of_keys = ["是", "为", "所以", "等于", "方案", "选择", "正确答案", "因此", "最后", "答案", "结果"]
|
| 326 |
+
key_responses = []
|
| 327 |
+
for index, resp in enumerate(sub_responses):
|
| 328 |
+
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
|
| 329 |
+
if index == len(sub_responses) - 1:
|
| 330 |
+
indicators_of_keys.extend(["="])
|
| 331 |
+
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
|
| 332 |
+
for indicator in indicators_of_keys:
|
| 333 |
+
if indicator in resp:
|
| 334 |
+
if not shortest_key_response:
|
| 335 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 336 |
+
else:
|
| 337 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
| 338 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 339 |
+
|
| 340 |
+
if shortest_key_response:
|
| 341 |
+
# and it's not trivial
|
| 342 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
| 343 |
+
key_responses.append(shortest_key_response)
|
| 344 |
+
if len(key_responses) == 0: # did not found any
|
| 345 |
+
return [response]
|
| 346 |
+
return key_responses
|
| 347 |
+
|
| 348 |
+
key_responses = get_key_subresponses(response)
|
| 349 |
+
|
| 350 |
+
pred_list = key_responses.copy() # keep the original string response
|
| 351 |
+
for resp in key_responses:
|
| 352 |
+
pred_list.extend(extract_numbers(resp))
|
| 353 |
+
|
| 354 |
+
tmp_pred_list = []
|
| 355 |
+
for i in range(len(pred_list)):
|
| 356 |
+
tmp_pred_list.extend(normalize_str(pred_list[i], answer))
|
| 357 |
+
pred_list = tmp_pred_list
|
| 358 |
+
|
| 359 |
+
# remove duplicates
|
| 360 |
+
pred_list = list(set(pred_list))
|
| 361 |
+
|
| 362 |
+
return pred_list
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_TF_prediction(response):
|
| 366 |
+
"""get the prediction from the generated response,
|
| 367 |
+
return a list of predicted strings or numbers"""
|
| 368 |
+
|
| 369 |
+
def get_key_subresponses(response):
|
| 370 |
+
key_responses = []
|
| 371 |
+
response = response.strip("。").strip()
|
| 372 |
+
sub_responses = re.split(r"。|\n", response)
|
| 373 |
+
indicators_of_keys = ["是", "为", "所以", "判断", "陈述", "说法", "表达", "答案", "结果"]
|
| 374 |
+
key_responses = []
|
| 375 |
+
for index, resp in enumerate(sub_responses):
|
| 376 |
+
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
|
| 377 |
+
for indicator in indicators_of_keys:
|
| 378 |
+
if indicator in resp:
|
| 379 |
+
if not shortest_key_response:
|
| 380 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 381 |
+
else:
|
| 382 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
| 383 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
| 384 |
+
|
| 385 |
+
if shortest_key_response:
|
| 386 |
+
# and it's not trivial
|
| 387 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
| 388 |
+
key_responses.append(shortest_key_response)
|
| 389 |
+
if len(key_responses) == 0: # did not found any
|
| 390 |
+
return [response]
|
| 391 |
+
return key_responses
|
| 392 |
+
|
| 393 |
+
key_responses = get_key_subresponses(response)
|
| 394 |
+
|
| 395 |
+
pred_list = key_responses.copy() # keep the original string response
|
| 396 |
+
# remove duplicates
|
| 397 |
+
pred_list = list(set(pred_list))
|
| 398 |
+
|
| 399 |
+
return pred_list
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def get_multi_choice_info(options):
|
| 403 |
+
start_chr = "A"
|
| 404 |
+
all_choices = []
|
| 405 |
+
index2ans = {}
|
| 406 |
+
for i, option in enumerate(options):
|
| 407 |
+
index2ans[chr(ord(start_chr) + i)] = option
|
| 408 |
+
all_choices.append(chr(ord(start_chr) + i))
|
| 409 |
+
|
| 410 |
+
return index2ans, all_choices
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def calculate_ins_level_acc(results):
|
| 414 |
+
correct_sum = 0
|
| 415 |
+
entries_sum = 0
|
| 416 |
+
for cat_results in results.values():
|
| 417 |
+
correct_sum += cat_results["correct_num"]
|
| 418 |
+
entries_sum += cat_results["entries_num"]
|
| 419 |
+
if entries_sum == 0:
|
| 420 |
+
return 0
|
| 421 |
+
return correct_sum / entries_sum
|
EAGLE/lmms_eval/tasks/gqa/gqa.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/GQA
|
| 2 |
+
dataset_name: testdev_balanced_instructions
|
| 3 |
+
dataset_kwargs:
|
| 4 |
+
token: True
|
| 5 |
+
task: "gqa"
|
| 6 |
+
test_split: testdev
|
| 7 |
+
output_type: generate_until
|
| 8 |
+
doc_to_visual: !function utils.gqa_doc_to_visual
|
| 9 |
+
doc_to_text: !function utils.gqa_doc_to_text
|
| 10 |
+
doc_to_target: "answer"
|
| 11 |
+
generation_kwargs:
|
| 12 |
+
max_new_tokens: 16
|
| 13 |
+
temperature: 0
|
| 14 |
+
top_p: 0
|
| 15 |
+
num_beams: 1
|
| 16 |
+
do_sample: false
|
| 17 |
+
metric_list:
|
| 18 |
+
- metric: exact_match
|
| 19 |
+
aggregation: mean
|
| 20 |
+
higher_is_better: true
|
| 21 |
+
ignore_case: true
|
| 22 |
+
ignore_punctuation: true
|
| 23 |
+
metadata:
|
| 24 |
+
- version: 0.0
|
| 25 |
+
|
| 26 |
+
model_specific_prompt_kwargs:
|
| 27 |
+
default:
|
| 28 |
+
pre_prompt: ""
|
| 29 |
+
post_prompt: "\nAnswer the question using a single word or phrase."
|
| 30 |
+
qwen_vl:
|
| 31 |
+
pre_prompt: ""
|
| 32 |
+
post_prompt: " Answer:"
|
EAGLE/lmms_eval/tasks/gqa/utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
|
| 3 |
+
GQA_RAW_IMAGE_DATASET = None
|
| 4 |
+
GQA_ID2IMAGE = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def gqa_doc_to_visual(doc):
|
| 8 |
+
global GQA_RAW_IMAGE_DATASET
|
| 9 |
+
global GQA_ID2IMAGE
|
| 10 |
+
if GQA_RAW_IMAGE_DATASET is None:
|
| 11 |
+
GQA_RAW_IMAGE_DATASET = load_dataset("lmms-lab/GQA", "testdev_balanced_images", split="testdev", token=True)
|
| 12 |
+
GQA_ID2IMAGE = {}
|
| 13 |
+
for row in GQA_RAW_IMAGE_DATASET:
|
| 14 |
+
GQA_ID2IMAGE[row["id"]] = row["image"].convert("RGB")
|
| 15 |
+
image = GQA_ID2IMAGE[doc["imageId"]]
|
| 16 |
+
return [image]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gqa_doc_to_text(doc, model_specific_prompt_kwargs):
|
| 20 |
+
question = doc["question"]
|
| 21 |
+
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
|
| 22 |
+
post_prompt = model_specific_prompt_kwargs["post_prompt"]
|
| 23 |
+
return f"{pre_prompt}{question}{post_prompt}"
|
EAGLE/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/llava-bench-in-the-wild
|
| 2 |
+
dataset_kwargs:
|
| 3 |
+
token: True
|
| 4 |
+
task: "llava_in_the_wild"
|
| 5 |
+
test_split: train
|
| 6 |
+
output_type: generate_until
|
| 7 |
+
doc_to_visual: !function utils.llava_doc_to_visual
|
| 8 |
+
doc_to_text: !function utils.llava_doc_to_text
|
| 9 |
+
doc_to_target: "gpt_answer"
|
| 10 |
+
generation_kwargs:
|
| 11 |
+
until:
|
| 12 |
+
- "ASSISTANT:"
|
| 13 |
+
image_aspect_ratio: original
|
| 14 |
+
max_new_tokens: 1024
|
| 15 |
+
temperature: 0
|
| 16 |
+
top_p: 0
|
| 17 |
+
num_beams: 1
|
| 18 |
+
do_sample: false
|
| 19 |
+
process_results: !function utils.llava_process_results
|
| 20 |
+
metric_list:
|
| 21 |
+
- metric: gpt_eval_llava_all
|
| 22 |
+
aggregation: !function utils.llava_all_aggregation
|
| 23 |
+
higher_is_better: true
|
| 24 |
+
- metric: gpt_eval_llava_conv
|
| 25 |
+
aggregation: !function utils.llava_conv_aggregation
|
| 26 |
+
higher_is_better: true
|
| 27 |
+
- metric: gpt_eval_llava_detail
|
| 28 |
+
aggregation: !function utils.llava_detail_aggregation
|
| 29 |
+
higher_is_better: true
|
| 30 |
+
- metric: gpt_eval_llava_complex
|
| 31 |
+
aggregation: !function utils.llava_complex_aggregation
|
| 32 |
+
higher_is_better: true
|
| 33 |
+
metadata:
|
| 34 |
+
version: 0.0
|
| 35 |
+
gpt_eval_model_name: "gpt-4-0613"
|
| 36 |
+
model_specific_prompt_kwargs:
|
| 37 |
+
default:
|
| 38 |
+
pre_prompt: ""
|
| 39 |
+
post_prompt: ""
|
EAGLE/lmms_eval/tasks/llava-in-the-wild/rule.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."},
|
| 3 |
+
"math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."},
|
| 4 |
+
"default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 5 |
+
"conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 6 |
+
"detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 7 |
+
"complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 8 |
+
"llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 9 |
+
"llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 10 |
+
"llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
|
| 11 |
+
}
|
EAGLE/lmms_eval/tasks/llava-in-the-wild/utils.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import requests
|
| 5 |
+
import numpy as np
|
| 6 |
+
import openai
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
import time
|
| 9 |
+
import yaml
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
|
| 13 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 14 |
+
NUM_SECONDS_TO_SLEEP = 5
|
| 15 |
+
|
| 16 |
+
LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_complex"]
|
| 17 |
+
|
| 18 |
+
rule_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rule.json"), "r"))
|
| 19 |
+
|
| 20 |
+
with open(Path(__file__).parent / "llava-in-the-wild.yaml", "r") as f:
|
| 21 |
+
raw_data = f.readlines()
|
| 22 |
+
safe_data = []
|
| 23 |
+
for i, line in enumerate(raw_data):
|
| 24 |
+
# remove function definition since yaml load cannot handle it
|
| 25 |
+
if "!function" not in line:
|
| 26 |
+
safe_data.append(line)
|
| 27 |
+
|
| 28 |
+
config = yaml.safe_load("".join(safe_data))
|
| 29 |
+
|
| 30 |
+
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
|
| 31 |
+
|
| 32 |
+
API_TYPE = os.getenv("API_TYPE", "openai")
|
| 33 |
+
|
| 34 |
+
if API_TYPE == "openai":
|
| 35 |
+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
|
| 36 |
+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
|
| 37 |
+
headers = {
|
| 38 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 39 |
+
"Content-Type": "application/json",
|
| 40 |
+
}
|
| 41 |
+
elif API_TYPE == "azure":
|
| 42 |
+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
|
| 43 |
+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
|
| 44 |
+
headers = {
|
| 45 |
+
"api-key": API_KEY,
|
| 46 |
+
"Content-Type": "application/json",
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_eval(content: str, max_tokens: int, retries: int = 5):
|
| 51 |
+
global headers
|
| 52 |
+
|
| 53 |
+
messages = [
|
| 54 |
+
{
|
| 55 |
+
"role": "system",
|
| 56 |
+
"content": "You are a helpful and precise assistant for checking the quality of the answer.",
|
| 57 |
+
},
|
| 58 |
+
{"role": "user", "content": content},
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
payload = {
|
| 62 |
+
"model": GPT_EVAL_MODEL_NAME,
|
| 63 |
+
"messages": messages,
|
| 64 |
+
"temperature": 0.2,
|
| 65 |
+
"max_tokens": max_tokens,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
for attempt in range(retries):
|
| 69 |
+
try:
|
| 70 |
+
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
|
| 71 |
+
response.raise_for_status()
|
| 72 |
+
response_data = response.json()
|
| 73 |
+
|
| 74 |
+
content = response_data["choices"][0]["message"]["content"].strip()
|
| 75 |
+
if content != "":
|
| 76 |
+
return content, response_data["model"]
|
| 77 |
+
break # If successful, break out of the loop
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
|
| 81 |
+
if attempt < retries: # If we have retries left, sleep and then continue to next attempt
|
| 82 |
+
time.sleep(NUM_SECONDS_TO_SLEEP)
|
| 83 |
+
else: # If this was the last attempt, log and return empty
|
| 84 |
+
eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
|
| 85 |
+
return "", ""
|
| 86 |
+
return "", ""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def parse_score(review):
|
| 90 |
+
try:
|
| 91 |
+
score_pair = review.split("\n")[0]
|
| 92 |
+
score_pair = score_pair.replace(",", " ")
|
| 93 |
+
sp = score_pair.split(" ")
|
| 94 |
+
if len(sp) == 2:
|
| 95 |
+
return [float(sp[0]), float(sp[1])]
|
| 96 |
+
else:
|
| 97 |
+
eval_logger.debug(f"Can not split: {review}. Returning [-1, -1]")
|
| 98 |
+
return [-1, -1]
|
| 99 |
+
except Exception as e:
|
| 100 |
+
eval_logger.debug(f"Error: {e}. Returning [-1, -1]")
|
| 101 |
+
return [-1, -1]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def llava_doc_to_visual(doc):
|
| 105 |
+
return [doc["image"].convert("RGB")]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def llava_doc_to_text(doc, model_specific_prompt_kwargs=None):
|
| 109 |
+
if model_specific_prompt_kwargs is None:
|
| 110 |
+
model_specific_prompt_kwargs = {}
|
| 111 |
+
pre_prompt = model_specific_prompt_kwargs.get("pre_prompt", "")
|
| 112 |
+
post_prompt = model_specific_prompt_kwargs.get("post_prompt", "")
|
| 113 |
+
return f"{pre_prompt}{doc['question']}{post_prompt}"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def llava_process_results(doc, result):
|
| 117 |
+
"""
|
| 118 |
+
Args:
|
| 119 |
+
doc: a instance of the eval dataset
|
| 120 |
+
results: [pred]
|
| 121 |
+
Returns:
|
| 122 |
+
a dictionary with key: metric name (in this case coco_bleu), value: metric value
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
question = doc.get("question", "")
|
| 126 |
+
ans1 = doc.get("gpt_answer", "")
|
| 127 |
+
ans2 = result[0] if result else ""
|
| 128 |
+
captions = doc.get("caption", [])
|
| 129 |
+
context = "\n".join(captions) if isinstance(captions, list) else captions
|
| 130 |
+
category = "llava_bench_" + doc.get("category", "")
|
| 131 |
+
rule = rule_dict.get(category, {})
|
| 132 |
+
prompt = rule.get("prompt", "")
|
| 133 |
+
role = rule.get("role", "user")
|
| 134 |
+
content = f"[Context]\n{context}\n\n" f"[Question]\n{question}\n\n" f"[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n" f"[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n" f"[System]\n{prompt}\n\n"
|
| 135 |
+
|
| 136 |
+
review, model_name = get_eval(content, 1024)
|
| 137 |
+
scores = parse_score(review)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
eval_logger.error(f"Error for Question ID: {doc.get('question_id', 'Unknown')}: {e}")
|
| 140 |
+
review = "Failed to Get a Proper Review."
|
| 141 |
+
model_name = "Failed Request"
|
| 142 |
+
scores = [-1, -1]
|
| 143 |
+
|
| 144 |
+
metric = f"gpt_eval_llava_{doc.get('category', 'all')}"
|
| 145 |
+
category_review_dict = {"question": question, "ans1": ans1, "ans2": ans2, "context": context, "category": category, "review": review, "scores": scores, "eval_model": model_name, "content": content}
|
| 146 |
+
|
| 147 |
+
non_category_review_dict = deepcopy(category_review_dict)
|
| 148 |
+
non_category_review_dict["scores"] = [-999, -999]
|
| 149 |
+
|
| 150 |
+
data_dict = {}
|
| 151 |
+
for m in LLAVA_W_METRICS:
|
| 152 |
+
if m == metric:
|
| 153 |
+
data_dict[m] = category_review_dict
|
| 154 |
+
else:
|
| 155 |
+
data_dict[m] = non_category_review_dict
|
| 156 |
+
data_dict["gpt_eval_llava_all"] = category_review_dict
|
| 157 |
+
|
| 158 |
+
# return {"gpt_eval_llava_all": review_dict}
|
| 159 |
+
return data_dict
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def llava_conv_aggregation(results):
|
| 163 |
+
return llava_aggregation(results, "conv")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def llava_complex_aggregation(results):
|
| 167 |
+
return llava_aggregation(results, "complex")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def llava_detail_aggregation(results):
|
| 171 |
+
return llava_aggregation(results, "detail")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def llava_all_aggregation(results):
|
| 175 |
+
return llava_aggregation(results, "all")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def llava_aggregation(results, category):
|
| 179 |
+
try:
|
| 180 |
+
scores = []
|
| 181 |
+
for result in results:
|
| 182 |
+
if -999 in result["scores"]:
|
| 183 |
+
continue
|
| 184 |
+
scores.append(result["scores"])
|
| 185 |
+
|
| 186 |
+
stats = np.asarray(scores).mean(0).tolist()
|
| 187 |
+
stats = [round(x, 3) for x in stats]
|
| 188 |
+
# gpt4_score_percentage = stats[0] * 10
|
| 189 |
+
# model_score_percentage = stats[1] * 10
|
| 190 |
+
# eval_logger.info(f"Category: {category}")
|
| 191 |
+
# eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%")
|
| 192 |
+
# eval_logger.info(f"Model Score: {model_score_percentage:.1f}%")
|
| 193 |
+
# eval_logger.info("=========================")
|
| 194 |
+
return round(stats[1] / stats[0] * 100, 1)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
eval_logger.info(f"Error in llava_aggregation: {e}, and in category: {category}")
|
| 197 |
+
return None
|
EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_cn_yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/MMBench
|
| 2 |
+
dataset_kwargs:
|
| 3 |
+
token: True
|
| 4 |
+
doc_to_target: "answer"
|
| 5 |
+
dataset_name: "cn"
|
| 6 |
+
output_type: generate_until
|
| 7 |
+
doc_to_visual: !function cn_utils.mmbench_doc_to_visual
|
| 8 |
+
doc_to_text: !function cn_utils.mmbench_doc_to_text
|
| 9 |
+
generation_kwargs:
|
| 10 |
+
max_new_tokens: 256
|
| 11 |
+
temperature: 0
|
| 12 |
+
top_p: 0
|
| 13 |
+
num_beams: 1
|
| 14 |
+
do_sample: false
|
| 15 |
+
process_results: !function cn_utils.mmbench_process_results
|
| 16 |
+
model_specific_prompt_kwargs:
|
| 17 |
+
default:
|
| 18 |
+
pre_prompt: ""
|
| 19 |
+
post_prompt: "\n请直接使用所提供的选项字母作为答案回答。"
|
| 20 |
+
model_specific_generation_kwargs:
|
| 21 |
+
llava:
|
| 22 |
+
image_aspect_ratio: original
|
EAGLE/lmms_eval/tasks/mmbench/_default_template_mmbench_en_yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/MMBench
|
| 2 |
+
dataset_kwargs:
|
| 3 |
+
token: True
|
| 4 |
+
doc_to_target: "answer"
|
| 5 |
+
model_specific_prompt_kwargs:
|
| 6 |
+
default:
|
| 7 |
+
pre_prompt: ""
|
| 8 |
+
post_prompt: "\nAnswer with the option's letter from the given choices directly."
|
| 9 |
+
doc_to_visual: !function en_utils.mmbench_doc_to_visual
|
| 10 |
+
doc_to_text: !function en_utils.mmbench_doc_to_text
|
| 11 |
+
doc_to_target: "answer"
|
| 12 |
+
process_results: !function en_utils.mmbench_process_results
|
| 13 |
+
model_specific_generation_kwargs:
|
| 14 |
+
llava:
|
| 15 |
+
image_aspect_ratio: original
|
| 16 |
+
output_type: generate_until
|
| 17 |
+
dataset_name: "en"
|
| 18 |
+
generation_kwargs:
|
| 19 |
+
until:
|
| 20 |
+
- "ASSISTANT:"
|
| 21 |
+
max_new_tokens: 1024
|
| 22 |
+
temperature: 0
|
| 23 |
+
top_p: 0
|
| 24 |
+
num_beams: 1
|
| 25 |
+
do_sample: false
|
EAGLE/lmms_eval/tasks/mmbench/cc_utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import yaml
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 9 |
+
from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
|
| 10 |
+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
|
| 11 |
+
|
| 12 |
+
with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
|
| 13 |
+
raw_data = f.readlines()
|
| 14 |
+
safe_data = []
|
| 15 |
+
for i, line in enumerate(raw_data):
|
| 16 |
+
# remove function definition since yaml load cannot handle it
|
| 17 |
+
if "!function" not in line:
|
| 18 |
+
safe_data.append(line)
|
| 19 |
+
|
| 20 |
+
config = yaml.safe_load("".join(safe_data))
|
| 21 |
+
|
| 22 |
+
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
|
| 23 |
+
API_TYPE = os.getenv("API_TYPE", "openai")
|
| 24 |
+
|
| 25 |
+
if API_TYPE == "openai":
|
| 26 |
+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
|
| 27 |
+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
|
| 28 |
+
elif API_TYPE == "azure":
|
| 29 |
+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
|
| 30 |
+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def mmbench_doc_to_visual(doc):
|
| 37 |
+
return [doc["image"].convert("RGB")]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def mmbench_cn_cc_doc_to_text(doc, model_specific_prompt_kwargs=None):
|
| 41 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 42 |
+
options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
|
| 43 |
+
|
| 44 |
+
data = {
|
| 45 |
+
# "img": doc["image"],
|
| 46 |
+
"question": doc["question"],
|
| 47 |
+
"answer": doc.get("answer", None),
|
| 48 |
+
"options": options_prompt,
|
| 49 |
+
"category": doc["category"],
|
| 50 |
+
"options_dict": options_dict,
|
| 51 |
+
"index": doc["index"],
|
| 52 |
+
"source": doc["source"],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
query_prompt = f"{data['question']} {data['options']}"
|
| 56 |
+
|
| 57 |
+
if model_specific_prompt_kwargs:
|
| 58 |
+
query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
|
| 59 |
+
|
| 60 |
+
return query_prompt
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def mmbench_cn_cc_process_results(doc, results):
|
| 64 |
+
model_response = results[0].strip()
|
| 65 |
+
data = {
|
| 66 |
+
"gpt_eval_score": {
|
| 67 |
+
"index": doc["index"],
|
| 68 |
+
"question": doc["question"],
|
| 69 |
+
"answer": doc["answer"],
|
| 70 |
+
"prediction": model_response,
|
| 71 |
+
"source": doc["source"],
|
| 72 |
+
"category": doc["category"],
|
| 73 |
+
},
|
| 74 |
+
"submission": {
|
| 75 |
+
"index": doc["index"],
|
| 76 |
+
"question": doc["question"],
|
| 77 |
+
"answer": doc["answer"],
|
| 78 |
+
"prediction": model_response,
|
| 79 |
+
"source": doc["source"],
|
| 80 |
+
"category": doc["category"],
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 84 |
+
for c in option_candidate:
|
| 85 |
+
data["submission"][c] = doc.get(c, "nan")
|
| 86 |
+
data["gpt_eval_score"][c] = doc.get(c, "nan")
|
| 87 |
+
return data
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def mmbench_cn_cc_aggregate_dev_results_eval(results, args):
|
| 91 |
+
print(f"============= MMBench-CN(CC) Detailed Results =============")
|
| 92 |
+
overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
|
| 93 |
+
file = generate_submission_file("mmbench_cn_cc_results.json", args)
|
| 94 |
+
details_info = {
|
| 95 |
+
"overall_acc": overall_acc,
|
| 96 |
+
"category_acc": category_acc,
|
| 97 |
+
"l2_category_acc": l2_category_acc,
|
| 98 |
+
}
|
| 99 |
+
with open(file, "w") as f:
|
| 100 |
+
json.dump(details_info, f)
|
| 101 |
+
return overall_acc * 100
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def mmbench_cn_cc_aggregate_results(results, args):
|
| 105 |
+
df = pd.DataFrame(results)
|
| 106 |
+
file = generate_submission_file("mmbench_cn_cc_results.xlsx", args)
|
| 107 |
+
with pd.ExcelWriter(file) as writer:
|
| 108 |
+
df.to_excel(writer, index=False)
|
| 109 |
+
eval_logger.info(f"Saved results to {file}")
|
EAGLE/lmms_eval/tasks/mmbench/cn_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import yaml
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import json
|
| 7 |
+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
|
| 8 |
+
|
| 9 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 10 |
+
from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
|
| 11 |
+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
|
| 12 |
+
|
| 13 |
+
with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
|
| 14 |
+
raw_data = f.readlines()
|
| 15 |
+
safe_data = []
|
| 16 |
+
for i, line in enumerate(raw_data):
|
| 17 |
+
# remove function definition since yaml load cannot handle it
|
| 18 |
+
if "!function" not in line:
|
| 19 |
+
safe_data.append(line)
|
| 20 |
+
|
| 21 |
+
config = yaml.safe_load("".join(safe_data))
|
| 22 |
+
|
| 23 |
+
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
|
| 24 |
+
API_TYPE = os.getenv("API_TYPE", "openai")
|
| 25 |
+
|
| 26 |
+
if API_TYPE == "openai":
|
| 27 |
+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
|
| 28 |
+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
|
| 29 |
+
elif API_TYPE == "azure":
|
| 30 |
+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
|
| 31 |
+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def mmbench_doc_to_visual(doc):
|
| 38 |
+
return [doc["image"].convert("RGB")]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None):
|
| 42 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 43 |
+
options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
|
| 44 |
+
|
| 45 |
+
data = {
|
| 46 |
+
# "img": doc["image"],
|
| 47 |
+
"question": doc["question"],
|
| 48 |
+
"answer": doc.get("answer", None),
|
| 49 |
+
"options": options_prompt,
|
| 50 |
+
"category": doc["category"],
|
| 51 |
+
"L2-category": doc["L2-category"],
|
| 52 |
+
"options_dict": options_dict,
|
| 53 |
+
"index": doc["index"],
|
| 54 |
+
"hint": doc["hint"],
|
| 55 |
+
"source": doc["source"],
|
| 56 |
+
"split": doc["split"],
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
query_prompt = f"{data['hint']} {data['question']} {data['options']}" if pd.notna(data["hint"]) else f"{data['question']} {data['options']}"
|
| 60 |
+
|
| 61 |
+
if model_specific_prompt_kwargs:
|
| 62 |
+
query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
|
| 63 |
+
|
| 64 |
+
return query_prompt
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def mmbench_process_results(doc, results):
|
| 68 |
+
model_response = results[0].strip()
|
| 69 |
+
data = {
|
| 70 |
+
"gpt_eval_score": {
|
| 71 |
+
"index": doc["index"],
|
| 72 |
+
"question": doc["question"],
|
| 73 |
+
"answer": doc["answer"],
|
| 74 |
+
"prediction": model_response,
|
| 75 |
+
"hint": doc["hint"],
|
| 76 |
+
"source": doc["source"],
|
| 77 |
+
"split": doc["split"],
|
| 78 |
+
"category": doc["category"],
|
| 79 |
+
"L2-category": doc["L2-category"],
|
| 80 |
+
},
|
| 81 |
+
"submission": {
|
| 82 |
+
"index": doc["index"],
|
| 83 |
+
"question": doc["question"],
|
| 84 |
+
"answer": doc["answer"],
|
| 85 |
+
"prediction": model_response,
|
| 86 |
+
"hint": doc["hint"],
|
| 87 |
+
"source": doc["source"],
|
| 88 |
+
"split": doc["split"],
|
| 89 |
+
"category": doc["category"],
|
| 90 |
+
"L2-category": doc["L2-category"],
|
| 91 |
+
},
|
| 92 |
+
}
|
| 93 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 94 |
+
for c in option_candidate:
|
| 95 |
+
data["submission"][c] = doc.get(c, "nan")
|
| 96 |
+
data["gpt_eval_score"][c] = doc.get(c, "nan")
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def mmbench_aggregate_dev_results_eval(results, args):
|
| 101 |
+
print(f"============= MMBench-CN(Dev) Detailed Results =============")
|
| 102 |
+
overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
|
| 103 |
+
file = generate_submission_file("mmbench_cn_dev_results.json", args)
|
| 104 |
+
details_info = {
|
| 105 |
+
"overall_acc": overall_acc,
|
| 106 |
+
"category_acc": category_acc,
|
| 107 |
+
"l2_category_acc": l2_category_acc,
|
| 108 |
+
}
|
| 109 |
+
with open(file, "w") as f:
|
| 110 |
+
json.dump(details_info, f)
|
| 111 |
+
return overall_acc * 100
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def mmbench_aggregate_dev_results(results, args):
|
| 115 |
+
df = pd.DataFrame(results)
|
| 116 |
+
excel_write_path = generate_submission_file("mmbench_cn_dev_results.xlsx", args)
|
| 117 |
+
with pd.ExcelWriter(excel_write_path) as writer:
|
| 118 |
+
df.to_excel(writer, index=False)
|
| 119 |
+
eval_logger.info(f"Saved results to {excel_write_path}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def mmbench_aggregate_test_results(results, args):
|
| 123 |
+
df = pd.DataFrame(results)
|
| 124 |
+
excel_write_path = generate_submission_file("mmbench_cn_test_results.xlsx", args)
|
| 125 |
+
with pd.ExcelWriter(excel_write_path) as writer:
|
| 126 |
+
df.to_excel(writer, index=False)
|
| 127 |
+
eval_logger.info(f"Saved results to {excel_write_path}")
|
EAGLE/lmms_eval/tasks/mmbench/en_utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import yaml
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
eval_logger = logging.getLogger("lmms-eval")
|
| 9 |
+
from lmms_eval.tasks.mmbench.mmbench_evals import MMBench_Evaluator
|
| 10 |
+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
|
| 11 |
+
|
| 12 |
+
with open(Path(__file__).parent / "mmbench.yaml", "r") as f:
|
| 13 |
+
raw_data = f.readlines()
|
| 14 |
+
safe_data = []
|
| 15 |
+
for i, line in enumerate(raw_data):
|
| 16 |
+
# remove function definition since yaml load cannot handle it
|
| 17 |
+
if "!function" not in line:
|
| 18 |
+
safe_data.append(line)
|
| 19 |
+
|
| 20 |
+
config = yaml.safe_load("".join(safe_data))
|
| 21 |
+
|
| 22 |
+
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
|
| 23 |
+
API_TYPE = os.getenv("API_TYPE", "openai")
|
| 24 |
+
|
| 25 |
+
if API_TYPE == "openai":
|
| 26 |
+
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
|
| 27 |
+
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
|
| 28 |
+
elif API_TYPE == "azure":
|
| 29 |
+
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
|
| 30 |
+
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
mmbench_evaluator = MMBench_Evaluator(sys_prompt=config["metadata"]["sys_prompt"], API_KEY=API_KEY, API_URL=API_URL, model_version=GPT_EVAL_MODEL_NAME)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def mmbench_doc_to_visual(doc):
|
| 37 |
+
return [doc["image"].convert("RGB")]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def mmbench_doc_to_text(doc, model_specific_prompt_kwargs=None):
|
| 41 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 42 |
+
options_prompt, options_dict = mmbench_evaluator.create_options_prompt(doc, option_candidate)
|
| 43 |
+
|
| 44 |
+
data = {
|
| 45 |
+
# "img": doc["image"],
|
| 46 |
+
"question": doc["question"],
|
| 47 |
+
"answer": doc.get("answer", None),
|
| 48 |
+
"options": options_prompt,
|
| 49 |
+
"category": doc["category"],
|
| 50 |
+
"L2-category": doc["L2-category"],
|
| 51 |
+
"options_dict": options_dict,
|
| 52 |
+
"index": doc["index"],
|
| 53 |
+
"hint": doc["hint"],
|
| 54 |
+
"source": doc["source"],
|
| 55 |
+
"split": doc["split"],
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
query_prompt = f"{data['hint']} {data['question']} {data['options']}" if pd.notna(data["hint"]) and data["hint"] != "nan" else f"{data['question']} {data['options']}"
|
| 59 |
+
|
| 60 |
+
if model_specific_prompt_kwargs:
|
| 61 |
+
query_prompt = f"{query_prompt}\n{model_specific_prompt_kwargs['post_prompt']}"
|
| 62 |
+
|
| 63 |
+
return query_prompt
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def mmbench_process_results(doc, results):
|
| 67 |
+
model_response = results[0].strip()
|
| 68 |
+
data = {
|
| 69 |
+
"gpt_eval_score": {
|
| 70 |
+
"index": doc["index"],
|
| 71 |
+
"question": doc["question"],
|
| 72 |
+
"answer": doc["answer"],
|
| 73 |
+
"prediction": model_response,
|
| 74 |
+
"hint": doc["hint"],
|
| 75 |
+
"source": doc["source"],
|
| 76 |
+
"split": doc["split"],
|
| 77 |
+
"category": doc["category"],
|
| 78 |
+
"L2-category": doc["L2-category"],
|
| 79 |
+
},
|
| 80 |
+
"submission": {
|
| 81 |
+
"index": doc["index"],
|
| 82 |
+
"question": doc["question"],
|
| 83 |
+
"answer": doc["answer"],
|
| 84 |
+
"prediction": model_response,
|
| 85 |
+
"hint": doc["hint"],
|
| 86 |
+
"source": doc["source"],
|
| 87 |
+
"split": doc["split"],
|
| 88 |
+
"category": doc["category"],
|
| 89 |
+
"L2-category": doc["L2-category"],
|
| 90 |
+
},
|
| 91 |
+
}
|
| 92 |
+
option_candidate = ["A", "B", "C", "D", "E"]
|
| 93 |
+
for c in option_candidate:
|
| 94 |
+
data["submission"][c] = doc.get(c, "nan")
|
| 95 |
+
data["gpt_eval_score"][c] = doc.get(c, "nan")
|
| 96 |
+
return data
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def mmbench_aggregate_dev_results_eval(results, args):
|
| 100 |
+
print(f"============= MMBench-EN(Dev) Detailed Results =============")
|
| 101 |
+
overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="openai")
|
| 102 |
+
file = generate_submission_file("mmbench_en_dev_results.json", args)
|
| 103 |
+
details_info = {
|
| 104 |
+
"overall_acc": overall_acc,
|
| 105 |
+
"category_acc": category_acc,
|
| 106 |
+
"l2_category_acc": l2_category_acc,
|
| 107 |
+
}
|
| 108 |
+
with open(file, "w") as f:
|
| 109 |
+
json.dump(details_info, f)
|
| 110 |
+
return overall_acc * 100
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def mmbench_aggregate_dev_results_submission(results, args):
|
| 114 |
+
df = pd.DataFrame(results)
|
| 115 |
+
excel_write_path = generate_submission_file("mmbench_en_dev_results.xlsx", args)
|
| 116 |
+
with pd.ExcelWriter(excel_write_path) as writer:
|
| 117 |
+
df.to_excel(writer, index=False)
|
| 118 |
+
eval_logger.info(f"Saved results to {excel_write_path}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def mmbench_aggregate_test_results(results, args):
|
| 122 |
+
df = pd.DataFrame(results)
|
| 123 |
+
excel_write_path = generate_submission_file("mmbench_en_test_results.xlsx", args)
|
| 124 |
+
with pd.ExcelWriter(excel_write_path) as writer:
|
| 125 |
+
df.to_excel(writer, index=False)
|
| 126 |
+
eval_logger.info(f"Saved results to {excel_write_path}")
|
EAGLE/lmms_eval/tasks/mmbench/mmbench.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
group: mmbench
|
| 2 |
+
task:
|
| 3 |
+
- mmbench_en_dev
|
| 4 |
+
- mmbench_en_test
|
| 5 |
+
- mmbench_cn_dev
|
| 6 |
+
- mmbench_cn_test
|
| 7 |
+
- mmbench_cn_cc
|
| 8 |
+
metadata:
|
| 9 |
+
version: 0.0
|
| 10 |
+
sys_prompt: "There are several options:"
|
| 11 |
+
gpt_eval_model_name: "gpt-3.5-turbo-0613"
|
EAGLE/lmms_eval/tasks/mmbench/mmbench_cc.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset_path: lmms-lab/MMBench
|
| 2 |
+
dataset_name: cc
|
| 3 |
+
dataset_kwargs:
|
| 4 |
+
token: True
|
| 5 |
+
task: "mmbench_cn_cc"
|
| 6 |
+
test_split: test
|
| 7 |
+
output_type: generate_until
|
| 8 |
+
doc_to_visual: !function cc_utils.mmbench_doc_to_visual
|
| 9 |
+
doc_to_text: !function cc_utils.mmbench_cn_cc_doc_to_text
|
| 10 |
+
doc_to_target: "answer"
|
| 11 |
+
generation_kwargs:
|
| 12 |
+
max_new_tokens: 256
|
| 13 |
+
temperature: 0
|
| 14 |
+
top_p: 0
|
| 15 |
+
num_beams: 1
|
| 16 |
+
do_sample: false
|
| 17 |
+
process_results: !function cc_utils.mmbench_cn_cc_process_results
|
| 18 |
+
metric_list:
|
| 19 |
+
- metric: gpt_eval_score
|
| 20 |
+
aggregation: !function cc_utils.mmbench_cn_cc_aggregate_dev_results_eval
|
| 21 |
+
higher_is_better: true
|
| 22 |
+
- metric: submission
|
| 23 |
+
aggregation: !function cc_utils.mmbench_cn_cc_aggregate_results
|
| 24 |
+
metadata:
|
| 25 |
+
version: 0.0
|
| 26 |
+
gpt_eval_model_name: "gpt-3.5-turbo-0613"
|
| 27 |
+
|
| 28 |
+
model_specific_prompt_kwargs:
|
| 29 |
+
default:
|
| 30 |
+
pre_prompt: ""
|
| 31 |
+
post_prompt: "\n请直接使用所提供的选项字母作为答案回答。"
|
| 32 |
+
model_specific_generation_kwargs:
|
| 33 |
+
llava:
|
| 34 |
+
image_aspect_ratio: original
|
EAGLE/lmms_eval/tasks/mmbench/mmbench_cn.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
group: mmbench_cn
|
| 2 |
+
task:
|
| 3 |
+
- mmbench_cn_dev
|
| 4 |
+
- mmbench_cn_test
|
| 5 |
+
- mmbench_cn_cc
|
| 6 |
+
metadata:
|
| 7 |
+
version: 0.0
|
| 8 |
+
gpt_eval_model_name: "gpt-3.5-turbo-0613"
|
| 9 |
+
sys_prompt: "有如下几个选项:"
|