| import json
|
| import os
|
| import tempfile
|
| from dataclasses import asdict
|
| from typing import Optional
|
|
|
|
|
| from models.utils import top_k_top_p_filtering
|
| from models.vision_transformer import ViT
|
| from models.language_model import LanguageModel
|
| from models.modality_projector import ModalityProjector
|
| from models.config import VLMConfig
|
|
|
| from data.processors import get_tokenizer
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from safetensors.torch import load_model, save_model
|
|
|
| class VisionLanguageModel(nn.Module):
|
| def __init__(self, cfg: VLMConfig, load_backbone=True):
|
| super().__init__()
|
| self.cfg = cfg
|
| if load_backbone:
|
| print("Loading from backbone weights")
|
| self.vision_encoder = ViT.from_pretrained(cfg)
|
| self.decoder = LanguageModel.from_pretrained(cfg)
|
| else:
|
| self.vision_encoder = ViT(cfg)
|
| self.decoder = LanguageModel(cfg)
|
| self.MP = ModalityProjector(cfg)
|
| self.load_backbone = load_backbone
|
| self.tokenizer = get_tokenizer(cfg.lm_tokenizer, cfg.vlm_extra_tokens, cfg.lm_chat_template)
|
|
|
| def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd):
|
| """
|
| Replace every image-token placeholder in `input_ids` with the corresponding slice
|
| from `image_embd`. Supports an arbitrary number of image-token placeholders per sample.
|
| The first example in the batch might have 2 images and the second none.
|
| """
|
|
|
| updated_token_embd = token_embd.clone()
|
|
|
|
|
| mask = (input_ids == self.tokenizer.image_token_id)
|
| updated_token_embd[mask] = image_embd.view(-1, image_embd.size(-1)).to(updated_token_embd.dtype)
|
|
|
| return updated_token_embd
|
|
|
| def _process_images(self, images, device):
|
| if isinstance(images, list):
|
| if images and isinstance(images[0], list):
|
| images = [img for sublist in images for img in sublist]
|
|
|
| if not images:
|
| return None
|
| else:
|
| return torch.cat(images, dim=0).to(device)
|
| return images
|
|
|
| def forward(self, input_ids, images, attention_mask=None, targets=None):
|
| images_tensor = self._process_images(images, input_ids.device)
|
| token_embd = self.decoder.token_embedding(input_ids)
|
|
|
| if images_tensor is not None:
|
| image_embd = self.vision_encoder(images_tensor)
|
| image_embd = self.MP(image_embd)
|
| token_embd = self._replace_img_tokens_with_embd(input_ids, token_embd, image_embd)
|
|
|
| logits, _ = self.decoder(token_embd, attention_mask=attention_mask)
|
|
|
| loss = None
|
| if targets is not None:
|
| logits = self.decoder.head(logits)
|
|
|
|
|
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
|
|
|
| return logits, loss
|
|
|
| @torch.inference_mode()
|
| def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
|
| images_tensor = self._process_images(images, input_ids.device)
|
| token_embd = self.decoder.token_embedding(input_ids)
|
|
|
| if images_tensor is not None:
|
|
|
| image_embd = self.vision_encoder(images_tensor)
|
| image_embd = self.MP(image_embd)
|
|
|
| token_embd = self._replace_img_tokens_with_embd(input_ids, token_embd, image_embd)
|
|
|
| current_total_seq_len = token_embd.size(1)
|
| batch_size = input_ids.size(0)
|
|
|
|
|
| prefill_output, kv_cache_list = self.decoder(
|
| token_embd,
|
| attention_mask=attention_mask,
|
| kv_cache=None,
|
| start_pos=0
|
| )
|
|
|
| last_token_output_from_prefill = prefill_output[:, -1, :]
|
|
|
| if not self.decoder.lm_use_tokens:
|
| current_logits = self.decoder.head(last_token_output_from_prefill)
|
| else:
|
| current_logits = last_token_output_from_prefill
|
|
|
|
|
| newly_generated_ids_list = []
|
|
|
|
|
| for _ in range(max_new_tokens):
|
| if greedy:
|
| next_token_id = torch.argmax(current_logits, dim=-1, keepdim=True)
|
| else:
|
| filtered_logits = top_k_top_p_filtering(current_logits, top_k=top_k, top_p=top_p)
|
| probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
| next_token_id = torch.multinomial(probs, num_samples=1)
|
|
|
| newly_generated_ids_list.append(next_token_id)
|
|
|
|
|
| next_token_embed = self.decoder.token_embedding(next_token_id)
|
|
|
|
|
| current_token_start_pos = current_total_seq_len
|
| current_total_seq_len += 1
|
|
|
|
|
| if attention_mask is not None:
|
| attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)
|
|
|
|
|
| decode_step_output, kv_cache_list = self.decoder(
|
| next_token_embed,
|
| attention_mask=attention_mask,
|
| kv_cache=kv_cache_list,
|
| start_pos=current_token_start_pos
|
| )
|
|
|
| last_token_output = decode_step_output[:, -1, :]
|
|
|
|
|
| if not self.decoder.lm_use_tokens:
|
| current_logits = self.decoder.head(last_token_output)
|
| else:
|
| current_logits = last_token_output
|
|
|
| if not newly_generated_ids_list:
|
| return torch.empty((batch_size,0), dtype=torch.long, device=input_ids.device)
|
|
|
| generated_ids = torch.cat(newly_generated_ids_list, dim=1)
|
|
|
|
|
| if self.tokenizer.eos_token_id is not None and generated_ids.numel() > 0:
|
| seq_len = generated_ids.size(1)
|
| device = generated_ids.device
|
|
|
| eos_mask = (generated_ids == self.tokenizer.eos_token_id)
|
|
|
| col_indices_for_min = torch.arange(seq_len, device=device)
|
|
|
|
|
| masked_col_indices = torch.where(eos_mask, col_indices_for_min.unsqueeze(0).expand_as(generated_ids), seq_len + 1)
|
|
|
| first_eos_indices_values = torch.min(masked_col_indices, dim=1).values
|
|
|
|
|
| actual_first_eos_indices = torch.clamp(first_eos_indices_values, max=seq_len)
|
|
|
|
|
| col_indices_for_comparison = torch.arange(seq_len, device=device).unsqueeze(0).expand_as(generated_ids)
|
|
|
|
|
| replace_mask = col_indices_for_comparison > actual_first_eos_indices.unsqueeze(1)
|
|
|
| generated_ids[replace_mask] = self.tokenizer.eos_token_id
|
|
|
| return generated_ids
|
|
|
| @classmethod
|
| def from_pretrained(
|
| cls, repo_id_or_path: str, *, revision: Optional[str] = None
|
| ) -> "VisionLanguageModel":
|
| """
|
| Load a VisionLanguageModel from a local directory or a repo on the Hugging Face Hub.
|
|
|
| Args:
|
| repo_id_or_path (str): The path to the local directory or the Hugging Face Hub repo ID.
|
|
|
| Returns:
|
| VisionLanguageModel: The loaded model.
|
| """
|
|
|
| if os.path.exists(repo_id_or_path):
|
| config_path = os.path.join(repo_id_or_path, "config.json")
|
| weights_path = os.path.join(repo_id_or_path, "model.safetensors")
|
|
|
| if not os.path.exists(config_path):
|
| raise ValueError(
|
| f"Config file not found at {config_path}. Please provide a valid path."
|
| )
|
| if not os.path.exists(weights_path):
|
| raise ValueError(
|
| f"Weights file not found at {weights_path}. Please provide a valid path."
|
| )
|
|
|
| else:
|
| from huggingface_hub import hf_hub_download
|
|
|
| config_path = hf_hub_download(
|
| repo_id=repo_id_or_path, filename="config.json", revision=revision
|
| )
|
| weights_path = hf_hub_download(
|
| repo_id=repo_id_or_path, filename="model.safetensors", revision=revision
|
| )
|
|
|
|
|
| with open(config_path, "r") as f:
|
| cfg = VLMConfig(**json.load(f))
|
|
|
|
|
| model = cls(cfg, load_backbone=False)
|
|
|
|
|
| load_model(model, weights_path)
|
|
|
|
|
| return model
|
|
|
| def save_pretrained(self, save_directory: str) -> None:
|
| """
|
| Save the model and configuration to a directory.
|
|
|
| Args:
|
| save_directory (str): The directory to save the model and config.
|
| """
|
|
|
| os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
| with open(os.path.join(save_directory, "config.json"), "w") as f:
|
| f.write(json.dumps(asdict(self.cfg), indent=4))
|
|
|
|
|
| save_model(self, os.path.join(save_directory, "model.safetensors"))
|
|
|
| def push_to_hub(self, repo_id: str, private: bool = False) -> None:
|
| """
|
| Push the model and configuration to the Hugging Face Hub.
|
|
|
| Args:
|
| repo_id (str): The repo ID on the Hugging Face Hub.
|
| """
|
| from huggingface_hub import create_repo, upload_folder
|
|
|
|
|
| repo_url = create_repo(repo_id=repo_id, private=private, exist_ok=True)
|
| repo_id = repo_url.repo_id
|
| print("Created repo: ", repo_url)
|
|
|
| with tempfile.TemporaryDirectory() as save_path:
|
|
|
| self.save_pretrained(save_path)
|
|
|
|
|
| with open(os.path.join(save_path, "README.md"), "w") as f:
|
| f.write(MODEL_CARD_TEMPLATE.format(repo_id=repo_id))
|
|
|
|
|
| return upload_folder(
|
| repo_id=repo_id,
|
| repo_type="model",
|
| folder_path=save_path,
|
| commit_message="Upload nanoVLM using push_to_hub",
|
| )
|
|
|
|
|
| MODEL_CARD_TEMPLATE = """
|
| ---
|
| # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
| # Doc / guide: https://huggingface.co/docs/hub/model-cards
|
| library_name: nanovlm
|
| license: mit
|
| pipeline_tag: image-text-to-text
|
| tags:
|
| - vision-language
|
| - multimodal
|
| - research
|
| ---
|
|
|
| **nanoVLM** is a minimal and lightweight Vision-Language Model (VLM) designed for efficient training and experimentation. Built using pure PyTorch, the entire model architecture and training logic fits within ~750 lines of code. It combines a ViT-based image encoder (SigLIP-B/16-224-85M) with a lightweight causal language model (SmolLM2-135M), resulting in a compact 222M parameter model.
|
|
|
| For more information, check out the base model on https://huggingface.co/lusxvr/nanoVLM-222M.
|
|
|
| **Usage:**
|
|
|
| Clone the nanoVLM repository: https://github.com/huggingface/nanoVLM.
|
| Follow the install instructions and run the following code:
|
|
|
| ```python
|
| from models.vision_language_model import VisionLanguageModel
|
|
|
| model = VisionLanguageModel.from_pretrained("{repo_id}")
|
| ```
|
| """
|
|
|