| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import tiktoken |
| from huggingface_hub import hf_hub_download |
| from transformers import CLIPVisionModel, CLIPImageProcessor |
|
|
| from model import GPT |
|
|
| |
| |
| |
| REPO_ID = "HayatoHongo/everyoneschat-checkpoints" |
| FILENAME = "checkpoint_010000_vision_instructv2.pt" |
| VISION_ENCODER = "openai/clip-vit-large-patch14" |
| NUM_IMAGE_PATCHES = 256 |
| PAD_TOKEN_ID = 50256 |
| IGNORE_INDEX = -100 |
| VISION_PROJECTOR_HIDDEN_DIM = 2048 |
|
|
| tokenizer = tiktoken.get_encoding("gpt2") |
| image_processor = CLIPImageProcessor.from_pretrained(VISION_ENCODER) |
|
|
|
|
| |
| |
| |
| from dataclasses import dataclass, fields |
|
|
| @dataclass |
| class ModelConfig: |
| input_sequence_length: int |
| max_sequence_length: int |
| embedding_dim: int |
| hidden_dim: int |
| num_attention_heads: int |
| layer_count: int |
| rope_theta: float |
| vocab_size: int |
| device_type: str |
| random_seed_value: int |
| autocast_dtype: torch.dtype |
|
|
|
|
| |
| |
| |
| class VLM(nn.Module): |
| def __init__(self, llm): |
| super().__init__() |
| self.llm = llm |
|
|
| for p in self.llm.parameters(): |
| p.requires_grad = False |
|
|
| self.vision = CLIPVisionModel.from_pretrained(VISION_ENCODER) |
| for p in self.vision.parameters(): |
| p.requires_grad = False |
|
|
| self.projector = nn.Sequential( |
| nn.Linear(self.vision.config.hidden_size, VISION_PROJECTOR_HIDDEN_DIM), |
| nn.GELU(), |
| nn.Linear(VISION_PROJECTOR_HIDDEN_DIM, llm.config.embedding_dim), |
| ) |
|
|
|
|
| |
| |
| |
| def load_vlm_model(): |
| ckpt_path = hf_hub_download( |
| repo_id=REPO_ID, |
| filename=FILENAME, |
| repo_type="model" |
| ) |
|
|
| checkpoint = torch.load(ckpt_path, map_location="cpu") |
| config_dict = checkpoint["config"] |
|
|
| if isinstance(config_dict.get("autocast_dtype"), str): |
| config_dict["autocast_dtype"] = getattr( |
| torch, config_dict["autocast_dtype"].split(".")[-1] |
| ) |
|
|
| model_config_fields = {f.name for f in fields(ModelConfig)} |
| filtered = {k: v for k, v in config_dict.items() if k in model_config_fields} |
| config = ModelConfig(**filtered) |
|
|
| llm = GPT(config) |
| model = VLM(llm) |
|
|
| model.load_state_dict(checkpoint["model_state_dict"], strict=True) |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def vlm_prefill(model, image_tensor, input_ids): |
| x = model.llm.token_embedding_layer(input_ids) |
|
|
| v = model.vision(image_tensor, output_hidden_states=True) |
| v = v.hidden_states[-1][:, 1:] |
| v = model.projector(v) |
|
|
| x = torch.cat([v, x[:, NUM_IMAGE_PATCHES:]], dim=1) |
|
|
| for block in model.llm.blocks: |
| x = block(x, use_cache=True) |
|
|
| return x |
|
|
|
|
| @torch.no_grad() |
| def vlm_next_token(model, input_ids, temperature, top_k, top_p): |
| x = model.llm.token_embedding_layer(input_ids) |
|
|
| for block in model.llm.blocks: |
| x = block(x, use_cache=True) |
|
|
| logits = model.llm.vocab_projection(x)[:, -1, :] / temperature |
|
|
| if top_k: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits = torch.where(logits < v[:, -1:], -float("inf"), logits) |
|
|
| if top_p: |
| s_logits, s_idx = torch.sort(logits, descending=True) |
| probs = F.softmax(s_logits, dim=-1) |
| cum = probs.cumsum(dim=-1) |
| mask = cum > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| s_logits[mask] = -float("inf") |
| logits = torch.zeros_like(logits).scatter(-1, s_idx, s_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, 1) |
|
|
|
|
| def vlm_infer_stream( |
| model, |
| image_tensor, |
| prompt, |
| max_new_tokens=256, |
| temperature=0.7, |
| top_k=None, |
| top_p=None, |
| stop_ids={50256}, |
| ): |
| device = next(model.parameters()).device |
| prompt_ids = tokenizer.encode(prompt, allowed_special="all") |
|
|
| input_ids = ( |
| [PAD_TOKEN_ID] * NUM_IMAGE_PATCHES + prompt_ids |
| ) |
| input_ids = torch.tensor(input_ids, device=device)[None] |
|
|
| for block in model.llm.blocks: |
| block.multihead_attention.reset_cache() |
|
|
| x = vlm_prefill(model, image_tensor, input_ids) |
| logits = model.llm.vocab_projection(x)[:, -1, :] / temperature |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, 1) |
|
|
| acc, last = [], "" |
|
|
| for _ in range(max_new_tokens): |
| |
| tid = int(next_token.item()) |
| if tid in stop_ids: |
| break |
|
|
|
|
| acc.append(tid) |
| text = tokenizer.decode(acc) |
| if not text.endswith("�"): |
| new = text[len(last):] |
| if new: |
| yield new |
| last = text |
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| next_token = vlm_next_token( |
| model, |
| input_ids[:, -1:], |
| temperature, |
| top_k, |
| top_p, |
| ) |
|
|