capri / modeling_capri.py
Ligul's picture
Upload folder using huggingface_hub
fd6509b verified
import os
from typing import Any
import torch
import torch.nn as nn
from peft import PeftModel
from safetensors.torch import load_file, save_file
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.utils import cached_file
from .configuration_capri import CapriConfig
class MLPProjector(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class CapriForConditionalGeneration(PreTrainedModel):
config_class = CapriConfig
base_model_prefix = "capri"
main_input_name = "input_ids"
def __init__(self, config: CapriConfig):
super().__init__(config)
self.projector = MLPProjector(
in_dim=config.projector_in_dim,
hidden_dim=config.projector_hidden_dim,
out_dim=config.projector_out_dim,
)
self.text_model = None
self.vision_model = None
self.tokenizer = None
self._repo_id_or_path = None
self._hub_kwargs = {}
self._text_model_kwargs = {}
self._vision_model_kwargs = {}
self.post_init()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, config=None, **kwargs):
load_vision_tower = kwargs.pop("load_vision_tower", None)
if config is None:
config, model_kwargs = CapriConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
**kwargs,
)
else:
model_kwargs = dict(kwargs)
model = cls(config, *model_args)
model._repo_id_or_path = pretrained_model_name_or_path
model._hub_kwargs = {
"cache_dir": model_kwargs.get("cache_dir"),
"force_download": model_kwargs.get("force_download"),
"local_files_only": model_kwargs.get("local_files_only"),
"revision": model_kwargs.get("revision"),
"token": model_kwargs.get("token"),
"trust_remote_code": model_kwargs.get("trust_remote_code", True),
}
base_runtime = {
"cache_dir": model_kwargs.get("cache_dir"),
"force_download": model_kwargs.get("force_download"),
"local_files_only": model_kwargs.get("local_files_only"),
"revision": model_kwargs.get("revision"),
"token": model_kwargs.get("token"),
"torch_dtype": model_kwargs.get("torch_dtype", model_kwargs.get("dtype")),
"device_map": model_kwargs.get("device_map"),
"attn_implementation": model_kwargs.get("attn_implementation"),
}
model._text_model_kwargs = {k: v for k, v in base_runtime.items() if v is not None}
model._vision_model_kwargs = {k: v for k, v in base_runtime.items() if k != "attn_implementation" and v is not None}
model._load_tokenizer()
model._load_text_model()
model._load_projector_weights()
should_load_vision = (
config.load_vision_tower_by_default if load_vision_tower is None else load_vision_tower
)
if should_load_vision:
model._load_vision_model()
model.eval()
return model
def save_pretrained(self, save_directory: str, **kwargs):
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
save_file(
self.projector.state_dict(),
os.path.join(save_directory, "projector.safetensors"),
)
if self.text_model is not None:
self.text_model.save_pretrained(
os.path.join(save_directory, self.config.adapter_subdir)
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
def _resolve_repo_file(self, filename: str, subfolder: str | None = None) -> str:
if os.path.isdir(self._repo_id_or_path):
parts = [self._repo_id_or_path]
if subfolder:
parts.append(subfolder)
parts.append(filename)
return os.path.join(*parts)
return cached_file(self._repo_id_or_path, filename, subfolder=subfolder, **self._hub_kwargs)
def _load_tokenizer(self):
if self.tokenizer is not None:
return
if self.config.image_token_id is None or self.config.image_token is None:
raise ValueError("`image_token_id` and `image_token` must be set in the config.")
self.tokenizer = AutoTokenizer.from_pretrained(
self._repo_id_or_path,
**self._hub_kwargs,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def _load_text_model(self):
if self.text_model is not None:
return
base_model = AutoModelForCausalLM.from_pretrained(
self.config.text_model_name_or_path,
**self._text_model_kwargs,
)
self.text_model = PeftModel.from_pretrained(
base_model,
self._repo_id_or_path,
subfolder=self.config.adapter_subdir,
is_trainable=False,
**self._hub_kwargs,
)
self.text_model.eval()
def _load_vision_model(self):
if self.vision_model is not None:
return
model = AutoModel.from_pretrained(
self.config.vision_model_name_or_path,
**self._vision_model_kwargs,
)
self.vision_model = getattr(model, "vision_model", model)
self.vision_model.eval()
def _load_projector_weights(self):
projector_path = self._resolve_repo_file("projector.safetensors")
state_dict = load_file(projector_path)
self.projector.load_state_dict(state_dict)
embed_weight = self.text_model.get_input_embeddings().weight
self.projector.to(device=embed_weight.device, dtype=embed_weight.dtype)
@property
def vision_loaded(self) -> bool:
return self.vision_model is not None
@staticmethod
def _module_device_dtype(module: nn.Module) -> tuple[torch.device, torch.dtype]:
param = next(module.parameters())
return param.device, param.dtype
@staticmethod
def _chunk_list(items: list[Any], chunk_size: int) -> list[list[Any]]:
return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
self._load_vision_model()
vision_device, vision_dtype = self._module_device_dtype(self.vision_model)
pixel_values = pixel_values.to(device=vision_device, dtype=vision_dtype)
outputs = self.vision_model(pixel_values=pixel_values)
pooled = getattr(outputs, "pooler_output", None)
if pooled is None:
last_hidden = getattr(outputs, "last_hidden_state", None)
if last_hidden is None:
raise ValueError("Vision model did not return pooler_output or last_hidden_state.")
pooled = last_hidden[:, 0]
return pooled
def _prompt_inputs(self, batch_size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
encoded = self.tokenizer(
[self.config.prompt_prefix] * batch_size,
add_special_tokens=False,
return_tensors="pt",
padding=True,
)
return encoded["input_ids"].to(device), encoded["attention_mask"].to(device)
def _prepare_inputs(
self,
*,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pooled_embeddings: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if pooled_embeddings is None:
if pixel_values is None:
raise ValueError("Provide either `pooled_embeddings` or `pixel_values`.")
pooled_embeddings = self.encode_images(pixel_values)
if pooled_embeddings.ndim == 1:
pooled_embeddings = pooled_embeddings.unsqueeze(0)
target_device = self.text_model.get_input_embeddings().weight.device
if input_ids is None:
input_ids, attention_mask = self._prompt_inputs(pooled_embeddings.size(0), target_device)
else:
input_ids = input_ids.to(target_device)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, device=target_device)
else:
attention_mask = attention_mask.to(target_device)
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
pooled_embeddings = pooled_embeddings.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
projected = self.projector(pooled_embeddings)
image_mask = input_ids.eq(self.config.image_token_id)
image_count = image_mask.sum(dim=1)
if not torch.all(image_count == 1):
raise ValueError("Each sample must contain exactly one `<image>` token.")
token_positions = image_mask.float().argmax(dim=1)
batch_positions = torch.arange(input_ids.size(0), device=input_ids.device)
inputs_embeds[batch_positions, token_positions] = projected
return inputs_embeds, attention_mask
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pooled_embeddings: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs: Any,
):
if input_ids is None and labels is not None:
raise ValueError("`input_ids` are required when passing `labels`.")
inputs_embeds, attention_mask = self._prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
pooled_embeddings=pooled_embeddings,
pixel_values=pixel_values,
)
return self.text_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs,
)
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pooled_embeddings: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
**generate_kwargs: Any,
) -> torch.Tensor:
inputs_embeds, attention_mask = self._prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
pooled_embeddings=pooled_embeddings,
pixel_values=pixel_values,
)
generate_kwargs.setdefault("do_sample", False)
generate_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id)
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
return self.text_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_kwargs,
)
@torch.no_grad()
def generate_captions(
self,
*,
images: Any = None,
pooled_embeddings: Any = None,
processor=None,
vision_batch_size: int = 64,
decode_batch_size: int = 1024,
**generate_kwargs: Any,
) -> list[str]:
if processor is None:
raise ValueError("`processor` is required for `generate_captions()`.")
if images is None and pooled_embeddings is None:
raise ValueError("Provide either `images` or `pooled_embeddings`.")
if images is not None and pooled_embeddings is not None:
raise ValueError("Provide only one of `images` or `pooled_embeddings`.")
if vision_batch_size <= 0 or decode_batch_size <= 0:
raise ValueError("Batch sizes must be positive integers.")
if images is not None:
image_items = processor.normalize_images(images)
all_pooled = []
for image_chunk in self._chunk_list(image_items, vision_batch_size):
image_inputs = processor(images=image_chunk, return_tensors="pt")
pooled_chunk = self.encode_images(image_inputs["pixel_values"]).detach().cpu()
all_pooled.append(pooled_chunk)
pooled_embeddings = torch.cat(all_pooled, dim=0)
else:
pooled_embeddings = processor.normalize_pooled_embeddings(pooled_embeddings).detach().cpu()
captions = []
total = pooled_embeddings.shape[0]
for start in range(0, total, decode_batch_size):
pooled_chunk = pooled_embeddings[start : start + decode_batch_size]
model_inputs = dict(processor(
pooled_embeddings=pooled_chunk,
return_tensors="pt",
))
sequences = self.generate(**model_inputs, **generate_kwargs)
captions.extend(processor.batch_decode(sequences, skip_special_tokens=True))
return captions