Spaces:
Running
Running
| """ | |
| models/custom_vlm.py | |
| ===================== | |
| Advanced Master-Hack β Visual Prefix-Tuning (Shakespeare + ViT) | |
| Architecture: A frozen pre-trained ViT (google/vit-base-patch16-224-in21k) | |
| is fused with a custom character-level causal Transformer decoder trained on | |
| Shakespeare text. A trainable MLP projection layer bridges the ViT's | |
| 768-dim output to the decoder's 384-dim embedding space. | |
| MODALITY FUSION: | |
| ViT β Project(768β384) β [visual_prefix | char_embeddings] β CausalSelfAttention | |
| TRAINING REGIME: | |
| - ViT: FROZEN (always) | |
| - Shakespeare Decoder: UNFROZEN during fine-tuning (adapts to COCO captions) | |
| - visual_projection: TRAINABLE (learned bridge) | |
| Weight Loading Strategy: | |
| The Shakespeare checkpoint uses a custom per-head architecture with keys like: | |
| blocks.N.sa_head.heads.M.{key,query,value}.weight | |
| These are remapped to PyTorch nn.TransformerEncoder's fused format: | |
| decoder_blocks.layers.N.self_attn.in_proj_weight | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import ViTModel | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Character Vocabulary Helper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_char_vocab(text_corpus: str): | |
| """ | |
| Build a character-level vocabulary from a raw text corpus string. | |
| Returns: | |
| chars : sorted list of unique characters | |
| char_to_idx : dict mapping char β int index | |
| idx_to_char : dict mapping int index β char | |
| vocab_size : int | |
| """ | |
| chars = sorted(set(text_corpus)) | |
| char_to_idx = {c: i for i, c in enumerate(chars)} | |
| idx_to_char = {i: c for i, c in enumerate(chars)} | |
| return chars, char_to_idx, idx_to_char, len(chars) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Definition | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CustomVLM(nn.Module): | |
| """ | |
| Visual Prefix-Tuning VLM. | |
| Combines: | |
| 1. Frozen ViT image encoder (768-dim output) | |
| 2. Trainable MLP projection (768 β text_embed_dim) | |
| 3. Character-level causal Transformer decoder | |
| (initialized from shakespeare_transformer.pt, then fine-tuned) | |
| """ | |
| NUM_VISUAL_TOKENS = 197 # ViT: 196 patches + 1 [CLS] | |
| def __init__(self, vocab_size, text_embed_dim=384, n_heads=8, n_layers=8, | |
| block_size=256, dropout=0.1): | |
| super().__init__() | |
| # ββ 1. Vision Encoder (Frozen) ββββββββββββββββββββββββββββββββββββββ | |
| self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") | |
| for param in self.vit.parameters(): | |
| param.requires_grad = False | |
| vit_hidden_size = self.vit.config.hidden_size # 768 | |
| # ββ 2. Trainable Bridge (MLP β like LLaVA) ββββββββββββββββββββββββββ | |
| self.visual_projection = nn.Sequential( | |
| nn.Linear(vit_hidden_size, vit_hidden_size * 2), | |
| nn.GELU(), | |
| nn.Linear(vit_hidden_size * 2, text_embed_dim) | |
| ) | |
| # ββ 3. Character-Level Causal Transformer Decoder βββββββββββββββββββ | |
| self.token_embedding_table = nn.Embedding(vocab_size, text_embed_dim) | |
| # Position table covers visual prefix (197) + max text (block_size) | |
| self.position_embedding_table = nn.Embedding( | |
| self.NUM_VISUAL_TOKENS + block_size, text_embed_dim | |
| ) | |
| decoder_layer = nn.TransformerEncoderLayer( | |
| d_model=text_embed_dim, | |
| nhead=n_heads, | |
| dim_feedforward=4 * text_embed_dim, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| self.decoder_blocks = nn.TransformerEncoder(decoder_layer, num_layers=n_layers) | |
| self.ln_f = nn.LayerNorm(text_embed_dim) | |
| self.lm_head = nn.Linear(text_embed_dim, vocab_size) | |
| self.block_size = block_size | |
| self.text_embed_dim = text_embed_dim | |
| self.vocab_size = vocab_size | |
| self.n_heads = n_heads | |
| self.n_layers = n_layers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Weight Loading β with architecture remapping | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_shakespeare_weights(self, path: str, device: str = "cpu") -> dict: | |
| """ | |
| Load pre-trained Shakespeare Transformer weights with full key remapping. | |
| The Shakespeare checkpoint uses a custom per-head architecture: | |
| blocks.N.sa_head.heads.M.{key,query,value}.weight (head_dim, embed_dim) | |
| blocks.N.sa_head.proj.{weight,bias} | |
| blocks.N.ffwd.net.{0,2}.{weight,bias} | |
| blocks.N.ln{1,2}.{weight,bias} | |
| These are remapped into PyTorch nn.TransformerEncoder's fused format: | |
| decoder_blocks.layers.N.self_attn.in_proj_weight (3*embed_dim, embed_dim) | |
| decoder_blocks.layers.N.self_attn.out_proj.{weight,bias} | |
| decoder_blocks.layers.N.linear1.{weight,bias} | |
| decoder_blocks.layers.N.linear2.{weight,bias} | |
| decoder_blocks.layers.N.norm1.{weight,bias} | |
| decoder_blocks.layers.N.norm2.{weight,bias} | |
| """ | |
| print(f"π Loading Shakespeare weights from: {path}") | |
| raw = torch.load(path, map_location=device) | |
| # Unwrap common checkpoint structures | |
| if isinstance(raw, dict): | |
| if "model_state" in raw: | |
| state_dict = raw["model_state"] | |
| elif "model" in raw: | |
| state_dict = raw["model"] | |
| elif "state_dict" in raw: | |
| state_dict = raw["state_dict"] | |
| else: | |
| state_dict = raw | |
| else: | |
| raise TypeError(f"Unexpected checkpoint type: {type(raw)}") | |
| # ββ Discover Shakespeare architecture ββββββββββββββββββββββββββββββββ | |
| shk_blocks = set() | |
| shk_heads = set() | |
| for key in state_dict: | |
| if key.startswith("blocks."): | |
| parts = key.split(".") | |
| shk_blocks.add(int(parts[1])) | |
| if "heads" in key: | |
| shk_heads.add(int(parts[4])) | |
| n_shk_blocks = len(shk_blocks) | |
| n_shk_heads = len(shk_heads) if shk_heads else self.n_heads | |
| head_dim = self.text_embed_dim // self.n_heads | |
| print(f" π Shakespeare arch: {n_shk_blocks} blocks, {n_shk_heads} heads, " | |
| f"head_dim={head_dim}") | |
| print(f" π Model arch: {self.n_layers} layers, {self.n_heads} heads") | |
| # How many layers to load (min of checkpoint and model) | |
| n_load = min(n_shk_blocks, self.n_layers) | |
| n_heads_load = min(n_shk_heads, self.n_heads) | |
| remapped = {} | |
| # ββ Remap decoder blocks βββββββββββββββββββββββββββββββββββββββββββββ | |
| for layer_idx in range(n_load): | |
| prefix_src = f"blocks.{layer_idx}" | |
| prefix_dst = f"decoder_blocks.layers.{layer_idx}" | |
| # 1. Self-Attention: Fuse per-head Q, K, V into in_proj_weight | |
| # Shakespeare: heads.M.query.weight (head_dim, embed_dim) | |
| # Target: self_attn.in_proj_weight (3*embed_dim, embed_dim) | |
| q_parts, k_parts, v_parts = [], [], [] | |
| for h in range(n_heads_load): | |
| qk = f"{prefix_src}.sa_head.heads.{h}.query.weight" | |
| kk = f"{prefix_src}.sa_head.heads.{h}.key.weight" | |
| vk = f"{prefix_src}.sa_head.heads.{h}.value.weight" | |
| if qk in state_dict and kk in state_dict and vk in state_dict: | |
| q_parts.append(state_dict[qk]) | |
| k_parts.append(state_dict[kk]) | |
| v_parts.append(state_dict[vk]) | |
| if q_parts: | |
| # Concatenate heads: each (head_dim, embed_dim) β (embed_dim, embed_dim) | |
| Q_full = torch.cat(q_parts, dim=0) # (n_heads*head_dim, embed_dim) | |
| K_full = torch.cat(k_parts, dim=0) | |
| V_full = torch.cat(v_parts, dim=0) | |
| # Fuse into in_proj_weight: [Q; K; V] β (3*embed_dim, embed_dim) | |
| in_proj_weight = torch.cat([Q_full, K_full, V_full], dim=0) | |
| remapped[f"{prefix_dst}.self_attn.in_proj_weight"] = in_proj_weight | |
| # Create zero bias (Shakespeare has no Q/K/V bias) | |
| remapped[f"{prefix_dst}.self_attn.in_proj_bias"] = torch.zeros( | |
| 3 * self.text_embed_dim | |
| ) | |
| # 2. Output projection | |
| proj_w = f"{prefix_src}.sa_head.proj.weight" | |
| proj_b = f"{prefix_src}.sa_head.proj.bias" | |
| if proj_w in state_dict: | |
| remapped[f"{prefix_dst}.self_attn.out_proj.weight"] = state_dict[proj_w] | |
| if proj_b in state_dict: | |
| remapped[f"{prefix_dst}.self_attn.out_proj.bias"] = state_dict[proj_b] | |
| # 3. Feed-Forward Network | |
| # Shakespeare: ffwd.net.0 β linear1, ffwd.net.2 β linear2 | |
| for shk_idx, tgt_name in [("0", "linear1"), ("2", "linear2")]: | |
| wk = f"{prefix_src}.ffwd.net.{shk_idx}.weight" | |
| bk = f"{prefix_src}.ffwd.net.{shk_idx}.bias" | |
| if wk in state_dict: | |
| remapped[f"{prefix_dst}.{tgt_name}.weight"] = state_dict[wk] | |
| if bk in state_dict: | |
| remapped[f"{prefix_dst}.{tgt_name}.bias"] = state_dict[bk] | |
| # 4. Layer Norms: ln1 β norm1, ln2 β norm2 | |
| for shk_ln, tgt_ln in [("ln1", "norm1"), ("ln2", "norm2")]: | |
| for suffix in ("weight", "bias"): | |
| sk = f"{prefix_src}.{shk_ln}.{suffix}" | |
| if sk in state_dict: | |
| remapped[f"{prefix_dst}.{tgt_ln}.{suffix}"] = state_dict[sk] | |
| # ββ Non-decoder module weights βββββββββββββββββββββββββββββββββββββββ | |
| # token_embedding_table | |
| if "token_embedding_table.weight" in state_dict: | |
| shk_emb = state_dict["token_embedding_table.weight"] | |
| own_emb = self.token_embedding_table.weight | |
| if shk_emb.shape == own_emb.shape: | |
| remapped["token_embedding_table.weight"] = shk_emb | |
| elif shk_emb.shape[1] == own_emb.shape[1]: | |
| # Vocab size difference: copy what fits | |
| n_copy = min(shk_emb.shape[0], own_emb.shape[0]) | |
| new_emb = own_emb.data.clone() | |
| new_emb[:n_copy] = shk_emb[:n_copy] | |
| remapped["token_embedding_table.weight"] = new_emb | |
| # position_embedding_table: Shakespeare (256, 384) β Model (453, 384) | |
| if "position_embedding_table.weight" in state_dict: | |
| shk_pos = state_dict["position_embedding_table.weight"] # (256, 384) | |
| own_pos = self.position_embedding_table.weight # (197+block_size, 384) | |
| if shk_pos.shape == own_pos.shape: | |
| remapped["position_embedding_table.weight"] = shk_pos | |
| else: | |
| # Expand: zero-init the full table, then copy Shakespeare positions | |
| # into the TEXT portion (positions 197..197+256) | |
| new_pos = torch.zeros_like(own_pos.data) | |
| # Visual positions (0..196) get small random init | |
| nn.init.normal_(new_pos[:self.NUM_VISUAL_TOKENS], std=0.02) | |
| # Text positions: copy Shakespeare's first N positions | |
| n_text_slots = own_pos.shape[0] - self.NUM_VISUAL_TOKENS | |
| n_copy = min(shk_pos.shape[0], n_text_slots) | |
| new_pos[self.NUM_VISUAL_TOKENS:self.NUM_VISUAL_TOKENS + n_copy] = shk_pos[:n_copy] | |
| remapped["position_embedding_table.weight"] = new_pos | |
| print(f" π Position embeddings expanded: {shk_pos.shape} β {own_pos.shape}") | |
| # ln_f (final layer norm) | |
| for suffix in ("weight", "bias"): | |
| k = f"ln_f.{suffix}" | |
| if k in state_dict: | |
| own_shape = getattr(self.ln_f, suffix).shape | |
| if state_dict[k].shape == own_shape: | |
| remapped[k] = state_dict[k] | |
| # lm_head | |
| if "lm_head.weight" in state_dict: | |
| shk_lm = state_dict["lm_head.weight"] | |
| own_lm = self.lm_head.weight | |
| if shk_lm.shape == own_lm.shape: | |
| remapped["lm_head.weight"] = shk_lm | |
| elif shk_lm.shape[1] == own_lm.shape[1]: | |
| n_copy = min(shk_lm.shape[0], own_lm.shape[0]) | |
| new_lm = own_lm.data.clone() | |
| new_lm[:n_copy] = shk_lm[:n_copy] | |
| remapped["lm_head.weight"] = new_lm | |
| if "lm_head.bias" in state_dict: | |
| shk_b = state_dict["lm_head.bias"] | |
| own_b = self.lm_head.bias | |
| if own_b is not None and shk_b.shape == own_b.shape: | |
| remapped["lm_head.bias"] = shk_b | |
| elif own_b is not None: | |
| n_copy = min(shk_b.shape[0], own_b.shape[0]) | |
| new_b = own_b.data.clone() | |
| new_b[:n_copy] = shk_b[:n_copy] | |
| remapped["lm_head.bias"] = new_b | |
| # ββ Load remapped weights βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Verify shapes before loading | |
| own_state = self.state_dict() | |
| valid_remapped = {} | |
| shape_mismatches = [] | |
| for k, v in remapped.items(): | |
| if k in own_state: | |
| if own_state[k].shape == v.shape: | |
| valid_remapped[k] = v | |
| else: | |
| shape_mismatches.append( | |
| f" {k}: ckpt={v.shape} vs model={own_state[k].shape}" | |
| ) | |
| else: | |
| shape_mismatches.append(f" {k}: not in model state_dict") | |
| result = self.load_state_dict(valid_remapped, strict=False) | |
| print(f" β Successfully loaded {len(valid_remapped)} weight tensors (of {len(state_dict)} in checkpoint)") | |
| if shape_mismatches: | |
| print(f" β οΈ {len(shape_mismatches)} shape mismatches (skipped):") | |
| for msg in shape_mismatches[:5]: | |
| print(msg) | |
| # Count decoder keys that were successfully loaded | |
| decoder_loaded = sum(1 for k in valid_remapped if k.startswith("decoder_blocks")) | |
| total_decoder = sum(1 for k in own_state if k.startswith("decoder_blocks")) | |
| print(f" π Decoder coverage: {decoder_loaded}/{total_decoder} tensors loaded") | |
| return { | |
| "loaded": list(valid_remapped.keys()), | |
| "missing": result.missing_keys, | |
| "unexpected": result.unexpected_keys, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Freezing / Unfreezing / Parameter Counting | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def freeze_decoder(self): | |
| """Freeze the Shakespeare decoder so only visual_projection trains.""" | |
| for name, param in self.named_parameters(): | |
| if not name.startswith("visual_projection"): | |
| param.requires_grad = False | |
| # Ensure ViT is frozen | |
| for param in self.vit.parameters(): | |
| param.requires_grad = False | |
| def unfreeze_decoder(self): | |
| """ | |
| Unfreeze the decoder for fine-tuning while keeping ViT frozen. | |
| This allows the decoder to adapt from Shakespeare text to COCO captions. | |
| The visual_projection is also trainable. | |
| """ | |
| # First, freeze everything | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze visual_projection (always trainable) | |
| for param in self.visual_projection.parameters(): | |
| param.requires_grad = True | |
| # Unfreeze ALL decoder components | |
| for param in self.token_embedding_table.parameters(): | |
| param.requires_grad = True | |
| for param in self.position_embedding_table.parameters(): | |
| param.requires_grad = True | |
| for param in self.decoder_blocks.parameters(): | |
| param.requires_grad = True | |
| for param in self.ln_f.parameters(): | |
| param.requires_grad = True | |
| for param in self.lm_head.parameters(): | |
| param.requires_grad = True | |
| # ViT stays FROZEN | |
| for param in self.vit.parameters(): | |
| param.requires_grad = False | |
| def get_param_groups(self, projection_lr=1e-4, decoder_lr=5e-5): | |
| """ | |
| Return optimizer param groups with discriminative learning rates. | |
| - visual_projection: higher LR (learning from scratch) | |
| - decoder: lower LR (gentle adaptation from Shakespeare) | |
| """ | |
| projection_params = [] | |
| decoder_params = [] | |
| for name, param in self.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if name.startswith("visual_projection"): | |
| projection_params.append(param) | |
| else: | |
| decoder_params.append(param) | |
| return [ | |
| {"params": projection_params, "lr": projection_lr}, | |
| {"params": decoder_params, "lr": decoder_lr}, | |
| ] | |
| def trainable_params(self): | |
| """Return count of trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Forward Pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward(self, pixel_values, text_input_ids, text_targets=None): | |
| B, T = text_input_ids.shape | |
| # ββ Image Encoding (frozen ViT) ββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| vit_outputs = self.vit(pixel_values=pixel_values) | |
| image_embeds = vit_outputs.last_hidden_state # (B, 197, 768) | |
| # ββ Project to text embedding space ββββββββββββββββββββββββββββββββββ | |
| visual_prefix = self.visual_projection(image_embeds) # (B, 197, 384) | |
| num_visual = visual_prefix.shape[1] # 197 | |
| # ββ Text Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| T_clipped = min(T, self.block_size) | |
| text_in = text_input_ids[:, :T_clipped] | |
| tok_emb = self.token_embedding_table(text_in) # (B, T, 384) | |
| # ββ Positional Embeddings (covers full combined sequence) βββββββββββββ | |
| # Positions 0..196 β visual prefix, 197..197+T β text tokens | |
| total_len = num_visual + T_clipped | |
| pos_ids = torch.arange(total_len, device=text_in.device) | |
| pos_emb = self.position_embedding_table(pos_ids) # (num_visual+T, 384) | |
| vis_pos = pos_emb[:num_visual] # (197, 384) | |
| txt_pos = pos_emb[num_visual:] # (T, 384) | |
| visual_emb = visual_prefix + vis_pos # (B, 197, 384) | |
| text_emb = tok_emb + txt_pos # (B, T, 384) | |
| # ββ Fusion: [visual_prefix | text_emb] βββββββββββββββββββββββββββββββ | |
| combined = torch.cat([visual_emb, text_emb], dim=1) # (B, 197+T, 384) | |
| tot = combined.shape[1] | |
| # ββ Causal Attention Mask βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Visual tokens attend to each other freely. | |
| # Text tokens attend to all visual tokens + causally to previous text. | |
| mask = torch.full((tot, tot), float("-inf"), device=text_in.device) | |
| mask[:num_visual, :num_visual] = 0.0 # visualβvisual: free | |
| mask[num_visual:, :num_visual] = 0.0 # textβvisual: free | |
| causal = torch.triu( | |
| torch.full((T_clipped, T_clipped), float("-inf"), device=text_in.device), | |
| diagonal=1, | |
| ) | |
| mask[num_visual:, num_visual:] = causal # textβtext: causal | |
| # ββ Decoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| x = self.decoder_blocks(combined, mask=mask, is_causal=False) | |
| text_out = x[:, num_visual:, :] | |
| text_out = self.ln_f(text_out) | |
| logits = self.lm_head(text_out) # (B, T, vocab) | |
| # ββ Loss (ignore padding index 0) βββββββββββββββββββββββββββββββββββββ | |
| loss = None | |
| if text_targets is not None: | |
| tgt = text_targets[:, :T_clipped] | |
| loss = F.cross_entropy( | |
| logits.reshape(B * T_clipped, -1), | |
| tgt.reshape(B * T_clipped), | |
| ignore_index=0, | |
| ) | |
| return logits, loss | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate(self, pixel_values, char_to_idx, idx_to_char, | |
| max_new_tokens=100, temperature=0.8): | |
| """ | |
| Autoregressive character-level caption generation (temperature sampling). | |
| Args: | |
| pixel_values : (1, 3, H, W) pre-processed image tensor | |
| char_to_idx : character β index mapping | |
| idx_to_char : index β character mapping | |
| max_new_tokens : how many characters to generate | |
| temperature : sampling temperature (0.8 = slightly sharper than uniform) | |
| Returns: | |
| generated_text : str | |
| """ | |
| self.eval() | |
| device = pixel_values.device | |
| bos_idx = char_to_idx.get("\n", 0) | |
| idx_seq = torch.tensor([[bos_idx]], dtype=torch.long, device=device) | |
| for _ in range(max_new_tokens): | |
| # Clip text to block_size β the forward method handles the visual | |
| # prefix separately, so we only need to limit the text portion. | |
| idx_cond = idx_seq[:, -self.block_size:] | |
| logits, _ = self(pixel_values, idx_cond) | |
| # Take the last time step | |
| logits_last = logits[:, -1, :] / max(temperature, 1e-5) | |
| probs = F.softmax(logits_last, dim=-1) | |
| next_idx = torch.multinomial(probs, num_samples=1) | |
| idx_seq = torch.cat([idx_seq, next_idx], dim=1) | |
| # Decode, skip the leading BOS | |
| generated = "".join( | |
| idx_to_char.get(i.item(), "?") for i in idx_seq[0, 1:] | |
| ) | |
| return generated | |
| def generate_beam(self, pixel_values, char_to_idx, idx_to_char, | |
| max_new_tokens=100, num_beams=4, length_penalty=1.0): | |
| """ | |
| Beam-search character-level caption generation. | |
| At each step we keep the top `num_beams` partial sequences ranked by | |
| cumulative log-probability (with optional length penalty). | |
| Args: | |
| pixel_values : (1, 3, H, W) image tensor | |
| char_to_idx : char β idx mapping | |
| idx_to_char : idx β char mapping | |
| max_new_tokens : max characters to generate | |
| num_beams : beam width (1 = greedy) | |
| length_penalty : >1 favors longer sequences; <1 favors shorter | |
| Returns: | |
| generated_text : str (best beam) | |
| """ | |
| self.eval() | |
| device = pixel_values.device | |
| bos_idx = char_to_idx.get("\n", 0) | |
| # Each beam: (score, token_sequence_tensor) | |
| beams = [(0.0, torch.tensor([[bos_idx]], dtype=torch.long, device=device))] | |
| for _ in range(max_new_tokens): | |
| candidates = [] | |
| for score, seq in beams: | |
| idx_cond = seq[:, -self.block_size:] | |
| logits, _ = self(pixel_values, idx_cond) | |
| log_probs = F.log_softmax(logits[:, -1, :], dim=-1) # (1, vocab) | |
| topk_probs, topk_ids = log_probs.topk(num_beams, dim=-1) | |
| for k in range(num_beams): | |
| new_score = score + topk_probs[0, k].item() | |
| new_seq = torch.cat( | |
| [seq, topk_ids[:, k:k+1]], dim=1 | |
| ) | |
| candidates.append((new_score, new_seq)) | |
| # Apply length penalty and keep top beams | |
| candidates.sort( | |
| key=lambda x: x[0] / (x[1].shape[1] ** length_penalty), | |
| reverse=True, | |
| ) | |
| beams = candidates[:num_beams] | |
| best_seq = beams[0][1] | |
| return "".join(idx_to_char.get(i.item(), "?") for i in best_seq[0, 1:]) | |