| import copy |
| from typing import Dict, List, Optional |
|
|
| import transformers |
| from more_itertools import distribute |
| from tqdm import tqdm |
|
|
| from lm_eval.api.instance import Instance |
| from lm_eval.api.registry import register_model |
| from lm_eval.models.utils import ( |
| Collator, |
| handle_stop_sequences, |
| replace_placeholders, |
| undistribute, |
| ) |
| from lm_eval.models.vllm_causallms import VLLM |
| from lm_eval.utils import eval_logger |
|
|
|
|
| try: |
| import ray |
| from vllm import LLM, SamplingParams |
| from vllm.lora.request import LoRARequest |
| from vllm.transformers_utils.tokenizer import get_tokenizer |
| except ModuleNotFoundError: |
| pass |
|
|
|
|
| DEFAULT_IMAGE_PLACEHOLDER = "<image>" |
|
|
|
|
| @register_model("vllm-vlm") |
| class VLLM_VLM(VLLM): |
| MULTIMODAL = True |
|
|
| def __init__( |
| self, |
| pretrained: str, |
| trust_remote_code: Optional[bool] = False, |
| revision: Optional[str] = None, |
| interleave: bool = True, |
| |
| max_images: int = 999, |
| **kwargs, |
| ): |
| if max_images != 999: |
| kwargs["limit_mm_per_prompt"] = {"image": max_images} |
| eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}") |
| super().__init__( |
| pretrained=pretrained, |
| trust_remote_code=trust_remote_code, |
| revision=revision, |
| **kwargs, |
| ) |
| self.interleave = interleave |
| self.max_images = max_images |
| self.processor = transformers.AutoProcessor.from_pretrained( |
| pretrained, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| ) |
| self.chat_applied: bool = False |
|
|
| def tok_batch_multimodal_encode( |
| self, |
| strings: List[str], |
| images, |
| left_truncate_len: int = None, |
| truncation: bool = False, |
| ): |
| images = [img[: self.max_images] for img in images] |
| |
| if self.chat_applied is False: |
| strings = [ |
| replace_placeholders( |
| string, |
| DEFAULT_IMAGE_PLACEHOLDER, |
| DEFAULT_IMAGE_PLACEHOLDER, |
| self.max_images, |
| ) |
| for string in strings |
| ] |
|
|
| outputs = [] |
| for x, i in zip(strings, images): |
| inputs = { |
| "prompt": x, |
| "multi_modal_data": {"image": i}, |
| } |
| outputs.append(inputs) |
| return outputs |
|
|
| def _model_generate( |
| self, |
| requests: List[List[dict]] = None, |
| generate: bool = False, |
| max_tokens: int = None, |
| stop: Optional[List[str]] = None, |
| **kwargs, |
| ): |
| if generate: |
| kwargs = self.modify_gen_kwargs(kwargs) |
| sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) |
| else: |
| sampling_params = SamplingParams( |
| temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False |
| ) |
| if self.data_parallel_size > 1: |
| |
| |
| |
| |
| |
| @ray.remote |
| def run_inference_one_model( |
| model_args: dict, sampling_params, requests: List[List[dict]] |
| ): |
| llm = LLM(**model_args) |
| return llm.generate(requests, sampling_params=sampling_params) |
|
|
| |
| |
| requests = [list(x) for x in distribute(self.data_parallel_size, requests)] |
| inputs = ((self.model_args, sampling_params, req) for req in requests) |
| object_refs = [run_inference_one_model.remote(*x) for x in inputs] |
| results = ray.get(object_refs) |
| |
| ray.shutdown() |
| |
| return undistribute(results) |
|
|
| if self.lora_request is not None: |
| outputs = self.model.generate( |
| requests, |
| sampling_params=sampling_params, |
| use_tqdm=True if self.batch_size == "auto" else False, |
| lora_request=self.lora_request, |
| ) |
| else: |
| outputs = self.model.generate( |
| requests, |
| sampling_params=sampling_params, |
| use_tqdm=True if self.batch_size == "auto" else False, |
| ) |
| return outputs |
|
|
| def apply_chat_template( |
| self, chat_history: List[Dict[str, str]], add_generation_prompt=True |
| ) -> str: |
| self.chat_applied = True |
| if not self.interleave: |
| for content in chat_history: |
| c = [] |
| text = content["content"] |
|
|
| |
| image_count = min( |
| self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
| ) |
| text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") |
|
|
| |
| for _ in range(image_count): |
| c.append({"type": "image", "image": None}) |
|
|
| |
| c.append({"type": "text", "text": text}) |
|
|
| content["content"] = c |
| else: |
| for content in chat_history: |
| c = [] |
| text = content["content"] |
| expected_image_count = min( |
| self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
| ) |
| actual_image_count = 0 |
|
|
| text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) |
|
|
| for i, part in enumerate(text_parts): |
| |
| if part: |
| c.append({"type": "text", "text": part}) |
| if ( |
| (i < len(text_parts) - 1) and i < self.max_images |
| ): |
| c.append({"type": "image"}) |
| actual_image_count += 1 |
|
|
| content["content"] = c |
|
|
| if actual_image_count != expected_image_count: |
| raise ValueError( |
| f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" |
| ) |
|
|
| return self.processor.apply_chat_template( |
| chat_history, |
| add_generation_prompt=add_generation_prompt, |
| continue_final_message=not add_generation_prompt, |
| ) |
|
|
| def generate_until( |
| self, requests: List[Instance], disable_tqdm: bool = False |
| ) -> List[str]: |
| |
| res = [] |
|
|
| def _collate(x): |
| |
| |
| |
| |
| |
| |
| toks = self.tok_encode(x[0]) |
| return -len(toks), x[0] |
|
|
| pbar = tqdm( |
| total=len(requests), |
| disable=(disable_tqdm or (self.rank != 0)), |
| desc="Running generate_until requests with text+image input", |
| ) |
| |
|
|
| |
| |
| |
| re_ords = Collator( |
| [reg.args for reg in requests], |
| _collate, |
| group_by="gen_kwargs", |
| group_fn=lambda x: x[1], |
| ) |
| chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
| eos = self.tokenizer.decode(self.eot_token_id) |
| for chunk in chunks: |
| contexts, all_gen_kwargs, aux_arguments = zip(*chunk) |
|
|
| visuals = [arg["visual"] for arg in aux_arguments] |
|
|
| if not isinstance(contexts, list): |
| contexts = list( |
| contexts |
| ) |
| |
|
|
| |
| |
| gen_kwargs = all_gen_kwargs[0] |
| |
| if isinstance(gen_kwargs, dict): |
| kwargs = copy.deepcopy(gen_kwargs) |
| |
| until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) |
| else: |
| raise ValueError( |
| f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
| ) |
| if "max_gen_toks" in kwargs.keys(): |
| max_gen_toks = kwargs.pop("max_gen_toks") |
| else: |
| max_gen_toks = self.max_gen_toks |
|
|
| max_ctx_len = self.max_length - max_gen_toks |
|
|
| inputs = self.tok_batch_multimodal_encode( |
| contexts, |
| visuals, |
| left_truncate_len=max_ctx_len, |
| ) |
|
|
| cont = self._model_generate( |
| inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs |
| ) |
|
|
| for output, context in zip(cont, contexts): |
| generated_text = output.outputs[0].text |
| res.append(generated_text) |
| self.cache_hook.add_partial( |
| "generate_until", (context, gen_kwargs), generated_text |
| ) |
| pbar.update(1) |
| |
| res = re_ords.get_original(res) |
|
|
| pbar.close() |
| return res |
|
|