|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from accelerate import Accelerator, DistributedType |
|
|
from accelerate.state import AcceleratorState |
|
|
from torchvision.transforms.functional import to_pil_image |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoProcessor, Idefics2ForConditionalGeneration |
|
|
|
|
|
from lmms_eval import utils |
|
|
from lmms_eval.api.instance import Instance |
|
|
from lmms_eval.api.model import lmms |
|
|
from lmms_eval.api.registry import register_model |
|
|
from lmms_eval.models.model_utils.load_video import load_video_decord |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from loguru import logger as eval_logger |
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
try: |
|
|
import flash_attn |
|
|
|
|
|
best_fit_attn_implementation = "flash_attention_2" |
|
|
except ImportError: |
|
|
best_fit_attn_implementation = "eager" |
|
|
|
|
|
|
|
|
@register_model("idefics2") |
|
|
class Idefics2(lmms): |
|
|
""" |
|
|
Idefics2 Model for Hugging Face Transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py |
|
|
|
|
|
Example usage: |
|
|
|
|
|
accelerate launch --num_processes=8 -m lmms_eval \ |
|
|
--model idefics2 \ |
|
|
--model_args pretrained=HuggingFaceM4/idefics2-8b \ |
|
|
--tasks mme \ |
|
|
--batch_size 1 \ |
|
|
--output_path ./logs/ \ |
|
|
--log_samples |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pretrained: str = "HuggingFaceM4/idefics2-8b", |
|
|
revision: str = "main", |
|
|
device: str = "cuda", |
|
|
dtype: Optional[Union[str, torch.dtype]] = "float16", |
|
|
batch_size: int = 1, |
|
|
trust_remote_code: Optional[bool] = False, |
|
|
attn_implementation: Optional[str] = best_fit_attn_implementation, |
|
|
device_map: str = "", |
|
|
use_cache: bool = True, |
|
|
do_image_splitting: bool = False, |
|
|
max_frames_num: int = 16, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
assert kwargs == {}, f"Unexpected kwargs: {kwargs}" |
|
|
|
|
|
accelerator = Accelerator() |
|
|
if accelerator.num_processes > 1 and device_map == "": |
|
|
self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
|
|
self.device_map = f"cuda:{accelerator.local_process_index}" |
|
|
else: |
|
|
self._device = torch.device(device) |
|
|
self.device_map = device_map |
|
|
if isinstance(dtype, str) and dtype != "auto": |
|
|
dtype = getattr(torch, dtype) |
|
|
self._model = Idefics2ForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation) |
|
|
self._processor = AutoProcessor.from_pretrained(pretrained, do_image_splitting=do_image_splitting, revision=revision, trust_remote_code=trust_remote_code) |
|
|
self.max_frames_num = max_frames_num |
|
|
|
|
|
self._tokenizer = self._processor.tokenizer |
|
|
self._config = self._model.config |
|
|
self.batch_size_per_gpu = int(batch_size) |
|
|
self.use_cache = use_cache |
|
|
if accelerator.num_processes > 1 and device_map == "": |
|
|
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
kwargs = { |
|
|
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
|
|
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
|
|
} |
|
|
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
|
|
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
|
|
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
self._model = accelerator.prepare(self.model) |
|
|
else: |
|
|
self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
|
|
self.accelerator = accelerator |
|
|
if self.accelerator.is_local_main_process: |
|
|
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
|
|
self._rank = self.accelerator.local_process_index |
|
|
self._world_size = self.accelerator.num_processes |
|
|
elif accelerator.num_processes == 1 and device_map == "auto": |
|
|
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism") |
|
|
self._rank = 0 |
|
|
self._world_size = 1 |
|
|
else: |
|
|
eval_logger.info(f"Using single device: {self._device}") |
|
|
self.model.to(self._device) |
|
|
self._rank = 0 |
|
|
self._world_size = 1 |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
|
|
|
return self._config |
|
|
|
|
|
@property |
|
|
def tokenizer(self): |
|
|
return self._tokenizer |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
|
|
|
if hasattr(self, "accelerator"): |
|
|
return self.accelerator.unwrap_model(self._model) |
|
|
else: |
|
|
return self._model |
|
|
|
|
|
@property |
|
|
def eot_token_id(self): |
|
|
|
|
|
return self.tokenizer.eos_token_id |
|
|
|
|
|
@property |
|
|
def max_length(self): |
|
|
return self._max_length |
|
|
|
|
|
@property |
|
|
def batch_size(self): |
|
|
return self.batch_size_per_gpu |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self._device |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
return self._rank |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._world_size |
|
|
|
|
|
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: |
|
|
""" """ |
|
|
add_special_tokens = False if add_special_tokens is None else add_special_tokens |
|
|
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
|
|
|
if left_truncate_len: |
|
|
encoding = encoding[-left_truncate_len:] |
|
|
return encoding |
|
|
|
|
|
def tok_decode(self, tokens): |
|
|
return self.tokenizer.decode(tokens) |
|
|
|
|
|
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
|
|
raise NotImplementedError("Loglikelihood is not implemented for Idefics2 model") |
|
|
|
|
|
def flatten(self, input): |
|
|
new_list = [] |
|
|
for i in input: |
|
|
for j in i: |
|
|
new_list.append(j) |
|
|
return new_list |
|
|
|
|
|
def generate_until(self, requests: List[Instance]) -> List[str]: |
|
|
res = [] |
|
|
|
|
|
def _collate(x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(x[0]) |
|
|
return -len(toks), x[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) |
|
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
|
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 |
|
|
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") |
|
|
for chunk in chunks: |
|
|
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk) |
|
|
visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] |
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
|
|
until = gen_kwargs.pop("until", None) |
|
|
image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio", None) |
|
|
if "max_new_tokens" not in gen_kwargs: |
|
|
gen_kwargs["max_new_tokens"] = 1024 |
|
|
if "temperature" not in gen_kwargs: |
|
|
gen_kwargs["temperature"] = 0 |
|
|
|
|
|
prompts = [] |
|
|
videos = None |
|
|
for context, visual in zip(contexts, visuals): |
|
|
content = [] |
|
|
if isinstance(visual[0], str): |
|
|
videos = load_video_decord(visual[0], max_frames_num=self.max_frames_num) |
|
|
for _ in range(videos.shape[0]): |
|
|
content.append({"type": "image"}) |
|
|
elif DEFAULT_IMAGE_TOKEN not in context: |
|
|
for image in visual: |
|
|
content.append({"type": "image"}) |
|
|
content.append({"type": "text", "text": context}) |
|
|
message = [{"role": "user", "content": content}] |
|
|
prompt = self._processor.apply_chat_template(message, add_generation_prompt=True) |
|
|
prompts.append(prompt) |
|
|
if videos is not None: |
|
|
images = [] |
|
|
for frame in videos: |
|
|
images.append(to_pil_image(frame)) |
|
|
inputs = self._processor(text=prompts, images=images, padding=True, return_tensors="pt") |
|
|
else: |
|
|
inputs = self._processor(text=prompts, images=visuals, padding=True, return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
output_ids = self.model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
for output_id, input_id in zip(output_ids, inputs["input_ids"]): |
|
|
generated_id = output_id[len(input_id) :] |
|
|
generated_text = self.tokenizer.decode(generated_id, skip_special_tokens=True) |
|
|
|
|
|
res.append(generated_text) |
|
|
pbar.update(1) |
|
|
|
|
|
res = re_ords.get_original(res) |
|
|
|
|
|
pbar.close() |
|
|
return res |
|
|
|
|
|
def generate_until_multi_round(self, requests) -> List[str]: |
|
|
raise NotImplementedError("TODO: Implement multi-round generation for Idefics2") |
|
|
|