csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
import time
import warnings
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.gen_metrics import log_metrics
from lmms_eval.protocol import ChatMessages
warnings.filterwarnings("ignore")
from loguru import logger as eval_logger
from lmms_eval.api.registry import register_model
from lmms_eval.models.simple.llava_hf import LlavaHf as LlavaHfSimple
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_VIDEO_TOKEN = "<video>"
# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
@register_model("llava_hf_chat")
class LlavaHf(LlavaHfSimple):
is_simple = False
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
# A dummy collate here to sort by doc id
def _collate(x):
return x[2], x[2]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
e2e_latency = 0
total_tokens = 0
for chunk in chunks:
ctx, doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
chat_messages = [doc_to_messages[0](self.task_dict[task][split][ids]) for ids in doc_id]
chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages]
visuals = []
videos = []
for messages in chat_messages:
visual, video, _ = messages.extract_media()
visuals.append(visual)
videos.append(video)
visuals = self.flatten(visuals)
videos = self.flatten(videos)
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
# Apply chat template
messages = chat_messages[0].model_dump()["messages"]
text = self._image_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
if len(videos) == 0:
videos = None
inputs = self._image_processor(images=visuals, videos=videos, text=text, return_tensors="pt").to(self._device, self.model.dtype)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
do_sample = True if gen_kwargs["temperature"] > 0 else False
try:
start_time = time.time()
cont = self.model.generate(
**inputs,
do_sample=do_sample,
temperature=gen_kwargs["temperature"] if do_sample else None,
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
pad_token_id=self.eot_token_id,
eos_token_id=self.eot_token_id,
)
end_time = time.time()
cont = cont[:, inputs["input_ids"].shape[-1] :]
# Calculate timing metrics
e2e_latency += end_time - start_time
total_tokens += cont.shape[-1] if len(cont.shape) > 1 else len(cont)
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
e2e_latency += 0
total_tokens += 0
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] if cont != "" else ""
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
res.append(text_outputs)
self.cache_hook.add_partial("generate_until", (text, gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
metric_dict = {
"total_tokens": total_tokens,
"e2e_latency": e2e_latency,
"avg_speed": total_tokens / e2e_latency if e2e_latency > 0 else 0,
"additional_metrics": {
"rank": self.rank,
},
}
log_metrics(**metric_dict)
pbar.close()
return res