import os from typing import Any import torch import torch.nn as nn from peft import PeftModel from safetensors.torch import load_file, save_file from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.utils import cached_file from .configuration_capri import CapriConfig class MLPProjector(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, out_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class CapriForConditionalGeneration(PreTrainedModel): config_class = CapriConfig base_model_prefix = "capri" main_input_name = "input_ids" def __init__(self, config: CapriConfig): super().__init__(config) self.projector = MLPProjector( in_dim=config.projector_in_dim, hidden_dim=config.projector_hidden_dim, out_dim=config.projector_out_dim, ) self.text_model = None self.vision_model = None self.tokenizer = None self._repo_id_or_path = None self._hub_kwargs = {} self._text_model_kwargs = {} self._vision_model_kwargs = {} self.post_init() @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, config=None, **kwargs): load_vision_tower = kwargs.pop("load_vision_tower", None) if config is None: config, model_kwargs = CapriConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs, ) else: model_kwargs = dict(kwargs) model = cls(config, *model_args) model._repo_id_or_path = pretrained_model_name_or_path model._hub_kwargs = { "cache_dir": model_kwargs.get("cache_dir"), "force_download": model_kwargs.get("force_download"), "local_files_only": model_kwargs.get("local_files_only"), "revision": model_kwargs.get("revision"), "token": model_kwargs.get("token"), "trust_remote_code": model_kwargs.get("trust_remote_code", True), } base_runtime = { "cache_dir": model_kwargs.get("cache_dir"), "force_download": model_kwargs.get("force_download"), "local_files_only": model_kwargs.get("local_files_only"), "revision": model_kwargs.get("revision"), "token": model_kwargs.get("token"), "torch_dtype": model_kwargs.get("torch_dtype", model_kwargs.get("dtype")), "device_map": model_kwargs.get("device_map"), "attn_implementation": model_kwargs.get("attn_implementation"), } model._text_model_kwargs = {k: v for k, v in base_runtime.items() if v is not None} model._vision_model_kwargs = {k: v for k, v in base_runtime.items() if k != "attn_implementation" and v is not None} model._load_tokenizer() model._load_text_model() model._load_projector_weights() should_load_vision = ( config.load_vision_tower_by_default if load_vision_tower is None else load_vision_tower ) if should_load_vision: model._load_vision_model() model.eval() return model def save_pretrained(self, save_directory: str, **kwargs): os.makedirs(save_directory, exist_ok=True) self.config.save_pretrained(save_directory) save_file( self.projector.state_dict(), os.path.join(save_directory, "projector.safetensors"), ) if self.text_model is not None: self.text_model.save_pretrained( os.path.join(save_directory, self.config.adapter_subdir) ) if self.tokenizer is not None: self.tokenizer.save_pretrained(save_directory) def _resolve_repo_file(self, filename: str, subfolder: str | None = None) -> str: if os.path.isdir(self._repo_id_or_path): parts = [self._repo_id_or_path] if subfolder: parts.append(subfolder) parts.append(filename) return os.path.join(*parts) return cached_file(self._repo_id_or_path, filename, subfolder=subfolder, **self._hub_kwargs) def _load_tokenizer(self): if self.tokenizer is not None: return if self.config.image_token_id is None or self.config.image_token is None: raise ValueError("`image_token_id` and `image_token` must be set in the config.") self.tokenizer = AutoTokenizer.from_pretrained( self._repo_id_or_path, **self._hub_kwargs, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def _load_text_model(self): if self.text_model is not None: return base_model = AutoModelForCausalLM.from_pretrained( self.config.text_model_name_or_path, **self._text_model_kwargs, ) self.text_model = PeftModel.from_pretrained( base_model, self._repo_id_or_path, subfolder=self.config.adapter_subdir, is_trainable=False, **self._hub_kwargs, ) self.text_model.eval() def _load_vision_model(self): if self.vision_model is not None: return model = AutoModel.from_pretrained( self.config.vision_model_name_or_path, **self._vision_model_kwargs, ) self.vision_model = getattr(model, "vision_model", model) self.vision_model.eval() def _load_projector_weights(self): projector_path = self._resolve_repo_file("projector.safetensors") state_dict = load_file(projector_path) self.projector.load_state_dict(state_dict) embed_weight = self.text_model.get_input_embeddings().weight self.projector.to(device=embed_weight.device, dtype=embed_weight.dtype) @property def vision_loaded(self) -> bool: return self.vision_model is not None @staticmethod def _module_device_dtype(module: nn.Module) -> tuple[torch.device, torch.dtype]: param = next(module.parameters()) return param.device, param.dtype @staticmethod def _chunk_list(items: list[Any], chunk_size: int) -> list[list[Any]]: return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: self._load_vision_model() vision_device, vision_dtype = self._module_device_dtype(self.vision_model) pixel_values = pixel_values.to(device=vision_device, dtype=vision_dtype) outputs = self.vision_model(pixel_values=pixel_values) pooled = getattr(outputs, "pooler_output", None) if pooled is None: last_hidden = getattr(outputs, "last_hidden_state", None) if last_hidden is None: raise ValueError("Vision model did not return pooler_output or last_hidden_state.") pooled = last_hidden[:, 0] return pooled def _prompt_inputs(self, batch_size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: encoded = self.tokenizer( [self.config.prompt_prefix] * batch_size, add_special_tokens=False, return_tensors="pt", padding=True, ) return encoded["input_ids"].to(device), encoded["attention_mask"].to(device) def _prepare_inputs( self, *, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, pooled_embeddings: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if pooled_embeddings is None: if pixel_values is None: raise ValueError("Provide either `pooled_embeddings` or `pixel_values`.") pooled_embeddings = self.encode_images(pixel_values) if pooled_embeddings.ndim == 1: pooled_embeddings = pooled_embeddings.unsqueeze(0) target_device = self.text_model.get_input_embeddings().weight.device if input_ids is None: input_ids, attention_mask = self._prompt_inputs(pooled_embeddings.size(0), target_device) else: input_ids = input_ids.to(target_device) if attention_mask is None: attention_mask = torch.ones_like(input_ids, device=target_device) else: attention_mask = attention_mask.to(target_device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids) pooled_embeddings = pooled_embeddings.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) projected = self.projector(pooled_embeddings) image_mask = input_ids.eq(self.config.image_token_id) image_count = image_mask.sum(dim=1) if not torch.all(image_count == 1): raise ValueError("Each sample must contain exactly one `` token.") token_positions = image_mask.float().argmax(dim=1) batch_positions = torch.arange(input_ids.size(0), device=input_ids.device) inputs_embeds[batch_positions, token_positions] = projected return inputs_embeds, attention_mask def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, pooled_embeddings: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, labels: torch.Tensor | None = None, **kwargs: Any, ): if input_ids is None and labels is not None: raise ValueError("`input_ids` are required when passing `labels`.") inputs_embeds, attention_mask = self._prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, pooled_embeddings=pooled_embeddings, pixel_values=pixel_values, ) return self.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs, ) @torch.no_grad() def generate( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, pooled_embeddings: torch.Tensor | None = None, pixel_values: torch.Tensor | None = None, **generate_kwargs: Any, ) -> torch.Tensor: inputs_embeds, attention_mask = self._prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, pooled_embeddings=pooled_embeddings, pixel_values=pixel_values, ) generate_kwargs.setdefault("do_sample", False) generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id) generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id) return self.text_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generate_kwargs, ) @torch.no_grad() def generate_captions( self, *, images: Any = None, pooled_embeddings: Any = None, processor=None, vision_batch_size: int = 64, decode_batch_size: int = 1024, **generate_kwargs: Any, ) -> list[str]: if processor is None: raise ValueError("`processor` is required for `generate_captions()`.") if images is None and pooled_embeddings is None: raise ValueError("Provide either `images` or `pooled_embeddings`.") if images is not None and pooled_embeddings is not None: raise ValueError("Provide only one of `images` or `pooled_embeddings`.") if vision_batch_size <= 0 or decode_batch_size <= 0: raise ValueError("Batch sizes must be positive integers.") if images is not None: image_items = processor.normalize_images(images) all_pooled = [] for image_chunk in self._chunk_list(image_items, vision_batch_size): image_inputs = processor(images=image_chunk, return_tensors="pt") pooled_chunk = self.encode_images(image_inputs["pixel_values"]).detach().cpu() all_pooled.append(pooled_chunk) pooled_embeddings = torch.cat(all_pooled, dim=0) else: pooled_embeddings = processor.normalize_pooled_embeddings(pooled_embeddings).detach().cpu() captions = [] total = pooled_embeddings.shape[0] for start in range(0, total, decode_batch_size): pooled_chunk = pooled_embeddings[start : start + decode_batch_size] model_inputs = dict(processor( pooled_embeddings=pooled_chunk, return_tensors="pt", )) sequences = self.generate(**model_inputs, **generate_kwargs) captions.extend(processor.batch_decode(sequences, skip_special_tokens=True)) return captions