HayatoHongoEveryonesAI's picture
Update vlm_inference.py
95dc99b verified
# vlm_inference.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
from huggingface_hub import hf_hub_download
from transformers import CLIPVisionModel, CLIPImageProcessor
from model import GPT
# =====================================================
# Constants
# =====================================================
REPO_ID = "HayatoHongo/everyoneschat-checkpoints"
FILENAME = "checkpoint_010000_vision_instructv2.pt"
VISION_ENCODER = "openai/clip-vit-large-patch14"
NUM_IMAGE_PATCHES = 256
PAD_TOKEN_ID = 50256
IGNORE_INDEX = -100
VISION_PROJECTOR_HIDDEN_DIM = 2048
tokenizer = tiktoken.get_encoding("gpt2")
image_processor = CLIPImageProcessor.from_pretrained(VISION_ENCODER)
# =====================================================
# ModelConfig (same as Colab)
# =====================================================
from dataclasses import dataclass, fields
@dataclass
class ModelConfig:
input_sequence_length: int
max_sequence_length: int
embedding_dim: int
hidden_dim: int
num_attention_heads: int
layer_count: int
rope_theta: float
vocab_size: int
device_type: str
random_seed_value: int
autocast_dtype: torch.dtype
# =====================================================
# VLM wrapper
# =====================================================
class VLM(nn.Module):
def __init__(self, llm):
super().__init__()
self.llm = llm
for p in self.llm.parameters():
p.requires_grad = False
self.vision = CLIPVisionModel.from_pretrained(VISION_ENCODER)
for p in self.vision.parameters():
p.requires_grad = False
self.projector = nn.Sequential(
nn.Linear(self.vision.config.hidden_size, VISION_PROJECTOR_HIDDEN_DIM),
nn.GELU(),
nn.Linear(VISION_PROJECTOR_HIDDEN_DIM, llm.config.embedding_dim),
)
# =====================================================
# Load model (CPU)
# =====================================================
def load_vlm_model():
ckpt_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
repo_type="model"
)
checkpoint = torch.load(ckpt_path, map_location="cpu")
config_dict = checkpoint["config"]
if isinstance(config_dict.get("autocast_dtype"), str):
config_dict["autocast_dtype"] = getattr(
torch, config_dict["autocast_dtype"].split(".")[-1]
)
model_config_fields = {f.name for f in fields(ModelConfig)}
filtered = {k: v for k, v in config_dict.items() if k in model_config_fields}
config = ModelConfig(**filtered)
llm = GPT(config)
model = VLM(llm)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
model.eval()
return model
# =====================================================
# Inference helpers (Colab準拠)
# =====================================================
@torch.no_grad()
def vlm_prefill(model, image_tensor, input_ids):
x = model.llm.token_embedding_layer(input_ids)
v = model.vision(image_tensor, output_hidden_states=True)
v = v.hidden_states[-1][:, 1:]
v = model.projector(v)
x = torch.cat([v, x[:, NUM_IMAGE_PATCHES:]], dim=1)
for block in model.llm.blocks:
x = block(x, use_cache=True)
return x
@torch.no_grad()
def vlm_next_token(model, input_ids, temperature, top_k, top_p):
x = model.llm.token_embedding_layer(input_ids)
for block in model.llm.blocks:
x = block(x, use_cache=True)
logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[:, -1:], -float("inf"), logits)
if top_p:
s_logits, s_idx = torch.sort(logits, descending=True)
probs = F.softmax(s_logits, dim=-1)
cum = probs.cumsum(dim=-1)
mask = cum > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
s_logits[mask] = -float("inf")
logits = torch.zeros_like(logits).scatter(-1, s_idx, s_logits)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1)
def vlm_infer_stream(
model,
image_tensor,
prompt,
max_new_tokens=256,
temperature=0.7,
top_k=None,
top_p=None,
stop_ids={50256},
):
device = next(model.parameters()).device
prompt_ids = tokenizer.encode(prompt, allowed_special="all")
input_ids = (
[PAD_TOKEN_ID] * NUM_IMAGE_PATCHES + prompt_ids
)
input_ids = torch.tensor(input_ids, device=device)[None]
for block in model.llm.blocks:
block.multihead_attention.reset_cache()
x = vlm_prefill(model, image_tensor, input_ids)
logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
acc, last = [], ""
for _ in range(max_new_tokens):
# sampled from prefill
tid = int(next_token.item())
if tid in stop_ids:
break
acc.append(tid)
text = tokenizer.decode(acc)
if not text.endswith("�"):
new = text[len(last):]
if new:
yield new
last = text
input_ids = torch.cat([input_ids, next_token], dim=1)
next_token = vlm_next_token(
model,
input_ids[:, -1:],
temperature,
top_k,
top_p,
)