import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F from typing_extensions import Self from typing import Optional from transformers.modeling_utils import PreTrainedModel from torch.distributions import Categorical @dataclass class LLaMAHFConfig: block_size: int = 156 n_layer: int = 32 n_head: int = 32 n_kv_head: Optional[int] = None n_embd: int = 4096 rope_base: int = 500000 T5_xxl_dim: int = 768 @classmethod def from_name(cls, name: str) -> Self: return cls(**llama_configs[name]) llama_configs = { "Normal_size": dict(n_layer=12, n_head=12, n_embd=768) } class LLaMAHF(nn.Module): def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=6, n_diffusion_heads=4, input_token_dim=16, device=torch.device('cuda'), width=512) -> None: super().__init__() assert config.block_size is not None self.config = config cond_dim = config.T5_xxl_dim self.transformer = nn.ModuleDict( dict( wte=nn.Linear(input_token_dim, config.n_embd), # vector tokens -> embeddings cond_embed=nn.Linear(cond_dim, config.n_embd), # text feature -> context emb h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f=RMSNorm(config.n_embd), ) ) target_channels = input_token_dim from models.diffloss import DiffLoss self.diff_loss = DiffLoss( target_channels=target_channels, z_channels=config.n_embd, width=width, depth=num_diffusion_head_layers, num_sampling_steps='50', grad_checkpointing=False, n_heads=n_diffusion_heads, mlp_ratio=2.0 ).to(device) self.out_proj = nn.Linear(config.n_embd, config.n_embd) self.use_out_proj = True # --- Persistent prompt cache & BOS token --- self._prompt_cached = False self._prompt_bsz = None self.bos = nn.Parameter(torch.zeros(1, 1, config.n_embd)) # === Needed by several sampling/forward paths === # projects raw text features when they are concatenated as tokens self.llama_proj = nn.Linear(config.T5_xxl_dim, config.n_embd) # special boundary-of-motion token used in forward_babel self.BOM_tag = nn.Parameter(torch.zeros(1, 1, config.n_embd)) # (Optional) only if sample_for_eval_classification() is used: # self.classify_head = nn.Linear(config.n_embd, num_classes) @torch.no_grad() def set_prompt(self, feature: torch.Tensor): """ Precompute and cache cross-attention K/V for the current prompt (feature). Call this ONCE when you switch prompt (e.g., 'walk' -> 'crawl'). """ context = self._prepare_context(feature) if context is None: raise ValueError("set_prompt: feature cannot be None") self._prompt_bsz = context.size(0) for blk in self.transformer.h: blk.set_context_cache(context) self._prompt_cached = True @torch.no_grad() def clear_prompt(self): for blk in self.transformer.h: blk.clear_context_cache() self._prompt_cached = False self._prompt_bsz = None def _prepare_context(self, feature: Optional[torch.Tensor], batch_size: Optional[int] = None) -> Optional[torch.Tensor]: if feature is None: return None if not torch.is_tensor(feature): feature = torch.as_tensor( feature, dtype=self.transformer.cond_embed.weight.dtype, device=self.transformer.cond_embed.weight.device, ) else: feature = feature.to( dtype=self.transformer.cond_embed.weight.dtype, device=self.transformer.cond_embed.weight.device, ) if feature.dim() == 1: feature = feature.unsqueeze(0) context = self.transformer.cond_embed(feature) if context.dim() == 2: context = context.unsqueeze(1) if batch_size is not None and context.size(0) != batch_size: if context.size(0) == 1: context = context.expand(batch_size, -1, -1) else: raise ValueError( f"Condition batch ({context.size(0)}) does not match token batch ({batch_size})." ) return context def _tie_or_clone_weights(self, output_embeddings, input_embeddings): """Tie or clone module weights depending of whether we are using TorchScript or not""" output_embeddings.weight = input_embeddings.weight if getattr(output_embeddings, "bias", None) is not None: output_embeddings.bias.data = nn.functional.pad( output_embeddings.bias.data, ( 0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], ), "constant", 0, ) if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings def get_input_embeddings(self): return self.transformer.wte def set_input_embeddings(self, value): self.transformer.wte = value def get_output_embeddings(self): return self.out_proj def set_output_embeddings(self, new_embeddings): self.out_proj = new_embeddings def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor: text_length = clip_feature.shape[1] context = self._prepare_context(clip_feature) if len(idx) == 0: x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :] else: _, t = idx.size() assert ( t <= self.config.block_size ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" # forward the LLaMA model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1) if context is not None and context.size(0) != x.size(0): if context.size(0) == 1: context = context.expand(x.size(0), -1, -1) else: raise ValueError("Conditioning batch size does not match token batch size.") for block in self.transformer.h: x = block(x, context=context) x = self.transformer.ln_f(x) logits = x return logits def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0): max_token_len = length // unit_length # Prepare conditioned prompt once and cache it feat_text = torch.from_numpy(tokenize_model.encode(text)).float().to(device) self.set_prompt(feat_text) # <-- persist until you change it # Prepare empty/uncond prompt once and cache it too empty_feat_text = torch.from_numpy(tokenize_model.encode('')).float().unsqueeze(0).to(device) # We'll flip between two caches: cond and uncond def _use_cond_cache(): self.set_prompt(feat_text) def _use_uncond_cache(): self.set_prompt(empty_feat_text) xs = None for k in range(max_token_len): x = [] if k == 0 else xs # conditioned next-step _use_cond_cache() conditions = self.forward(x, feature=None)[:, -1, :] # unconditioned next-step _use_uncond_cache() empty_conditions = self.forward(x, feature=None)[:, -1, :] temperature = 1.0 if cfg != 1: mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1) scaled_logits = scaled_logits.unsqueeze(0) xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1) # re-enable the conditioned prompt cache for whatever comes next self.set_prompt(feat_text) return xs # For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold. def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0): max_token_len = length // unit_length feat_text = torch.from_numpy(tokenizer.encode(text)).float().to(device) empty_feat_text = torch.from_numpy(tokenizer.encode('')).float().unsqueeze(0).to(device) def _use_cond(): self.set_prompt(feat_text) def _use_uncond(): self.set_prompt(empty_feat_text) xs = None for k in range(max_token_len): x = [] if k == 0 else xs _use_cond() conditions = self.forward_inference(x, feature=None)[:, -1, :] _use_uncond() empty_conditions = self.forward(x, feature=None)[:, -1, :] mix = torch.cat([conditions, empty_conditions], dim=0) sampled = self.diff_loss.sample(mix, temperature=temperature, cfg=cfg) scaled_logits, _ = sampled.chunk(2, dim=0) if cfg != 1 else (sampled, None) scaled_logits = scaled_logits.unsqueeze(0) if reference_end_latent is not None: dist = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2)) if dist < threshold: break xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1) # leave the cond cache active self.set_prompt(feat_text) return xs def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): import clip max_token_len = length // unit_length for k in range(max_token_len): if k == 0: x = [] else: x = xs try: conditions = self.forward(x, feat_clip_text) except: conditions = self.forward(x, feat_clip_text.unsqueeze(0)) conditions = conditions[:, -1, :] empty_conditions = self.forward(x, empty_feat_clip_text) empty_conditions = empty_conditions[:, -1, :] mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if reference_end_token is not None: distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) print(distance_l2) if distance_l2 < threshold: break if k == 0: xs = scaled_logits else: xs = torch.cat((xs, scaled_logits), dim=1) return xs def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): import clip max_token_len = length // unit_length for k in range(1): if current_token == []: x = [] else: x = torch.cat(current_token, dim=1) try: conditions = self.forward(x, feat_clip_text) except: conditions = self.forward(x, feat_clip_text.unsqueeze(0)) conditions = conditions[:, -1, :] empty_conditions = self.forward(x, empty_feat_clip_text) empty_conditions = empty_conditions[:, -1, :] mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if k == 0: xs = scaled_logits else: xs = torch.cat((xs, scaled_logits), dim=1) return xs def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): import clip B_token_length = length // unit_length - A_motion.shape[0] if tokenizer == 'clip': A_text = clip.tokenize(A_text, truncate=True).to(device) A_feat_clip_text = clip_model.encode_text(A_text).float() B_text = clip.tokenize(B_text, truncate=True).to(device) B_feat_clip_text = clip_model.encode_text(B_text).float() elif tokenizer == 't5-xxl': A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() A_feat_clip_text = A_feat_clip_text.to(device) B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() B_feat_clip_text = B_feat_clip_text.to(device) A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) A_motion = A_motion.unsqueeze(0) A_motion_embeddings = self.transformer.wte(A_motion) B_motion = torch.tensor([]).to(device) for k in range(B_token_length): if k == 0: x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) else: x = xs conditions = self.forward_babel_eval(x) conditions = conditions[:, -1, :] empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) if k == 0: empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) empty_conditions = self.forward_babel_eval(empty_input) else: B_motion_embeddings = self.transformer.wte(B_motion) empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) empty_conditions = empty_conditions[:, -1, :] temperature = 1.0 mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) B_motion = torch.cat((B_motion, scaled_logits), dim=1) scaled_logits_embedding = self.transformer.wte(scaled_logits) xs = torch.cat((x, scaled_logits_embedding), dim=1) return xs, B_motion def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): import clip B_token_length = length // unit_length - A_motion.shape[0] if tokenizer == 'clip': A_text = clip.tokenize(A_text, truncate=True).to(device) A_feat_clip_text = clip_model.encode_text(A_text).float() B_text = clip.tokenize(B_text, truncate=True).to(device) B_feat_clip_text = clip_model.encode_text(B_text).float() elif tokenizer == 't5-xxl': A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() A_feat_clip_text = A_feat_clip_text.to(device) B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() B_feat_clip_text = B_feat_clip_text.to(device) A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) A_text_embeddings = A_text_embeddings.unsqueeze(0) B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) B_text_embeddings = B_text_embeddings.unsqueeze(0) A_motion = A_motion.unsqueeze(0) A_motion_embeddings = self.transformer.wte(A_motion) B_motion = torch.tensor([]).to(device) attention_weights = [] for k in range(B_token_length): if k == 0: x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) else: x = xs conditions = self.forward_babel_eval(x, return_attention=False) conditions = conditions[:, -1, :] empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) if k == 0: empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) empty_conditions = self.forward_babel_eval(empty_input) else: B_motion_embeddings = self.transformer.wte(B_motion) empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) empty_conditions = empty_conditions[:, -1, :] temperature = 1.0 mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if reference_end_token is not None: distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) print(distance_l2) if distance_l2 < threshold: break B_motion = torch.cat((B_motion, scaled_logits), dim=1) scaled_logits_embedding = self.transformer.wte(scaled_logits) xs = torch.cat((x, scaled_logits_embedding), dim=1) return xs, B_motion def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3): import clip B_token_length = length // unit_length if tokenizer == 'clip': A_text = clip.tokenize(A_text, truncate=True).to(device) A_feat_clip_text = clip_model.encode_text(A_text).float() B_text = clip.tokenize(B_text, truncate=True).to(device) B_feat_clip_text = clip_model.encode_text(B_text).float() elif tokenizer == 't5-xxl': B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() B_feat_clip_text = B_feat_clip_text.to(device) empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) A_motion = A_motion.unsqueeze(0) A_motion_embeddings = self.transformer.wte(A_motion) B_motion = torch.tensor([]).to(device) attention_weights = [] for k in range(B_token_length): if k == 0: x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) else: x = xs conditions = self.forward_babel_eval(x, return_attention=False) conditions = conditions[:, -1, :] empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) if k == 0: empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) else: B_motion_embeddings = self.transformer.wte(B_motion) empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) empty_conditions = empty_conditions[:, -1, :] temperature = 1.0 mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if reference_end_token is not None: distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) print(distance_l2) if distance_l2 < threshold: break B_motion = torch.cat((B_motion, scaled_logits), dim=1) scaled_logits_embedding = self.transformer.wte(scaled_logits) xs = torch.cat((x, scaled_logits_embedding), dim=1) return xs, B_motion def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): import clip B_token_length = length // unit_length - A_motion.shape[0] if tokenizer == 'clip': A_text = clip.tokenize(A_text, truncate=True).to(device) A_feat_clip_text = clip_model.encode_text(A_text).float() B_text = clip.tokenize(B_text, truncate=True).to(device) B_feat_clip_text = clip_model.encode_text(B_text).float() elif tokenizer == 't5-xxl': B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() B_feat_clip_text = B_feat_clip_text.to(device) empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) B_text_embeddings = B_text_embeddings.unsqueeze(0) A_motion = A_motion.unsqueeze(0) A_motion_embeddings = self.transformer.wte(A_motion) B_motion = torch.tensor([]).to(device) # 存储所有层的注意力权重 attention_weights = [] for k in range(B_token_length): if k == 0: x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) else: x = xs conditions = self.forward_babel_eval(x, return_attention=False) conditions = conditions[:, -1, :] empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) if k == 0: empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) else: B_motion_embeddings = self.transformer.wte(B_motion) empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) empty_conditions = self.forward_babel_eval(empty_input) empty_conditions = empty_conditions[:, -1, :] mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if reference_end_token is not None: distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) print(distance_l2) if distance_l2 < threshold and k > 10: break B_motion = torch.cat((B_motion, scaled_logits), dim=1) scaled_logits_embedding = self.transformer.wte(scaled_logits) xs = torch.cat((x, scaled_logits_embedding), dim=1) return xs, B_motion def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): """ Inference loop that mimics the "Two-Forward" training strategy. This version correctly performs two full passes over the entire sequence. """ import clip B_token_length = length // unit_length - A_motion.shape[0] if tokenizer == 't5-xxl': B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device) else: raise NotImplementedError("Only t5-xxl is supported for this function.") empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device) # --- Create 3D embeddings [batch, seq, dim] --- B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0) empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) # This is [1, 1, 768] A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0)) # === 1. First Forward Pass (Generate Rough Draft) === rough_motion_tokens = A_motion for k in range(B_token_length): current_rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0)) # Conditioned x_cond = torch.cat([B_text_embeddings, current_rough_embeddings], dim=1) conditions = self.forward_babel_eval(x_cond, return_attention=False)[:, -1, :] # Unconditioned x_uncond = torch.cat([empty_text_embeddings, current_rough_embeddings], dim=1) empty_conditions = self.forward_babel_eval(x_uncond, return_attention=False)[:, -1, :] # Sample a rough prediction for the next token mix_conditions = torch.cat([conditions, empty_conditions], dim=0) pred_xstart_rough = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) if cfg != 1: pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0) rough_motion_tokens = torch.cat([rough_motion_tokens, pred_xstart_rough], dim=0) # === 2. Second Forward Pass (Generate Refined Motion) === # Now we have the full rough draft. We use it as the input for the second pass. refined_motion_tokens = A_motion for k in range(B_token_length): # The input to the transformer is the full rough sequence rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0)) # Conditioned x_cond_refined = torch.cat([B_text_embeddings, rough_embeddings], dim=1) # We take the condition corresponding to the token we want to predict conditions_refined = self.forward_babel_eval(x_cond_refined, return_attention=False)[:, A_motion.shape[0] + k, :] # Unconditioned x_uncond_refined = torch.cat([empty_text_embeddings, rough_embeddings], dim=1) empty_conditions_refined = self.forward_babel_eval(x_uncond_refined, return_attention=False)[:, A_motion.shape[0] + k, :] # Sample the final, refined token mix_conditions_refined = torch.cat([conditions_refined, empty_conditions_refined], dim=0) final_token, _ = self.diff_loss.sample(mix_conditions_refined, temperature=temperature, cfg=cfg).chunk(2, dim=0) # Append the refined token to our final output history refined_motion_tokens = torch.cat([refined_motion_tokens, final_token], dim=0) # IMPORTANT: For the next step, we must update the "rough draft" with our new refined token # This mimics the training where the input is a mix of GT and predictions. # Here, it's a mix of the initial rough draft and the new refined tokens. rough_motion_tokens[A_motion.shape[0] + k] = final_token.squeeze(0) # Return only the newly generated tokens (B_motion) B_motion = refined_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0) return None, B_motion #--------------Test classification head-------------------- def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4): import clip for k in range(51): if k == 0: x = [] else: x = xs if tokenizer == 'clip': text = clip.tokenize(clip_text, truncate=True).to(device) feat_clip_text = clip_model.encode_text(text).float() elif tokenizer == 't5-xxl': feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() conditions = self.forward(x, feat_clip_text) conditions = conditions[:, -1, :] empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) empty_conditions = self.forward(x, empty_feat_clip_text) empty_conditions = empty_conditions[:, -1, :] temperature = 1.0 cfg = 7.5 mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk if cfg != 1: scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: scaled_logits = sampled_token_latent prediction_logits = self.classify_head(conditions) probs = torch.sigmoid(prediction_logits) predicted_classes = torch.argmax(probs, dim=-1) scaled_logits = scaled_logits.unsqueeze(0) if k == 0: xs = scaled_logits else: xs = torch.cat((xs, scaled_logits), dim=1) if predicted_classes == 1: break return xs #--------------------Test CFG----------------------- def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4): import clip max_token_len = length // unit_length for k in range(max_token_len): if k == 0: x = [] else: x = xs if cfg != 1: if tokenizer == 'clip': text = clip.tokenize(clip_text, truncate=True).to(device) feat_clip_text = clip_model.encode_text(text).float() elif tokenizer == 't5-xxl': feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() conditions = self.forward(x, feat_clip_text) conditions = conditions[:, -1, :] empty_clip_text = '' if tokenizer == 'clip': empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) empty_feat_clip_text = clip_model.encode_text(empty_text).float() elif tokenizer == 't5-xxl': empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) empty_feat_clip_text = empty_feat_clip_text.to(device) empty_conditions = self.forward(x, empty_feat_clip_text) empty_conditions = empty_conditions[:, -1, :] temperature = 1.0 mix_conditions = torch.cat([conditions, empty_conditions], dim=0) sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) # chunk scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) else: if tokenizer == 'clip': text = clip.tokenize(clip_text, truncate=True).to(device) feat_clip_text = clip_model.encode_text(text).float() elif tokenizer == 't5-xxl': feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() feat_clip_text = feat_clip_text.to(device) conditions = self.forward(x, feat_clip_text) conditions = conditions[:, -1, :] temperature = 1.0 sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg) scaled_logits = sampled_token_latent scaled_logits = scaled_logits.unsqueeze(0) if k == 0: xs = scaled_logits else: xs = torch.cat((xs, scaled_logits), dim=1) return xs #-------------------------------------------------- def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor: """ Vector-token path: idx must be shape [B, T, input_token_dim]. If you want discrete IDs instead, you must switch wte to nn.Embedding. """ context = None if idx.numel() == 0: context = self._prepare_context(clip_feature) token_embeddings = context if token_embeddings is None: raise ValueError("Conditioning features are required when no motion tokens are provided.") else: b, t, _ = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" token_embeddings = self.transformer.wte(idx) # Linear -> [B, T, n_embd] context = self._prepare_context(clip_feature, batch_size=b) if context is not None: token_embeddings = torch.cat([context, token_embeddings], dim=1) x = token_embeddings if use_cache and past_key_values is None: past_key_values = [None] * len(self.transformer.h) for i, block in enumerate(self.transformer.h): if use_cache: last_past = past_key_values[i] x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache) past_key_values[i] = list(presents) else: x = block(x, context=context) x = self.transformer.ln_f(x) logits = self.out_proj(x) return logits def forward(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor: """ If self._prompt_cached is True, we DO NOT concat context each call. Instead, blocks read the cached prompt KV. Otherwise we embed and concat context as before. """ context = None if len(idx) == 0: if self._prompt_cached: if self._prompt_bsz is None: raise ValueError("Prompt cache set but batch size unknown.") b = self._prompt_bsz token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype) else: context = self._prepare_context(feature) token_embeddings = context if token_embeddings is None: raise ValueError("Conditioning features are required when no motion tokens are provided.") else: b, t, c = idx.size() idx = idx.float() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" token_embeddings = self.transformer.wte(idx) if not self._prompt_cached: context = self._prepare_context(feature, batch_size=b) if context is not None: token_embeddings = torch.cat([context, token_embeddings], dim=1) # Always prepend BOS scene token bos = self.bos.expand(token_embeddings.size(0), 1, -1) x = torch.cat([bos, token_embeddings], dim=1) # blocks: if context is None -> use cached prompt kv (if set) for block in self.transformer.h: x = block(x, context=context) x = self.transformer.ln_f(x) logits = self.out_proj(x) return logits def forward_inference(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor: context = None if len(idx) == 0: if self._prompt_cached: if self._prompt_bsz is None: raise ValueError("Prompt cache set but batch size unknown.") b = self._prompt_bsz token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype) else: context = self._prepare_context(feature) token_embeddings = context if token_embeddings is None: raise ValueError("Conditioning features are required when no motion tokens are provided.") else: b, t, c = idx.size() idx = idx.float() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" token_embeddings = self.transformer.wte(idx) if not self._prompt_cached: context = self._prepare_context(feature, batch_size=b) if context is not None: token_embeddings = torch.cat([context, token_embeddings], dim=1) x = token_embeddings if len(x.shape) == 2: x = x.unsqueeze(0) # prepend BOS bos = self.bos.expand(x.size(0), 1, -1) x = torch.cat([bos, x], dim=1) if context is not None and context.size(0) != x.size(0): if context.size(0) == 1: context = context.expand(x.size(0), -1, -1) else: raise ValueError("Conditioning batch size does not match token batch size.") for block in self.transformer.h: x = block(x, context=context) x = self.transformer.ln_f(x) logits = self.out_proj(x) return logits def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor: b, t, c = idx.size() idx = idx.float() idx = self.transformer.wte(idx) assert ( t <= self.config.block_size ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" for i in range(b): length_i = length[i][:num_subseq[i]] clip_feature_i = clip_feature[i][:num_subseq[i]] pointer = 0 for j in range(num_subseq[i]): if j > 0: pointer += length_i[j].item() pointer += 1 pointer = int(pointer) clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1) idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0] x = idx context = None if use_cache: if past_key_values is None: past_key_values = [None] * len(self.transformer.h) for i,block in enumerate(self.transformer.h): if use_cache: last_past = past_key_values[i] x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache) past_key_values[i] = list(presents) else: x = block(x, context=context) x = self.transformer.ln_f(x) logits = self.out_proj(x) return logits def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor: layer_attentions = [] context = None for block in self.transformer.h: if return_attention: x, att = block(x, context=context, return_attention=True) layer_attentions.append(att) else: x = block(x, context=context) x = self.transformer.ln_f(x) if self.use_out_proj: logits = self.out_proj(x) else: logits = x if return_attention: return logits, layer_attentions return logits def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor: context = None if len(idx) == 0: # inference context = self._prepare_context(clip_feature) token_embeddings = context else: b, t, c = idx.size() idx = idx.float() assert ( t <= self.config.block_size ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" A_feature = clip_feature[:, 0, :] B_feature = clip_feature[:, 1, :] A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1) B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1) context = torch.cat([A_text_embeddings, B_text_embeddings], dim=1) token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device) for i in range(b): A_idx = idx[i, :A_token_length[i].item(), :] B_idx = idx[i, A_token_length[i].item():-2, :] token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024) x = token_embeddings if context is not None and context.size(0) != x.size(0): if context.size(0) == 1: context = context.expand(x.size(0), -1, -1) else: raise ValueError("Conditioning batch size does not match token batch size.") for block in self.transformer.h: x = block(x, context=context) x = self.transformer.ln_f(x) if self.use_out_proj: logits = self.out_proj(x) else: logits = x return logits def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor: context = None if idx.numel() == 0: # inference with only context context = self._prepare_context(clip_feature) token_embeddings = context else: b, t, c = idx.size() idx = idx.float() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" B_feature = clip_feature # [B, D] or [B, 1, D] B_text_embeddings = self.transformer.cond_embed(B_feature) # [B, D] -> [B, D] if B_text_embeddings.dim() == 2: B_text_embeddings = B_text_embeddings.unsqueeze(1) # [B, 1, D] context = B_text_embeddings # [B, 1, D] idx_embeddings = self.transformer.wte(idx) # [B, T, D] token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1) # [B, 1+T, D] x = token_embeddings if context is not None: if context.dim() == 2: context = context.unsqueeze(1) if context.size(0) != x.size(0): if context.size(0) == 1: context = context.expand(x.size(0), -1, -1) else: raise ValueError("Conditioning batch size does not match token batch size.") for block in self.transformer.h: x = block(x, context=context) x = self.transformer.ln_f(x) logits = self.out_proj(x) if self.use_out_proj else x return logits def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False ) -> nn.Embedding: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Arguments: new_num_tokens (`int`, *optional*): The new number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. pad_to_multiple_of (`int`, *optional*): If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc Return: `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) if new_num_tokens is None and pad_to_multiple_of is None: return model_embeds # Update base model and current model config self.config.vocab_size = model_embeds.weight.shape[0] self.vocab_size = model_embeds.weight.shape[0] # Tie weights again if needed # self.tie_weights() return model_embeds def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): old_embeddings = self.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) old_embeddings_requires_grad = old_embeddings.weight.requires_grad new_embeddings.requires_grad_(old_embeddings_requires_grad) self.set_input_embeddings(new_embeddings) # Update new_num_tokens with the actual size of new_embeddings if pad_to_multiple_of is not None: # if is_deepspeed_zero3_enabled(): # import deepspeed # with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): # new_num_tokens = new_embeddings.weight.shape[0] # else: new_num_tokens = new_embeddings.weight.shape[0] # if word embeddings are not tied, make sure that lm head is resized as well # if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: if self.get_output_embeddings() is not None and not False: old_lm_head = self.get_output_embeddings() new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) # if hasattr(old_lm_head, "_hf_hook"): # hook = old_lm_head._hf_hook # add_hook_to_module(new_lm_head, hook) old_lm_head_requires_grad = old_lm_head.weight.requires_grad new_lm_head.requires_grad_(old_lm_head_requires_grad) self.set_output_embeddings(new_lm_head) return self.get_input_embeddings() def _get_resized_embeddings( self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, ) -> nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end Args: old_embeddings (`torch.nn.Embedding`): Old embeddings to be resized. new_num_tokens (`int`, *optional*): New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. pad_to_multiple_of (`int`, *optional*): If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc Return: `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is `None` """ if pad_to_multiple_of is not None: if not isinstance(pad_to_multiple_of, int): raise ValueError( f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" ) if new_num_tokens is None: new_num_tokens = old_embeddings.weight.shape[0] new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of else: print( "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" ) if new_num_tokens is None: return old_embeddings # if is_deepspeed_zero3_enabled(): if False: import deepspeed with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): old_num_tokens, old_embedding_dim = old_embeddings.weight.size() else: old_num_tokens, old_embedding_dim = old_embeddings.weight.size() # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): if old_num_tokens == new_num_tokens and not False: return old_embeddings if not isinstance(old_embeddings, nn.Embedding): raise TypeError( f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" " should either use a different resize function or make sure that `old_embeddings` are an instance of" f" {nn.Embedding}." ) # Build new embeddings # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init # because the shape of the new embedding layer is used across various modeling files # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading # to errors when training. new_embeddings = nn.Embedding( new_num_tokens, old_embedding_dim, device=old_embeddings.weight.device, dtype=old_embeddings.weight.dtype, ) # initialize all new embeddings (in particular added tokens) self._init_weights(new_embeddings) # Copy token embeddings from the previous weights # numbers of tokens to copy n = min(old_num_tokens, new_num_tokens) # if is_deepspeed_zero3_enabled(): if False: import deepspeed params = [old_embeddings.weight, new_embeddings.weight] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] else: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] return new_embeddings def _get_resized_lm_head( self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False ) -> nn.Linear: """ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end Args: old_lm_head (`torch.nn.Linear`): Old lm head liner layer to be resized. new_num_tokens (`int`, *optional*): New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just returns a pointer to the input tokens `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, vocab_size` else `vocab_size, lm_head_dim`. Return: `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is `None` """ if new_num_tokens is None: return old_lm_head # if is_deepspeed_zero3_enabled(): if False: import deepspeed with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): old_num_tokens, old_lm_head_dim = ( old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() ) else: old_num_tokens, old_lm_head_dim = ( old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() ) # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): if old_num_tokens == new_num_tokens and not False: return old_lm_head if not isinstance(old_lm_head, nn.Linear): raise TypeError( f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" " should either use a different resize function or make sure that `old_lm_head` are an instance of" f" {nn.Linear}." ) # Build new lm head new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) has_new_lm_head_bias = old_lm_head.bias is not None # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init # because the shape of the new embedding layer is used across various modeling files # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading # to errors when training. new_lm_head = nn.Linear( *new_lm_head_shape, bias=has_new_lm_head_bias, device=old_lm_head.weight.device, dtype=old_lm_head.weight.dtype, ) # initialize new lm head (in particular added tokens) self._init_weights(new_lm_head) num_tokens_to_copy = min(old_num_tokens, new_num_tokens) # if is_deepspeed_zero3_enabled(): if False: import deepspeed params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): self._copy_lm_head_original_to_resized( new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ) else: self._copy_lm_head_original_to_resized( new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ) return new_lm_head def _copy_lm_head_original_to_resized( self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ): # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] else: new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] # Copy bias weights to new lm head if has_new_lm_head_bias: new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] @classmethod def from_name(cls, name: str) -> Self: return cls(LLaMAHFConfig.from_name(name)) class Block(nn.Module): def __init__(self, config: LLaMAHFConfig) -> None: super().__init__() self.rms_1 = RMSNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.rms_cross = RMSNorm(config.n_embd) self.cross_attn = CrossAttention(config) self.rms_2 = RMSNorm(config.n_embd) self.mlp = MLP(config) # cached prompt kv (precomputed by set_prompt) self._ctx_k_repeat = None self._ctx_v_repeat = None self._ctx_bsz = None @torch.no_grad() def set_context_cache(self, context: torch.Tensor): # Precompute KV for cross attention and repeat across kv groups B, S, D = context.shape ca = self.cross_attn k = ca.k_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2) v = ca.v_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2) k = ca.k_norm(k) # repeat K/V to match heads self._ctx_k_repeat = repeat_kv(k, ca.num_kv_groups) # [B, n_head, S, d] self._ctx_v_repeat = repeat_kv(v, ca.num_kv_groups) # [B, n_head, S, d] self._ctx_bsz = B @torch.no_grad() def clear_context_cache(self): self._ctx_k_repeat = None self._ctx_v_repeat = None self._ctx_bsz = None def _cross_attend_cached(self, x: torch.Tensor): # x: [B, T, D] if self._ctx_k_repeat is None or self._ctx_v_repeat is None: return x # no-op if no cached prompt B, T, _ = x.size() if self._ctx_bsz is not None and self._ctx_bsz != B: # different batch: ignore cache (or you could raise) return x ca = self.cross_attn q = ca.q_proj(x).view(B, T, ca.n_head, ca.head_dim).transpose(1, 2) q = ca.q_norm(q) y = F.scaled_dot_product_attention( q, self._ctx_k_repeat, self._ctx_v_repeat, attn_mask=None, dropout_p=0.0, is_causal=False, scale=ca.softmax_scale, ) y = y.transpose(1, 2).contiguous().view(B, T, ca.n_head * ca.head_dim) return x + ca.o_proj(y) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, last_past=None, use_cache: bool = False, return_attention: bool = False, ) -> torch.Tensor: present = None # self-attn if use_cache: if return_attention: attn_output, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache) else: attn_output, present = self.attn(self.rms_1(x), last_past, use_cache) x = x + attn_output else: if return_attention: attn_output, attn = self.attn.forward_attn(self.rms_1(x)) else: attn_output = self.attn(self.rms_1(x)) x = x + attn_output # cross-attn: prefer live context if provided; else use cached prompt kv if context is not None: x = x + self.cross_attn(self.rms_cross(x), context) else: x = self._cross_attend_cached(self.rms_cross(x)) # mlp x = x + self.mlp(self.rms_2(x)) if use_cache: if return_attention: return x, present, attn else: return x, present else: if return_attention: return x, attn else: return x class CausalSelfAttention(nn.Module): def __init__(self, config: LLaMAHFConfig) -> None: super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" self.head_dim = config.n_embd // config.n_head self.block_size = config.block_size self.rope_base = config.rope_base self.rope_cache = None self.num_kv_groups = self.n_head // self.n_kv_head self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) self.softmax_scale = self.head_dim ** -0.5 def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: B, T, _ = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) if ( self.rope_cache is None or self.rope_cache.dtype != x.dtype or self.rope_cache.device != x.device ): self.rope_cache = build_rope_cache( seq_len=self.block_size, n_elem=self.head_dim, dtype=x.dtype, device=x.device, base=self.rope_base, ) q = apply_rope(q, self.rope_cache) k = apply_rope(k, self.rope_cache) if use_cache: if last_past is not None: past_key, past_value = last_past k = torch.cat([past_key, k], dim=-2) v = torch.cat([past_value, v], dim=-2) present = (k, v) else: present = None k_repeat = repeat_kv(k, self.num_kv_groups) v_repeat = repeat_kv(v, self.num_kv_groups) y = F.scaled_dot_product_attention( q, k_repeat, v_repeat, attn_mask=None, dropout_p=0.0, is_causal=True, scale=self.softmax_scale, ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) y = self.o_proj(y) if use_cache: return y, present return y def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: B, T, _ = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) if ( self.rope_cache is None or self.rope_cache.dtype != x.dtype or self.rope_cache.device != x.device ): self.rope_cache = build_rope_cache( seq_len=self.block_size, n_elem=self.head_dim, dtype=x.dtype, device=x.device, base=self.rope_base, ) q = apply_rope(q, self.rope_cache) k = apply_rope(k, self.rope_cache) if use_cache: if last_past is not None: past_key, past_value = last_past k = torch.cat([past_key, k], dim=-2) v = torch.cat([past_value, v], dim=-2) k_repeat = repeat_kv(k, self.num_kv_groups) v_repeat = repeat_kv(v, self.num_kv_groups) att = torch.matmul(q, k_repeat.transpose(-2, -1)) * self.softmax_scale att = F.softmax(att, dim=-1) y = torch.matmul(att, v_repeat) y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) y = self.o_proj(y) return y, att class CrossAttention(nn.Module): def __init__(self, config: LLaMAHFConfig) -> None: super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" self.head_dim = config.n_embd // config.n_head self.num_kv_groups = self.n_head // self.n_kv_head self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) self.softmax_scale = self.head_dim ** -0.5 def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: B, T, _ = x.size() _, S, _ = context.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) k_repeat = repeat_kv(k, self.num_kv_groups) v_repeat = repeat_kv(v, self.num_kv_groups) y = F.scaled_dot_product_attention( q, k_repeat, v_repeat, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.softmax_scale, ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) return self.o_proj(y) def repeat_kv(hidden_states: torch.Tensor, num_groups: int) -> torch.Tensor: if num_groups == 1: return hidden_states bsz, n_kv, seq_len, head_dim = hidden_states.shape hidden_states = hidden_states.unsqueeze(2).expand(bsz, n_kv, num_groups, seq_len, head_dim) return hidden_states.reshape(bsz, n_kv * num_groups, seq_len, head_dim) class LengthCausalSelfAttention(nn.Module): def __init__(self, config: LLaMAHFConfig) -> None: super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" self.head_dim = config.n_embd // config.n_head self.block_size = config.block_size self.rope_base = config.rope_base self.rope_cache = None self.num_kv_groups = self.n_head // self.n_kv_head self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) self.softmax_scale = self.head_dim ** -0.5 def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor: B, T, _ = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) if ( self.rope_cache is None or self.rope_cache.dtype != x.dtype or self.rope_cache.device != x.device ): self.rope_cache = build_rope_cache( seq_len=self.block_size, n_elem=self.head_dim, dtype=x.dtype, device=x.device, base=self.rope_base, ) q = apply_rope(q, self.rope_cache) k = apply_rope(k, self.rope_cache) attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device) attn_mask = torch.tril(attn_mask) attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1) text_mask = y_mask.unsqueeze(2) * y_mask.unsqueeze(1) text_mask = F.pad(text_mask, (0, T - y_mask.shape[1], 0, T - y_mask.shape[1]), mode='constant', value=0) attn_mask = torch.logical_or(attn_mask, text_mask) k_repeat = repeat_kv(k, self.num_kv_groups) v_repeat = repeat_kv(v, self.num_kv_groups) y = F.scaled_dot_product_attention( q, k_repeat, v_repeat, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False, scale=self.softmax_scale, ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) y = self.o_proj(y) return y class MLP(nn.Module): def __init__(self, config: LLaMAHFConfig) -> None: super().__init__() hidden_dim = 4 * config.n_embd n_hidden = int(2 * hidden_dim / 3) N = 256 # ensure n_hidden is multiple of N n_hidden = ((n_hidden - 1) // N) * N + N self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.c_fc1(x)) * self.c_fc2(x) x = self.c_proj(x) return x class RMSNorm(nn.Module): """Root Mean Square Layer Normalization. Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: super().__init__() self.scale = nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE: the original RMSNorm paper implementation is not equivalent # norm_x = x.norm(2, dim=self.dim, keepdim=True) # rms_x = norm_x * d_x ** (-1. / 2) # x_normed = x / (rms_x + self.eps) norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.scale * x_normed def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: """ Rotary-position cache with safe dtype handling. """ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) seq_idx = torch.arange(seq_len, dtype=dtype, device=device) idx_theta = torch.outer(seq_idx, theta) # cast to float32 for torch.polar when needed dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8] working_dtype = torch.float32 if dtype in dtypes_requiring_casting else dtype complex_dtype = torch.complex64 # torch.complex32 does not exist cache = torch.polar(torch.ones_like(idx_theta, dtype=working_dtype, device=device), idx_theta.to(working_dtype)).to(complex_dtype) return cache def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: x = x.transpose(1, 2) # truncate to support variable sizes T = x.size(1) rope_cache = rope_cache[:T] # cast because `view_as_complex` does not support 16 bit tensors xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) x_out = torch.view_as_real(xc * rope_cache).flatten(3) return x_out.transpose(1, 2).type_as(x)