File size: 22,259 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
import copy
import json
import logging
import math
import re
import warnings
from datetime import timedelta
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
import transformers
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from decord import VideoReader, cpu
from packaging import version
from tqdm import tqdm
from transformers import AutoConfig
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
# Suppress warnings
warnings.filterwarnings("ignore")
# Configure logging
eval_logger = logging.getLogger("lmms-eval")
# Enable TF32 for CUDA
torch.backends.cuda.matmul.allow_tf32 = True
# Import LLaVA modules
try:
import copy
import os
import re
import sys
import warnings
import numpy as np
import requests
import soundfile as sf
import torch
import whisper
from decord import VideoReader, cpu
from egogpt.constants import (
DEFAULT_IMAGE_TOKEN,
DEFAULT_SPEECH_TOKEN,
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
SPEECH_TOKEN_INDEX,
)
from egogpt.conversation import SeparatorStyle, conv_templates
from egogpt.mm_utils import get_model_name_from_path, process_images
from egogpt.model.builder import load_pretrained_model
from PIL import Image
from scipy.signal import resample
except ImportError as e:
eval_logger.debug(f"egogpt is not installed. Please install egogpt to use this model.\nError: {e}")
# Determine best attention implementation
if version.parse(torch.__version__) >= version.parse("2.1.2"):
best_fit_attn_implementation = "sdpa"
else:
best_fit_attn_implementation = "eager"
@register_model("egogpt")
class EgoGPT(lmms):
"""
EgoGPT Model
"""
def __init__(
self,
pretrained: str = "checkpoints/egogpt_IT_12k_1126_zero3",
truncation: Optional[bool] = True,
device: Optional[str] = "cuda:0",
batch_size: Optional[Union[int, str]] = 1,
model_name: Optional[str] = None,
attn_implementation: Optional[str] = best_fit_attn_implementation,
device_map: Optional[str] = "cuda:0",
conv_template: Optional[str] = "qwen_1_5",
use_cache: Optional[bool] = True,
truncate_context: Optional[bool] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6
customized_config: Optional[str] = None, # ends in json
max_frames_num: Optional[int] = 32,
mm_spatial_pool_stride: Optional[int] = 2,
mm_spatial_pool_mode: Optional[str] = "bilinear",
token_strategy: Optional[str] = "single", # could be "single" or "multiple", "multiple" denotes adding multiple <image> tokens for each frame
video_decode_backend: str = "decord",
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
elif accelerator.num_processes == 1 and device_map == "auto":
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
egogpt_model_args = {}
if attn_implementation is not None:
egogpt_model_args["attn_implementation"] = attn_implementation
self.pretrained = pretrained
self.token_strategy = token_strategy
self.max_frames_num = max_frames_num
self.mm_spatial_pool_stride = mm_spatial_pool_stride
self.mm_spatial_pool_mode = mm_spatial_pool_mode
self.video_decode_backend = video_decode_backend
# Try to load the model with the multimodal argument
self._tokenizer, self._model, self._max_length = load_pretrained_model(pretrained, device_map=self.device_map, **egogpt_model_args)
self._image_processor = self._model.get_vision_tower().image_processor
self._config = self._model.config
self.model.eval()
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.conv_template = conv_template
self.use_cache = use_cache
self.truncate_context = truncate_context
assert self.batch_size_per_gpu == 1
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
self._rank = 0
self._world_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._world_size = 1
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
try:
return self.tokenizer.decode(tokens)
except:
return self.tokenizer.decode([tokens])
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for EgoGPT")
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def split_text(self, text, keywords):
pattern = "(" + "|".join(map(re.escape, keywords)) + ")"
parts = re.split(pattern, text)
parts = [part for part in parts if part]
return parts
def load_video(self, video_path=None, audio_path=None, max_frames_num=16, fps=1, task_name=None):
if audio_path is not None:
speech, sample_rate = sf.read(audio_path)
if sample_rate != 16000:
target_length = int(len(speech) * 16000 / sample_rate)
speech = resample(speech, target_length)
if speech.ndim > 1:
speech = np.mean(speech, axis=1)
# max_length = 480000
speech = whisper.pad_or_trim(speech.astype(np.float32))
speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0)
speech_lengths = torch.LongTensor([speech.shape[0]])
else:
speech = torch.zeros(3000, 128)
speech_lengths = torch.LongTensor([3000])
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
avg_fps = round(vr.get_avg_fps() / fps)
frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
frame_time = [i / avg_fps for i in frame_idx]
if max_frames_num > 0:
if len(frame_idx) > max_frames_num:
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
if task_name == "egoplan":
# add current ovservation frame
frame_idx.append(total_frame_num - 1)
video = vr.get_batch(frame_idx).asnumpy()
return video, speech, speech_lengths
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
metadata = requests[0].metadata
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
origin_image_aspect_ratio = getattr(self._config, "image_aspect_ratio", None)
for chunk in chunks:
batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk)
task = batched_task[0]
split = batched_split[0]
batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] # [B, N]
assert len(batched_visuals) == 1
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
if "until" in gen_kwargs:
gen_kwargs.pop("until")
question_input = []
# import ipdb; ipdb.set_trace()
for visual, context in zip(batched_visuals, batched_contexts):
if origin_image_aspect_ratio is not None and self._config.image_aspect_ratio != origin_image_aspect_ratio:
self._config.image_aspect_ratio = origin_image_aspect_ratio
eval_logger.info(f"Resetting image aspect ratio to {origin_image_aspect_ratio}")
if visual is None or visual == []: # for text-only tasks.
visual = None
task_type = "text"
placeholder_count = 0
image_tensor = None
else:
if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: # for multi image case, we treat per image aspect ratio as "pad" by default.
self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad")
eval_logger.info(f"In Multi-Image setting, image aspect ratio: {self._config.image_aspect_ratio}")
if "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: # overwrite logic for video task with multiple static image frames
assert type(visual) == list, "sample_frames must be specified for video task"
sample_indices = np.linspace(0, len(visual) - 1, metadata["sample_frames"], dtype=int)
visual = [visual[i] for i in sample_indices]
assert len(visual) == metadata["sample_frames"]
image_tensor = process_images(visual, self._image_processor, self._config)
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
image_tensor = [image_tensor]
task_type = "video"
placeholder_count = 1
elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks
image_tensor = process_images(visual, self._image_processor, self._config)
speech = torch.zeros(3000, 128)
speech_lengths = torch.LongTensor([3000])
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
task_type = "image"
placeholder_count = len(visual) if isinstance(visual, list) else 1
elif type(visual[0]) == str: # For video task
image_tensor = []
try:
if self.video_decode_backend == "decord":
if "egoplan" in visual[0]:
task_name = "egoplan"
else:
task_name = None
frames, speech, speech_lengths = self.load_video(video_path=visual[0], max_frames_num=self.max_frames_num, task_name=task_name)
else:
raise NotImplementedError("Only decord backend is supported for video task")
processed_frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda()
processed_frames = processed_frames.half()
image_tensor.append(processed_frames)
image_sizes = [frames[0].size]
except Exception as e:
eval_logger.error(f"Error {e} in loading video")
image_tensor = None
task_type = "video"
placeholder_count = len(frames) if self.token_strategy == "multiple" else 1
if DEFAULT_IMAGE_TOKEN not in context:
question = DEFAULT_IMAGE_TOKEN + "\n" + context
else:
question = context
speech = torch.stack([speech]).to(self.device).half()
# This is much safer for llama3, as we now have some object type in it
if "llama_3" in self.conv_template:
conv = copy.deepcopy(conv_templates[self.conv_template])
else:
conv = conv_templates[self.conv_template].copy()
if utils.is_json(question): # conversational question input
question = json.loads(question)
for idx, item in enumerate(question):
role = conv.roles[idx % 2]
message = item["value"]
conv.append_message(role, message)
assert len(conv.messages) % 2 == 1
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
question_input.append(prompt_question)
else: # only simple string for question
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
question_input.append(prompt_question)
# preconfigure gen_kwargs with defaults
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "do_sample" not in gen_kwargs:
gen_kwargs["do_sample"] = False
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
parts = self.split_text(prompt_question, ["<image>", "<speech>"])
input_ids = []
for part in parts:
if "<image>" == part:
input_ids += [IMAGE_TOKEN_INDEX]
elif "<speech>" == part:
input_ids += [SPEECH_TOKEN_INDEX]
else:
input_ids += self.tokenizer(part).input_ids
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(self.device)
input_ids_list = [input_ids]
pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device)
attention_masks = input_ids.ne(pad_token_ids).to(self.device)
input_ids = torch.tensor(input_ids, dtype=torch.long).squeeze(0).to(self.device)
if task_type == "image":
gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))]
elif task_type == "video":
gen_kwargs["modalities"] = ["video"]
self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride
self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
# These steps are not in LLaVA's original code, but are necessary for generation to work
# TODO: attention to this major generation step...
if "image_aspect_ratio" in gen_kwargs.keys():
gen_kwargs.pop("image_aspect_ratio")
try:
with torch.inference_mode():
cont = self.model.generate(input_ids, images=image_tensor, speech=speech, speech_lengths=speech_lengths, **gen_kwargs)
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
except Exception as e:
raise e
text_outputs = [response.strip() for response in text_outputs]
res.extend(text_outputs)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
raise NotImplementedError("generate_until_multi_round is not implemented for EgoGPT")
|