|
|
import time |
|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
from lmms_eval import utils |
|
|
from lmms_eval.api.instance import Instance |
|
|
from lmms_eval.api.model import lmms |
|
|
from lmms_eval.api.registry import register_model |
|
|
from lmms_eval.models.model_utils.gen_metrics import log_metrics |
|
|
from lmms_eval.protocol import ChatMessages |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from loguru import logger as eval_logger |
|
|
|
|
|
from lmms_eval.api.registry import register_model |
|
|
from lmms_eval.models.simple.llava_hf import LlavaHf as LlavaHfSimple |
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
DEFAULT_VIDEO_TOKEN = "<video>" |
|
|
|
|
|
|
|
|
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" |
|
|
|
|
|
|
|
|
@register_model("llava_hf_chat") |
|
|
class LlavaHf(LlavaHfSimple): |
|
|
is_simple = False |
|
|
|
|
|
def generate_until(self, requests: List[Instance]) -> List[str]: |
|
|
res = [] |
|
|
|
|
|
|
|
|
def _collate(x): |
|
|
return x[2], x[2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True) |
|
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
|
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 |
|
|
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") |
|
|
e2e_latency = 0 |
|
|
total_tokens = 0 |
|
|
for chunk in chunks: |
|
|
ctx, doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk) |
|
|
task = task[0] |
|
|
split = split[0] |
|
|
chat_messages = [doc_to_messages[0](self.task_dict[task][split][ids]) for ids in doc_id] |
|
|
chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages] |
|
|
visuals = [] |
|
|
videos = [] |
|
|
for messages in chat_messages: |
|
|
visual, video, _ = messages.extract_media() |
|
|
visuals.append(visual) |
|
|
videos.append(video) |
|
|
visuals = self.flatten(visuals) |
|
|
videos = self.flatten(videos) |
|
|
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" |
|
|
|
|
|
|
|
|
messages = chat_messages[0].model_dump()["messages"] |
|
|
text = self._image_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
if self.accelerator.is_main_process and doc_id[0] % 100 == 0: |
|
|
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n") |
|
|
|
|
|
if len(videos) == 0: |
|
|
videos = None |
|
|
inputs = self._image_processor(images=visuals, videos=videos, text=text, return_tensors="pt").to(self._device, self.model.dtype) |
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
|
|
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] |
|
|
if "max_new_tokens" not in gen_kwargs: |
|
|
gen_kwargs["max_new_tokens"] = 1024 |
|
|
if "temperature" not in gen_kwargs: |
|
|
gen_kwargs["temperature"] = 0 |
|
|
if "top_p" not in gen_kwargs: |
|
|
gen_kwargs["top_p"] = None |
|
|
if "num_beams" not in gen_kwargs: |
|
|
gen_kwargs["num_beams"] = 1 |
|
|
do_sample = True if gen_kwargs["temperature"] > 0 else False |
|
|
try: |
|
|
start_time = time.time() |
|
|
cont = self.model.generate( |
|
|
**inputs, |
|
|
do_sample=do_sample, |
|
|
temperature=gen_kwargs["temperature"] if do_sample else None, |
|
|
top_p=gen_kwargs["top_p"], |
|
|
num_beams=gen_kwargs["num_beams"], |
|
|
max_new_tokens=gen_kwargs["max_new_tokens"], |
|
|
use_cache=self.use_cache, |
|
|
pad_token_id=self.eot_token_id, |
|
|
eos_token_id=self.eot_token_id, |
|
|
) |
|
|
end_time = time.time() |
|
|
cont = cont[:, inputs["input_ids"].shape[-1] :] |
|
|
|
|
|
|
|
|
e2e_latency += end_time - start_time |
|
|
total_tokens += cont.shape[-1] if len(cont.shape) > 1 else len(cont) |
|
|
|
|
|
except Exception as e: |
|
|
eval_logger.error(f"Error {e} in generating") |
|
|
cont = "" |
|
|
e2e_latency += 0 |
|
|
total_tokens += 0 |
|
|
|
|
|
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] if cont != "" else "" |
|
|
|
|
|
if self.accelerator.is_main_process and doc_id[0] % 100 == 0: |
|
|
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n") |
|
|
|
|
|
res.append(text_outputs) |
|
|
self.cache_hook.add_partial("generate_until", (text, gen_kwargs), text_outputs) |
|
|
pbar.update(1) |
|
|
|
|
|
res = re_ords.get_original(res) |
|
|
|
|
|
metric_dict = { |
|
|
"total_tokens": total_tokens, |
|
|
"e2e_latency": e2e_latency, |
|
|
"avg_speed": total_tokens / e2e_latency if e2e_latency > 0 else 0, |
|
|
"additional_metrics": { |
|
|
"rank": self.rank, |
|
|
}, |
|
|
} |
|
|
log_metrics(**metric_dict) |
|
|
|
|
|
pbar.close() |
|
|
return res |
|
|
|