| import os |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| from peft import PeftModel |
| from safetensors.torch import load_file, save_file |
| from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel |
| from transformers.utils import cached_file |
|
|
| from .configuration_capri import CapriConfig |
|
|
|
|
| class MLPProjector(nn.Module): |
| def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, out_dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class CapriForConditionalGeneration(PreTrainedModel): |
| config_class = CapriConfig |
| base_model_prefix = "capri" |
| main_input_name = "input_ids" |
|
|
| def __init__(self, config: CapriConfig): |
| super().__init__(config) |
| self.projector = MLPProjector( |
| in_dim=config.projector_in_dim, |
| hidden_dim=config.projector_hidden_dim, |
| out_dim=config.projector_out_dim, |
| ) |
| self.text_model = None |
| self.vision_model = None |
| self.tokenizer = None |
| self._repo_id_or_path = None |
| self._hub_kwargs = {} |
| self._text_model_kwargs = {} |
| self._vision_model_kwargs = {} |
| self.post_init() |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, config=None, **kwargs): |
| load_vision_tower = kwargs.pop("load_vision_tower", None) |
| if config is None: |
| config, model_kwargs = CapriConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| return_unused_kwargs=True, |
| **kwargs, |
| ) |
| else: |
| model_kwargs = dict(kwargs) |
|
|
| model = cls(config, *model_args) |
| model._repo_id_or_path = pretrained_model_name_or_path |
| model._hub_kwargs = { |
| "cache_dir": model_kwargs.get("cache_dir"), |
| "force_download": model_kwargs.get("force_download"), |
| "local_files_only": model_kwargs.get("local_files_only"), |
| "revision": model_kwargs.get("revision"), |
| "token": model_kwargs.get("token"), |
| "trust_remote_code": model_kwargs.get("trust_remote_code", True), |
| } |
| base_runtime = { |
| "cache_dir": model_kwargs.get("cache_dir"), |
| "force_download": model_kwargs.get("force_download"), |
| "local_files_only": model_kwargs.get("local_files_only"), |
| "revision": model_kwargs.get("revision"), |
| "token": model_kwargs.get("token"), |
| "torch_dtype": model_kwargs.get("torch_dtype", model_kwargs.get("dtype")), |
| "device_map": model_kwargs.get("device_map"), |
| "attn_implementation": model_kwargs.get("attn_implementation"), |
| } |
| model._text_model_kwargs = {k: v for k, v in base_runtime.items() if v is not None} |
| model._vision_model_kwargs = {k: v for k, v in base_runtime.items() if k != "attn_implementation" and v is not None} |
|
|
| model._load_tokenizer() |
| model._load_text_model() |
| model._load_projector_weights() |
|
|
| should_load_vision = ( |
| config.load_vision_tower_by_default if load_vision_tower is None else load_vision_tower |
| ) |
| if should_load_vision: |
| model._load_vision_model() |
|
|
| model.eval() |
| return model |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| os.makedirs(save_directory, exist_ok=True) |
| self.config.save_pretrained(save_directory) |
| save_file( |
| self.projector.state_dict(), |
| os.path.join(save_directory, "projector.safetensors"), |
| ) |
| if self.text_model is not None: |
| self.text_model.save_pretrained( |
| os.path.join(save_directory, self.config.adapter_subdir) |
| ) |
| if self.tokenizer is not None: |
| self.tokenizer.save_pretrained(save_directory) |
|
|
| def _resolve_repo_file(self, filename: str, subfolder: str | None = None) -> str: |
| if os.path.isdir(self._repo_id_or_path): |
| parts = [self._repo_id_or_path] |
| if subfolder: |
| parts.append(subfolder) |
| parts.append(filename) |
| return os.path.join(*parts) |
| return cached_file(self._repo_id_or_path, filename, subfolder=subfolder, **self._hub_kwargs) |
|
|
| def _load_tokenizer(self): |
| if self.tokenizer is not None: |
| return |
| |
| if self.config.image_token_id is None or self.config.image_token is None: |
| raise ValueError("`image_token_id` and `image_token` must be set in the config.") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self._repo_id_or_path, |
| **self._hub_kwargs, |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def _load_text_model(self): |
| if self.text_model is not None: |
| return |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.config.text_model_name_or_path, |
| **self._text_model_kwargs, |
| ) |
| self.text_model = PeftModel.from_pretrained( |
| base_model, |
| self._repo_id_or_path, |
| subfolder=self.config.adapter_subdir, |
| is_trainable=False, |
| **self._hub_kwargs, |
| ) |
| self.text_model.eval() |
|
|
| def _load_vision_model(self): |
| if self.vision_model is not None: |
| return |
| model = AutoModel.from_pretrained( |
| self.config.vision_model_name_or_path, |
| **self._vision_model_kwargs, |
| ) |
| self.vision_model = getattr(model, "vision_model", model) |
| self.vision_model.eval() |
|
|
| def _load_projector_weights(self): |
| projector_path = self._resolve_repo_file("projector.safetensors") |
| state_dict = load_file(projector_path) |
| self.projector.load_state_dict(state_dict) |
|
|
| embed_weight = self.text_model.get_input_embeddings().weight |
| self.projector.to(device=embed_weight.device, dtype=embed_weight.dtype) |
|
|
| @property |
| def vision_loaded(self) -> bool: |
| return self.vision_model is not None |
|
|
| @staticmethod |
| def _module_device_dtype(module: nn.Module) -> tuple[torch.device, torch.dtype]: |
| param = next(module.parameters()) |
| return param.device, param.dtype |
|
|
| @staticmethod |
| def _chunk_list(items: list[Any], chunk_size: int) -> list[list[Any]]: |
| return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] |
|
|
| def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| self._load_vision_model() |
| vision_device, vision_dtype = self._module_device_dtype(self.vision_model) |
| pixel_values = pixel_values.to(device=vision_device, dtype=vision_dtype) |
| outputs = self.vision_model(pixel_values=pixel_values) |
| pooled = getattr(outputs, "pooler_output", None) |
| if pooled is None: |
| last_hidden = getattr(outputs, "last_hidden_state", None) |
| if last_hidden is None: |
| raise ValueError("Vision model did not return pooler_output or last_hidden_state.") |
| pooled = last_hidden[:, 0] |
| return pooled |
|
|
| def _prompt_inputs(self, batch_size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: |
| encoded = self.tokenizer( |
| [self.config.prompt_prefix] * batch_size, |
| add_special_tokens=False, |
| return_tensors="pt", |
| padding=True, |
| ) |
| return encoded["input_ids"].to(device), encoded["attention_mask"].to(device) |
|
|
| def _prepare_inputs( |
| self, |
| *, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| pooled_embeddings: torch.Tensor | None = None, |
| pixel_values: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if pooled_embeddings is None: |
| if pixel_values is None: |
| raise ValueError("Provide either `pooled_embeddings` or `pixel_values`.") |
| pooled_embeddings = self.encode_images(pixel_values) |
|
|
| if pooled_embeddings.ndim == 1: |
| pooled_embeddings = pooled_embeddings.unsqueeze(0) |
|
|
| target_device = self.text_model.get_input_embeddings().weight.device |
| if input_ids is None: |
| input_ids, attention_mask = self._prompt_inputs(pooled_embeddings.size(0), target_device) |
| else: |
| input_ids = input_ids.to(target_device) |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, device=target_device) |
| else: |
| attention_mask = attention_mask.to(target_device) |
|
|
| inputs_embeds = self.text_model.get_input_embeddings()(input_ids) |
| pooled_embeddings = pooled_embeddings.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) |
| projected = self.projector(pooled_embeddings) |
|
|
| image_mask = input_ids.eq(self.config.image_token_id) |
| image_count = image_mask.sum(dim=1) |
| if not torch.all(image_count == 1): |
| raise ValueError("Each sample must contain exactly one `<image>` token.") |
|
|
| token_positions = image_mask.float().argmax(dim=1) |
| batch_positions = torch.arange(input_ids.size(0), device=input_ids.device) |
| inputs_embeds[batch_positions, token_positions] = projected |
| return inputs_embeds, attention_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| pooled_embeddings: torch.Tensor | None = None, |
| pixel_values: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| **kwargs: Any, |
| ): |
| if input_ids is None and labels is not None: |
| raise ValueError("`input_ids` are required when passing `labels`.") |
|
|
| inputs_embeds, attention_mask = self._prepare_inputs( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pooled_embeddings=pooled_embeddings, |
| pixel_values=pixel_values, |
| ) |
| return self.text_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| labels=labels, |
| **kwargs, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| pooled_embeddings: torch.Tensor | None = None, |
| pixel_values: torch.Tensor | None = None, |
| **generate_kwargs: Any, |
| ) -> torch.Tensor: |
| inputs_embeds, attention_mask = self._prepare_inputs( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pooled_embeddings=pooled_embeddings, |
| pixel_values=pixel_values, |
| ) |
|
|
| generate_kwargs.setdefault("do_sample", False) |
| generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) |
| generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id) |
| return self.text_model.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| **generate_kwargs, |
| ) |
|
|
| @torch.no_grad() |
| def generate_captions( |
| self, |
| *, |
| images: Any = None, |
| pooled_embeddings: Any = None, |
| processor=None, |
| vision_batch_size: int = 64, |
| decode_batch_size: int = 1024, |
| **generate_kwargs: Any, |
| ) -> list[str]: |
| if processor is None: |
| raise ValueError("`processor` is required for `generate_captions()`.") |
| if images is None and pooled_embeddings is None: |
| raise ValueError("Provide either `images` or `pooled_embeddings`.") |
| if images is not None and pooled_embeddings is not None: |
| raise ValueError("Provide only one of `images` or `pooled_embeddings`.") |
| if vision_batch_size <= 0 or decode_batch_size <= 0: |
| raise ValueError("Batch sizes must be positive integers.") |
|
|
| if images is not None: |
| image_items = processor.normalize_images(images) |
| all_pooled = [] |
| for image_chunk in self._chunk_list(image_items, vision_batch_size): |
| image_inputs = processor(images=image_chunk, return_tensors="pt") |
| pooled_chunk = self.encode_images(image_inputs["pixel_values"]).detach().cpu() |
| all_pooled.append(pooled_chunk) |
| pooled_embeddings = torch.cat(all_pooled, dim=0) |
| else: |
| pooled_embeddings = processor.normalize_pooled_embeddings(pooled_embeddings).detach().cpu() |
|
|
| captions = [] |
| total = pooled_embeddings.shape[0] |
| for start in range(0, total, decode_batch_size): |
| pooled_chunk = pooled_embeddings[start : start + decode_batch_size] |
| model_inputs = dict(processor( |
| pooled_embeddings=pooled_chunk, |
| return_tensors="pt", |
| )) |
| sequences = self.generate(**model_inputs, **generate_kwargs) |
| captions.extend(processor.batch_decode(sequences, skip_special_tokens=True)) |
|
|
| return captions |
|
|