File size: 15,242 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
import re
import warnings
from typing import List, Optional, Tuple, Union
import librosa
import numpy as np
import PIL
import torch
from accelerate import Accelerator, DistributedType
from accelerate.state import AcceleratorState
from decord import VideoReader, cpu
from PIL import Image
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.audio_processing import downsample_audio
warnings.filterwarnings("ignore")
from loguru import logger as eval_logger
@register_model("phi4_multimodal")
class Phi4(lmms):
"""
Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava
Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py
Example usage:
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
--model phi4_multimodal \
--model_args pretrained=microsoft/Phi-4-multimodal-instruct \
--tasks seedbench \
--batch_size 1 \
--output_path ./logs/ \
--log_samples
"""
def __init__(
self,
pretrained: str = "microsoft/Phi-4-multimodal-instruct",
revision: str = "main",
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: int = 1,
trust_remote_code: Optional[bool] = True,
attn_implementation: Optional[str] = None,
device_map: str = "",
chat_template: Optional[str] = None,
use_cache: bool = True,
max_frames_num: Optional[int] = 16,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator = Accelerator()
if accelerator.num_processes > 1 and device_map == "":
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
else:
self._device = torch.device(device)
self.device_map = device_map
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)
self.max_frames_num = max_frames_num
self._model = AutoModelForCausalLM.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self.pretrained = pretrained
self._processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
self._processor.tokenizer.padding_side = "left"
self._tokenizer = self._processor.tokenizer
self._config = self._model.config
self.batch_size_per_gpu = int(batch_size)
self.chat_template = chat_template
self.use_cache = use_cache
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
self._rank = 0
self._world_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1
self.accelerator = accelerator
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("TODO: Implement loglikelihood for Phi-4")
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def load_video(self, video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames # (frames, height, width, channels)
def default_process(self, visuals, contexts):
text = "<|user|>"
images = []
audios = []
vision_start = 1
audio_start = 1
for visual in visuals:
if isinstance(visual, str):
frames = self.load_video(visual, self.max_frames_num)
for image in frames:
text += f"<|image_{vision_start}|>"
images.append(Image.fromarray(np.uint8(image)))
vision_start += 1
elif isinstance(visual, PIL.Image.Image):
images.append(visual)
text += f"<|image_{vision_start}|>"
vision_start += 1
elif isinstance(visual, dict) and "array" in visual:
audio = downsample_audio(visual["array"], visual["sampling_rate"], self._processor.audio_processor.sampling_rate)
audio = [audio, self._processor.audio_processor.sampling_rate]
audios.append(audio)
text += f"<|audio_{audio_start}|>"
audio_start += 1
text += f"{contexts[0]}<|end|><|assistant|>"
if len(images) == 0:
images = None
if len(audios) == 0:
audios = None
return text, images, audios
def process_av_odessy(self, visuals, context):
text = "<|user|>"
images = []
audios = []
vision_start = 1
audio_start = 1
# Split the media tag
pattern = r"<media_(\d+)>"
matches = list(re.finditer(pattern, context))
result = []
if not matches:
result = [context]
else:
last_match = 0
for match in matches:
result.append(context[last_match : match.start()])
last_match = match.end()
# Append the last part of the context
result.append(context[matches[-1].end() :])
import filetype
for idx, visual in enumerate(visuals):
file_type = filetype.guess(visual)
# Append at the front
text += result[idx]
if "audio" in file_type.mime:
audio = librosa.load(visual, sr=self._processor.audio_processor.sampling_rate)[0]
audio = [audio, self._processor.audio_processor.sampling_rate]
audios.append(audio)
text += f"<|audio_{audio_start}|>"
audio_start += 1
elif "video" in file_type.mime:
frames = self.load_video(visual, self.max_frames_num)
for image in frames:
text += f"<|image_{vision_start}|>"
images.append(Image.fromarray(np.uint8(image)))
vision_start += 1
elif "image" in file_type.mime:
images.append(Image.open(visual))
text += f"<|image_{vision_start}|>"
vision_start += 1
# Leave the last part of the context
text += result[-1]
text += "<|end|><|assistant|>"
if len(images) == 0:
images = None
if len(audios) == 0:
audios = None
return text, images, audios
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
if task == "av_odyssey":
text, images, audios = self.process_av_odessy(visuals, contexts[0])
else:
text, images, audios = self.default_process(visuals, contexts)
inputs = self._processor(text=text, images=images, audios=audios, return_tensors="pt").to(self.device)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
try:
cont = self.model.generate(
**inputs,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
pad_token_id=self.eot_token_id,
num_logits_to_keep=0,
)
except Exception as e:
eval_logger.error(f"Error generating text: {e}")
cont = inputs["input_ids"]
cont = cont[:, inputs["input_ids"].shape[-1] :]
text_outputs = self._processor.batch_decode(cont, skip_special_tokens=True)[0]
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
res.append(text_outputs)
self.cache_hook.add_partial("generate_until", (contexts[0], gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation for LLaVAHF")
|