Spaces:
Sleeping
Sleeping
| """MiniLLaVA β CLIP-ViT + MultiModalProjector + Qwen2.5 Causal LM. | |
| LLaVA-1.5μ ν΅μ¬ μν€ν μ²λ₯Ό μ§μ ꡬν. HuggingFaceμ LlavaForConditionalGeneration | |
| κ°μ κ³ μμ€ ν΄λμ€λ₯Ό μ¬μ©νμ§ μκ³ , ν μ€νΈ/μ΄λ―Έμ§ μλ² λ© μ΅ν©κ³Ό splice λ‘μ§μ | |
| μ μμ€μμ μ§μ λ€λ£¬λ€. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| CLIPImageProcessor, | |
| CLIPVisionModel, | |
| ) | |
| from .config import IGNORE_INDEX, IMAGE_TOKEN, LLM_MODEL, VISION_MODEL | |
| class MultiModalProjector(nn.Module): | |
| """CLIPμ μκ° νΉμ§μ LLMμ μλ² λ© κ³΅κ°μΌλ‘ λ§€ννλ 2-layer MLP. | |
| LLaVA-1.5μ 'mlp2x_gelu' projectorλ₯Ό κ·Έλλ‘ λ°λ₯Έλ€. | |
| """ | |
| def __init__(self, vision_hidden_size: int, llm_hidden_size: int): | |
| super().__init__() | |
| self.fc1 = nn.Linear(vision_hidden_size, llm_hidden_size) | |
| self.act = nn.GELU() | |
| self.fc2 = nn.Linear(llm_hidden_size, llm_hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc2(self.act(self.fc1(x))) | |
| class MiniLLaVA(nn.Module): | |
| """Vision-Language Model. | |
| - CLIP-ViTλ νμ frozen (κ°λ ₯ν μ¬μ νμ΅ μκ° νν νμ©) | |
| - LLMμ κΈ°λ³Έ frozen (LLaVA Stage 1 alignment) | |
| - Projectorλ§ νμ΅ β 1.6M params λ§μΌλ‘ λ©ν°λͺ¨λ¬ λ₯λ ₯ λΆμ¬ | |
| """ | |
| def __init__( | |
| self, | |
| vision_model_name: str = VISION_MODEL, | |
| llm_model_name: str = LLM_MODEL, | |
| freeze_vision: bool = True, | |
| freeze_llm: bool = True, | |
| torch_dtype: torch.dtype = torch.float32, | |
| ): | |
| super().__init__() | |
| self.vision = CLIPVisionModel.from_pretrained(vision_model_name) | |
| self.image_processor = CLIPImageProcessor.from_pretrained(vision_model_name) | |
| self.llm = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, torch_dtype=torch_dtype | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| # <image> νλ μ΄μ€νλ μΆκ° | |
| if IMAGE_TOKEN not in self.tokenizer.get_vocab(): | |
| self.tokenizer.add_special_tokens( | |
| {"additional_special_tokens": [IMAGE_TOKEN]} | |
| ) | |
| self.llm.resize_token_embeddings(len(self.tokenizer)) | |
| self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
| if self.tokenizer.pad_token_id is None: | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| vision_hidden = self.vision.config.hidden_size | |
| llm_hidden = self.llm.config.hidden_size | |
| self.projector = MultiModalProjector(vision_hidden, llm_hidden) | |
| if freeze_vision: | |
| for p in self.vision.parameters(): | |
| p.requires_grad = False | |
| self.vision.eval() | |
| if freeze_llm: | |
| for p in self.llm.parameters(): | |
| p.requires_grad = False | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Encoding | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| """[B, 3, H, W] β [B, N_patches, D_llm]. CLS ν ν° μ μΈ.""" | |
| outputs = self.vision(pixel_values=pixel_values) | |
| patch_features = outputs.last_hidden_state[:, 1:, :] | |
| return self.projector(patch_features) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Embedding fusion: <image> μμΉλ₯Ό patch tokensλ‘ splice | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _merge( | |
| self, | |
| text_embeds: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| image_embeds: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| ): | |
| """input_idsμμ <image> μμΉλ₯Ό image_embeds(Nκ° patch)λ‘ κ΅μ²΄. | |
| - λͺ¨λ μνμ μ νν 1κ°μ <image> ν ν°μ κ°μ§λ€κ³ κ°μ | |
| - text/mask/labelμ λͺ¨λ μΌκ΄λκ² μ¬μ λ ¬ | |
| """ | |
| B, L, D = text_embeds.shape | |
| N = image_embeds.shape[1] | |
| new_L = L - 1 + N | |
| device = text_embeds.device | |
| merged_embeds = torch.zeros(B, new_L, D, dtype=text_embeds.dtype, device=device) | |
| merged_mask = torch.zeros(B, new_L, dtype=attention_mask.dtype, device=device) | |
| merged_labels = ( | |
| torch.full((B, new_L), IGNORE_INDEX, dtype=torch.long, device=device) | |
| if labels is not None | |
| else None | |
| ) | |
| for b in range(B): | |
| img_pos = (input_ids[b] == self.image_token_id).nonzero(as_tuple=True)[0] | |
| if len(img_pos) != 1: | |
| raise ValueError( | |
| f"sample {b}λ <image> ν ν°μ΄ {len(img_pos)}κ° β μ νν 1κ°μ¬μΌ ν©λλ€." | |
| ) | |
| p = img_pos.item() | |
| # μ / μ΄λ―Έμ§ / λ€ μμΌλ‘ splice | |
| merged_embeds[b, :p] = text_embeds[b, :p] | |
| merged_embeds[b, p : p + N] = image_embeds[b] | |
| merged_embeds[b, p + N :] = text_embeds[b, p + 1 :] | |
| merged_mask[b, :p] = attention_mask[b, :p] | |
| merged_mask[b, p : p + N] = 1 | |
| merged_mask[b, p + N :] = attention_mask[b, p + 1 :] | |
| if labels is not None: | |
| merged_labels[b, :p] = labels[b, :p] | |
| # μ΄λ―Έμ§ patch μμΉλ IGNORE_INDEX μ μ§ (μ΄λ―Έ μ±μλ ) | |
| merged_labels[b, p + N :] = labels[b, p + 1 :] | |
| return merged_embeds, merged_mask, merged_labels | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Forward (νμ΅) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| pixel_values: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| ): | |
| text_embeds = self.llm.get_input_embeddings()(input_ids) | |
| image_embeds = self.encode_image(pixel_values) | |
| merged_embeds, merged_mask, merged_labels = self._merge( | |
| text_embeds, attention_mask, image_embeds, input_ids, labels | |
| ) | |
| return self.llm( | |
| inputs_embeds=merged_embeds, | |
| attention_mask=merged_mask, | |
| labels=merged_labels, | |
| return_dict=True, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation (μΆλ‘ ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| pixel_values: torch.Tensor, | |
| max_new_tokens: int = 128, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| do_sample: bool = True, | |
| ) -> torch.Tensor: | |
| text_embeds = self.llm.get_input_embeddings()(input_ids) | |
| image_embeds = self.encode_image(pixel_values) | |
| merged_embeds, merged_mask, _ = self._merge( | |
| text_embeds, attention_mask, image_embeds, input_ids, labels=None | |
| ) | |
| return self.llm.generate( | |
| inputs_embeds=merged_embeds, | |
| attention_mask=merged_mask, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Checkpoint I/O β projectorλ§ μ μ₯ (LLM/CLIPμ HFμμ λ€μ λ‘λ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save_projector(self, path: str) -> None: | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| torch.save(self.projector.state_dict(), path) | |
| def load_projector(self, path: str, map_location: str = "cpu") -> None: | |
| state = torch.load(path, map_location=map_location) | |
| self.projector.load_state_dict(state) | |
| def load_lora_adapter(self, adapter_path: str) -> None: | |
| """νμ΅λ LoRA adapterλ₯Ό frozen LLM μμ λΆμ°©.""" | |
| from peft import PeftModel | |
| self.llm = PeftModel.from_pretrained(self.llm, adapter_path) | |
| self.llm.eval() | |
| def trainable_parameters(self): | |
| return [p for p in self.parameters() if p.requires_grad] | |
| def num_trainable(self) -> int: | |
| return sum(p.numel() for p in self.trainable_parameters()) | |