csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
from typing import List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, DistributedType
from apps.plm.generate import (
PackedCausalTransformerGenerator,
PackedCausalTransformerGeneratorArgs,
load_consolidated_model_and_tokenizer,
)
from core.args import dataclass_from_dict
from core.transforms.image_transform import get_image_transform
from core.transforms.video_transform import get_video_transform
from loguru import logger as eval_logger
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
@register_model("plm")
class PerceptionLM(lmms):
"""
Perception Lanugate Model (PLM)
"Paste the paper link"
"Paste the github link"
"Paste the huggingface link"
"""
def __init__(
self,
pretrained: str = "facebook/Perception-LM-8B",
device: Optional[str] = "cuda",
batch_size: Optional[Union[int, str]] = 1,
compile_prefilling=False,
reduce_generation_overhead=False,
max_tokens=11264,
**kwargs,
) -> None:
super().__init__()
accelerator = Accelerator()
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
# Collect all arguments into a dictionary
args = {
"pretrained": pretrained,
"device": device,
"batch_size": batch_size,
"compile_prefilling": compile_prefilling,
"reduce_generation_overhead": reduce_generation_overhead,
"max_tokens": max_tokens,
**kwargs, # Include any additional keyword arguments
}
# Convert the dictionary to a dotlist format
dotlist = [f"{key}={value}" for key, value in args.items()]
cfg = OmegaConf.from_dotlist(dotlist)
gen_cfg = dataclass_from_dict(PackedCausalTransformerGeneratorArgs, cfg, strict=False)
# Load PLM model
eval_logger.info(f"Lodding PLM model from {cfg.pretrained}")
model, tokenizer, config = load_consolidated_model_and_tokenizer(cfg.pretrained)
# Create preprocessors (transforms)
processor = {}
vision_input_type = config.get("data").get("vision_input_type", "thumb+tile")
max_num_tiles = config.get("data").get("max_num_tiles", 36)
processor["image"] = get_image_transform(vision_input_type=vision_input_type, image_res=model.vision_model.image_size, max_num_tiles=max_num_tiles)
processor["video"] = get_video_transform(image_res=model.vision_model.image_size)
self._max_video_frames = config.get("data").get("max_video_frames", 32)
# Create PLM generator
eval_logger.info(f"Creating packed generator with gen_cfg: {gen_cfg}")
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
# Set the class variables
self._tokenizer = tokenizer
self._processor = processor
self._model = model
self._generator = generator
self.batch_size_per_gpu = int(batch_size)
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.process_index
self._world_size = self.accelerator.num_processes
else:
self._rank = 0
self._world_size = 1
@property
def generator(self):
return self._generator
@property
def tokenizer(self):
return self._tokenizer
@property
def processor(self):
return self._processor
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of text is more accurate for what we're doing than end of sentence
return self.tokenizer.eos_token_id
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def max_video_frames(self):
return self._max_video_frames
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for PLM")
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tokenizer.encode(x[0], add_bos=False, add_eos=False)
return -len(toks), x[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
messages = []
for i, context in enumerate(contexts):
if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
video_info = (visual, self.max_video_frames, None, None, None)
visual, _ = self.processor["video"](video_info)
message = (context, visual)
elif isinstance(visual, Image.Image): # Single image
visual = visual.convert("RGB")
visual, _ = self.processor["image"](visual)
message = (context, visual)
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images or Video Frames
visual = [image.convert("RGB") for image in visual]
visual, _ = self.processor["video"]._process_multiple_images_pil(visual)
message = (context, visual)
else:
# Text-only sample
raise NotImplementedError("Text-only input is not yet supported.")
else:
# Text-only sample
raise NotImplementedError("Text-only input is not yet supported.")
messages.append(message)
gen_kwargs = all_gen_kwargs[0]
if "max_new_tokens" in gen_kwargs:
self.generator.max_gen_len = gen_kwargs["max_new_tokens"]
if "temperature" in gen_kwargs:
self.generator.temperature = gen_kwargs["temperature"]
# Default for PLM
self.generator.top_p = None
self.generator.top_k = 100
generation, loglikelihood, greedy = self.generator.generate(messages)
for gen, context in zip(generation, contexts):
if gen.endswith("."):
gen = gen[:-1]
res.append(gen)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), gen)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("Multi-round generation is not implemented yet.")