|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from accelerate import Accelerator, DistributedType |
|
|
from accelerate.state import AcceleratorState |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
from lmms_eval import utils |
|
|
from lmms_eval.api.instance import Instance |
|
|
from lmms_eval.api.model import lmms |
|
|
from lmms_eval.api.registry import register_model |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from loguru import logger as eval_logger |
|
|
|
|
|
|
|
|
@register_model("minimonkey") |
|
|
class MiniMonkey(lmms): |
|
|
""" |
|
|
MiniMonkey Model |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pretrained: str = "mx262/MiniMonkey", |
|
|
device: Optional[str] = "cuda", |
|
|
dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16, |
|
|
batch_size: Optional[Union[int, str]] = 1, |
|
|
trust_remote_code: Optional[bool] = True, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
assert kwargs == {}, f"Unexpected kwargs: {kwargs}" |
|
|
|
|
|
accelerator = Accelerator() |
|
|
if accelerator.num_processes > 1: |
|
|
self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
|
|
else: |
|
|
self._device = device |
|
|
self.dtype = dtype |
|
|
self._model = AutoModel.from_pretrained(pretrained, trust_remote_code=trust_remote_code, torch_dtype=dtype, device_map=self._device) |
|
|
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code) |
|
|
self._config = self._model.config |
|
|
self.model.eval() |
|
|
self.model.tie_weights() |
|
|
self.batch_size_per_gpu = int(batch_size) |
|
|
if accelerator.num_processes > 1: |
|
|
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
kwargs = { |
|
|
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
|
|
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
|
|
} |
|
|
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
|
|
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
|
|
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
self._model = accelerator.prepare(self.model) |
|
|
else: |
|
|
self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
|
|
self.accelerator = accelerator |
|
|
if self.accelerator.is_local_main_process: |
|
|
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
|
|
self._rank = self.accelerator.local_process_index |
|
|
self._world_size = self.accelerator.num_processes |
|
|
else: |
|
|
|
|
|
self._rank = 0 |
|
|
self._world_size = 1 |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
|
|
|
return self._config |
|
|
|
|
|
@property |
|
|
def tokenizer(self): |
|
|
return self._tokenizer |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
|
|
|
if hasattr(self, "accelerator"): |
|
|
return self.accelerator.unwrap_model(self._model) |
|
|
else: |
|
|
return self._model |
|
|
|
|
|
@property |
|
|
def eot_token_id(self): |
|
|
|
|
|
return self.tokenizer.eos_token_id |
|
|
|
|
|
@property |
|
|
def max_length(self): |
|
|
return self._max_length |
|
|
|
|
|
@property |
|
|
def batch_size(self): |
|
|
return self.batch_size_per_gpu |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self._device |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
return self._rank |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._world_size |
|
|
|
|
|
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: |
|
|
""" """ |
|
|
add_special_tokens = False if add_special_tokens is None else add_special_tokens |
|
|
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
|
|
|
if left_truncate_len: |
|
|
encoding = encoding[-left_truncate_len:] |
|
|
return encoding |
|
|
|
|
|
def tok_decode(self, tokens): |
|
|
return self.tokenizer.decode(tokens) |
|
|
|
|
|
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
|
|
|
|
|
assert False, "We have not implemented this function for MiniMonkey yet" |
|
|
|
|
|
def flatten(self, input): |
|
|
new_list = [] |
|
|
for i in input: |
|
|
for j in i: |
|
|
new_list.append(j) |
|
|
return new_list |
|
|
|
|
|
def generate_until(self, requests: List[Instance]) -> List[str]: |
|
|
res = [] |
|
|
|
|
|
def _collate(x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(x[0]) |
|
|
return -len(toks), x[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) |
|
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
|
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 |
|
|
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") |
|
|
for chunk in chunks: |
|
|
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) |
|
|
task = task[0] |
|
|
split = split[0] |
|
|
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] |
|
|
visuals = self.flatten(visuals) |
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
|
|
|
|
|
until = [self.tok_decode(self.eot_token_id)] |
|
|
|
|
|
|
|
|
if "until" in gen_kwargs: |
|
|
until = gen_kwargs.pop("until") |
|
|
if isinstance(until, str): |
|
|
until = [until] |
|
|
elif not isinstance(until, list): |
|
|
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}") |
|
|
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now" |
|
|
assert len(visuals) == 1, "MiniMonkey interface does not support bn_image > 1 for now" |
|
|
context = contexts[0] |
|
|
if "<image>" in context: |
|
|
context = context.replace("<image>", "") |
|
|
|
|
|
if "max_new_tokens" not in gen_kwargs: |
|
|
gen_kwargs["max_new_tokens"] = 1024 |
|
|
if "temperature" not in gen_kwargs: |
|
|
gen_kwargs["temperature"] = 0 |
|
|
if "top_p" not in gen_kwargs: |
|
|
gen_kwargs["top_p"] = None |
|
|
if "num_beams" not in gen_kwargs: |
|
|
gen_kwargs["num_beams"] = 1 |
|
|
|
|
|
image, prompt = visuals[0], context |
|
|
try: |
|
|
pixel_values, target_aspect_ratio = load_image(image, min_num=4, max_num=12) |
|
|
pixel_values2 = load_image2(image, min_num=3, max_num=7, target_aspect_ratio=target_aspect_ratio) |
|
|
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0).to(self._device).to(self.dtype) |
|
|
|
|
|
response, history = self.model.chat(self.tokenizer, pixel_values, target_aspect_ratio, prompt, gen_kwargs, history=None, return_history=True) |
|
|
|
|
|
context = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}] |
|
|
except Exception as e: |
|
|
eval_logger.error(f"Error {e} in generating") |
|
|
cont = "" |
|
|
res.append(response) |
|
|
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), response) |
|
|
pbar.update(1) |
|
|
|
|
|
res = re_ords.get_original(res) |
|
|
|
|
|
pbar.close() |
|
|
return res |
|
|
|
|
|
def generate_until_multi_round(self, requests) -> List[str]: |
|
|
raise NotImplementedError("TODO: Implement multi-round generation") |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
|
|
|
def build_transform(input_size): |
|
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
|
transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]) |
|
|
return transform |
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
|
best_ratio_diff = float("inf") |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) |
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images, target_aspect_ratio |
|
|
|
|
|
|
|
|
def dynamic_preprocess2(image, min_num=1, max_num=12, prior_aspect_ratio=None, image_size=448, use_thumbnail=False): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) |
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
new_target_ratios = [] |
|
|
for i in target_ratios: |
|
|
if prior_aspect_ratio[0] % i[0] or prior_aspect_ratio[1] % i[1]: |
|
|
new_target_ratios.append(i) |
|
|
else: |
|
|
continue |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, new_target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images |
|
|
|
|
|
|
|
|
def load_image(image, input_size=448, min_num=1, max_num=12): |
|
|
image = image.convert("RGB") |
|
|
transform = build_transform(input_size=input_size) |
|
|
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num) |
|
|
pixel_values = [transform(image) for image in images] |
|
|
pixel_values = torch.stack(pixel_values) |
|
|
return pixel_values, target_aspect_ratio |
|
|
|
|
|
|
|
|
def load_image2(image, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None): |
|
|
image = image.convert("RGB") |
|
|
transform = build_transform(input_size=input_size) |
|
|
images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio) |
|
|
pixel_values = [transform(image) for image in images] |
|
|
pixel_values = torch.stack(pixel_values) |
|
|
return pixel_values |
|
|
|