motion-stream / models /llama_model.py
zirobtc's picture
Upload 2 files
3c212d2 verified
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self
from typing import Optional
from transformers.modeling_utils import PreTrainedModel
from torch.distributions import Categorical
@dataclass
class LLaMAHFConfig:
block_size: int = 156
n_layer: int = 32
n_head: int = 32
n_kv_head: Optional[int] = None
n_embd: int = 4096
rope_base: int = 500000
T5_xxl_dim: int = 768
@classmethod
def from_name(cls, name: str) -> Self:
return cls(**llama_configs[name])
llama_configs = {
"Normal_size": dict(n_layer=12, n_head=12, n_embd=768)
}
class LLaMAHF(nn.Module):
def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=6, n_diffusion_heads=4, input_token_dim=16, device=torch.device('cuda'), width=512) -> None:
super().__init__()
assert config.block_size is not None
self.config = config
cond_dim = config.T5_xxl_dim
self.transformer = nn.ModuleDict(
dict(
wte=nn.Linear(input_token_dim, config.n_embd), # vector tokens -> embeddings
cond_embed=nn.Linear(cond_dim, config.n_embd), # text feature -> context emb
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)
target_channels = input_token_dim
from models.diffloss import DiffLoss
self.diff_loss = DiffLoss(
target_channels=target_channels,
z_channels=config.n_embd,
width=width,
depth=num_diffusion_head_layers,
num_sampling_steps='50',
grad_checkpointing=False,
n_heads=n_diffusion_heads,
mlp_ratio=2.0
).to(device)
self.out_proj = nn.Linear(config.n_embd, config.n_embd)
self.use_out_proj = True
# --- Persistent prompt cache & BOS token ---
self._prompt_cached = False
self._prompt_bsz = None
self.bos = nn.Parameter(torch.zeros(1, 1, config.n_embd))
# === Needed by several sampling/forward paths ===
# projects raw text features when they are concatenated as tokens
self.llama_proj = nn.Linear(config.T5_xxl_dim, config.n_embd)
# special boundary-of-motion token used in forward_babel
self.BOM_tag = nn.Parameter(torch.zeros(1, 1, config.n_embd))
# (Optional) only if sample_for_eval_classification() is used:
# self.classify_head = nn.Linear(config.n_embd, num_classes)
@torch.no_grad()
def set_prompt(self, feature: torch.Tensor):
"""
Precompute and cache cross-attention K/V for the current prompt (feature).
Call this ONCE when you switch prompt (e.g., 'walk' -> 'crawl').
"""
context = self._prepare_context(feature)
if context is None:
raise ValueError("set_prompt: feature cannot be None")
self._prompt_bsz = context.size(0)
for blk in self.transformer.h:
blk.set_context_cache(context)
self._prompt_cached = True
@torch.no_grad()
def clear_prompt(self):
for blk in self.transformer.h:
blk.clear_context_cache()
self._prompt_cached = False
self._prompt_bsz = None
def _prepare_context(self, feature: Optional[torch.Tensor], batch_size: Optional[int] = None) -> Optional[torch.Tensor]:
if feature is None:
return None
if not torch.is_tensor(feature):
feature = torch.as_tensor(
feature,
dtype=self.transformer.cond_embed.weight.dtype,
device=self.transformer.cond_embed.weight.device,
)
else:
feature = feature.to(
dtype=self.transformer.cond_embed.weight.dtype,
device=self.transformer.cond_embed.weight.device,
)
if feature.dim() == 1:
feature = feature.unsqueeze(0)
context = self.transformer.cond_embed(feature)
if context.dim() == 2:
context = context.unsqueeze(1)
if batch_size is not None and context.size(0) != batch_size:
if context.size(0) == 1:
context = context.expand(batch_size, -1, -1)
else:
raise ValueError(
f"Condition batch ({context.size(0)}) does not match token batch ({batch_size})."
)
return context
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(
0,
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
),
"constant",
0,
)
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def get_input_embeddings(self):
return self.transformer.wte
def set_input_embeddings(self, value):
self.transformer.wte = value
def get_output_embeddings(self):
return self.out_proj
def set_output_embeddings(self, new_embeddings):
self.out_proj = new_embeddings
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor:
text_length = clip_feature.shape[1]
context = self._prepare_context(clip_feature)
if len(idx) == 0:
x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :]
else:
_, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# forward the LLaMA model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1)
if context is not None and context.size(0) != x.size(0):
if context.size(0) == 1:
context = context.expand(x.size(0), -1, -1)
else:
raise ValueError("Conditioning batch size does not match token batch size.")
for block in self.transformer.h:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = x
return logits
def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0):
max_token_len = length // unit_length
# Prepare conditioned prompt once and cache it
feat_text = torch.from_numpy(tokenize_model.encode(text)).float().to(device)
self.set_prompt(feat_text) # <-- persist until you change it
# Prepare empty/uncond prompt once and cache it too
empty_feat_text = torch.from_numpy(tokenize_model.encode('')).float().unsqueeze(0).to(device)
# We'll flip between two caches: cond and uncond
def _use_cond_cache():
self.set_prompt(feat_text)
def _use_uncond_cache():
self.set_prompt(empty_feat_text)
xs = None
for k in range(max_token_len):
x = [] if k == 0 else xs
# conditioned next-step
_use_cond_cache()
conditions = self.forward(x, feature=None)[:, -1, :]
# unconditioned next-step
_use_uncond_cache()
empty_conditions = self.forward(x, feature=None)[:, -1, :]
temperature = 1.0
if cfg != 1:
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1)
scaled_logits = scaled_logits.unsqueeze(0)
xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1)
# re-enable the conditioned prompt cache for whatever comes next
self.set_prompt(feat_text)
return xs
# For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold.
def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'),
unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0):
max_token_len = length // unit_length
feat_text = torch.from_numpy(tokenizer.encode(text)).float().to(device)
empty_feat_text = torch.from_numpy(tokenizer.encode('')).float().unsqueeze(0).to(device)
def _use_cond(): self.set_prompt(feat_text)
def _use_uncond(): self.set_prompt(empty_feat_text)
xs = None
for k in range(max_token_len):
x = [] if k == 0 else xs
_use_cond()
conditions = self.forward_inference(x, feature=None)[:, -1, :]
_use_uncond()
empty_conditions = self.forward(x, feature=None)[:, -1, :]
mix = torch.cat([conditions, empty_conditions], dim=0)
sampled = self.diff_loss.sample(mix, temperature=temperature, cfg=cfg)
scaled_logits, _ = sampled.chunk(2, dim=0) if cfg != 1 else (sampled, None)
scaled_logits = scaled_logits.unsqueeze(0)
if reference_end_latent is not None:
dist = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2))
if dist < threshold: break
xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1)
# leave the cond cache active
self.set_prompt(feat_text)
return xs
def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
import clip
max_token_len = length // unit_length
for k in range(max_token_len):
if k == 0:
x = []
else:
x = xs
try:
conditions = self.forward(x, feat_clip_text)
except:
conditions = self.forward(x, feat_clip_text.unsqueeze(0))
conditions = conditions[:, -1, :]
empty_conditions = self.forward(x, empty_feat_clip_text)
empty_conditions = empty_conditions[:, -1, :]
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if reference_end_token is not None:
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
print(distance_l2)
if distance_l2 < threshold:
break
if k == 0:
xs = scaled_logits
else:
xs = torch.cat((xs, scaled_logits), dim=1)
return xs
def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
import clip
max_token_len = length // unit_length
for k in range(1):
if current_token == []:
x = []
else:
x = torch.cat(current_token, dim=1)
try:
conditions = self.forward(x, feat_clip_text)
except:
conditions = self.forward(x, feat_clip_text.unsqueeze(0))
conditions = conditions[:, -1, :]
empty_conditions = self.forward(x, empty_feat_clip_text)
empty_conditions = empty_conditions[:, -1, :]
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if k == 0:
xs = scaled_logits
else:
xs = torch.cat((xs, scaled_logits), dim=1)
return xs
def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
import clip
B_token_length = length // unit_length - A_motion.shape[0]
if tokenizer == 'clip':
A_text = clip.tokenize(A_text, truncate=True).to(device)
A_feat_clip_text = clip_model.encode_text(A_text).float()
B_text = clip.tokenize(B_text, truncate=True).to(device)
B_feat_clip_text = clip_model.encode_text(B_text).float()
elif tokenizer == 't5-xxl':
A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
A_feat_clip_text = A_feat_clip_text.to(device)
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
B_feat_clip_text = B_feat_clip_text.to(device)
A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
A_motion = A_motion.unsqueeze(0)
A_motion_embeddings = self.transformer.wte(A_motion)
B_motion = torch.tensor([]).to(device)
for k in range(B_token_length):
if k == 0:
x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
else:
x = xs
conditions = self.forward_babel_eval(x)
conditions = conditions[:, -1, :]
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
if k == 0:
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
else:
B_motion_embeddings = self.transformer.wte(B_motion)
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
empty_conditions = empty_conditions[:, -1, :]
temperature = 1.0
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
B_motion = torch.cat((B_motion, scaled_logits), dim=1)
scaled_logits_embedding = self.transformer.wte(scaled_logits)
xs = torch.cat((x, scaled_logits_embedding), dim=1)
return xs, B_motion
def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
import clip
B_token_length = length // unit_length - A_motion.shape[0]
if tokenizer == 'clip':
A_text = clip.tokenize(A_text, truncate=True).to(device)
A_feat_clip_text = clip_model.encode_text(A_text).float()
B_text = clip.tokenize(B_text, truncate=True).to(device)
B_feat_clip_text = clip_model.encode_text(B_text).float()
elif tokenizer == 't5-xxl':
A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
A_feat_clip_text = A_feat_clip_text.to(device)
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
B_feat_clip_text = B_feat_clip_text.to(device)
A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
A_text_embeddings = A_text_embeddings.unsqueeze(0)
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
B_text_embeddings = B_text_embeddings.unsqueeze(0)
A_motion = A_motion.unsqueeze(0)
A_motion_embeddings = self.transformer.wte(A_motion)
B_motion = torch.tensor([]).to(device)
attention_weights = []
for k in range(B_token_length):
if k == 0:
x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
else:
x = xs
conditions = self.forward_babel_eval(x, return_attention=False)
conditions = conditions[:, -1, :]
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
if k == 0:
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
else:
B_motion_embeddings = self.transformer.wte(B_motion)
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
empty_conditions = empty_conditions[:, -1, :]
temperature = 1.0
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if reference_end_token is not None:
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
print(distance_l2)
if distance_l2 < threshold:
break
B_motion = torch.cat((B_motion, scaled_logits), dim=1)
scaled_logits_embedding = self.transformer.wte(scaled_logits)
xs = torch.cat((x, scaled_logits_embedding), dim=1)
return xs, B_motion
def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3):
import clip
B_token_length = length // unit_length
if tokenizer == 'clip':
A_text = clip.tokenize(A_text, truncate=True).to(device)
A_feat_clip_text = clip_model.encode_text(A_text).float()
B_text = clip.tokenize(B_text, truncate=True).to(device)
B_feat_clip_text = clip_model.encode_text(B_text).float()
elif tokenizer == 't5-xxl':
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
B_feat_clip_text = B_feat_clip_text.to(device)
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
A_motion = A_motion.unsqueeze(0)
A_motion_embeddings = self.transformer.wte(A_motion)
B_motion = torch.tensor([]).to(device)
attention_weights = []
for k in range(B_token_length):
if k == 0:
x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
else:
x = xs
conditions = self.forward_babel_eval(x, return_attention=False)
conditions = conditions[:, -1, :]
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
if k == 0:
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
else:
B_motion_embeddings = self.transformer.wte(B_motion)
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
empty_conditions = empty_conditions[:, -1, :]
temperature = 1.0
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if reference_end_token is not None:
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
print(distance_l2)
if distance_l2 < threshold:
break
B_motion = torch.cat((B_motion, scaled_logits), dim=1)
scaled_logits_embedding = self.transformer.wte(scaled_logits)
xs = torch.cat((x, scaled_logits_embedding), dim=1)
return xs, B_motion
def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
import clip
B_token_length = length // unit_length - A_motion.shape[0]
if tokenizer == 'clip':
A_text = clip.tokenize(A_text, truncate=True).to(device)
A_feat_clip_text = clip_model.encode_text(A_text).float()
B_text = clip.tokenize(B_text, truncate=True).to(device)
B_feat_clip_text = clip_model.encode_text(B_text).float()
elif tokenizer == 't5-xxl':
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
B_feat_clip_text = B_feat_clip_text.to(device)
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
B_text_embeddings = B_text_embeddings.unsqueeze(0)
A_motion = A_motion.unsqueeze(0)
A_motion_embeddings = self.transformer.wte(A_motion)
B_motion = torch.tensor([]).to(device)
# 存储所有层的注意力权重
attention_weights = []
for k in range(B_token_length):
if k == 0:
x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
else:
x = xs
conditions = self.forward_babel_eval(x, return_attention=False)
conditions = conditions[:, -1, :]
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
if k == 0:
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
else:
B_motion_embeddings = self.transformer.wte(B_motion)
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(empty_input)
empty_conditions = empty_conditions[:, -1, :]
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if reference_end_token is not None:
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
print(distance_l2)
if distance_l2 < threshold and k > 10:
break
B_motion = torch.cat((B_motion, scaled_logits), dim=1)
scaled_logits_embedding = self.transformer.wte(scaled_logits)
xs = torch.cat((x, scaled_logits_embedding), dim=1)
return xs, B_motion
def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
"""
Inference loop that mimics the "Two-Forward" training strategy.
This version correctly performs two full passes over the entire sequence.
"""
import clip
B_token_length = length // unit_length - A_motion.shape[0]
if tokenizer == 't5-xxl':
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device)
else:
raise NotImplementedError("Only t5-xxl is supported for this function.")
empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device)
# --- Create 3D embeddings [batch, seq, dim] ---
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0)
empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) # This is [1, 1, 768]
A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0))
# === 1. First Forward Pass (Generate Rough Draft) ===
rough_motion_tokens = A_motion
for k in range(B_token_length):
current_rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
# Conditioned
x_cond = torch.cat([B_text_embeddings, current_rough_embeddings], dim=1)
conditions = self.forward_babel_eval(x_cond, return_attention=False)[:, -1, :]
# Unconditioned
x_uncond = torch.cat([empty_text_embeddings, current_rough_embeddings], dim=1)
empty_conditions = self.forward_babel_eval(x_uncond, return_attention=False)[:, -1, :]
# Sample a rough prediction for the next token
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
pred_xstart_rough = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
if cfg != 1:
pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0)
rough_motion_tokens = torch.cat([rough_motion_tokens, pred_xstart_rough], dim=0)
# === 2. Second Forward Pass (Generate Refined Motion) ===
# Now we have the full rough draft. We use it as the input for the second pass.
refined_motion_tokens = A_motion
for k in range(B_token_length):
# The input to the transformer is the full rough sequence
rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
# Conditioned
x_cond_refined = torch.cat([B_text_embeddings, rough_embeddings], dim=1)
# We take the condition corresponding to the token we want to predict
conditions_refined = self.forward_babel_eval(x_cond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
# Unconditioned
x_uncond_refined = torch.cat([empty_text_embeddings, rough_embeddings], dim=1)
empty_conditions_refined = self.forward_babel_eval(x_uncond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
# Sample the final, refined token
mix_conditions_refined = torch.cat([conditions_refined, empty_conditions_refined], dim=0)
final_token, _ = self.diff_loss.sample(mix_conditions_refined, temperature=temperature, cfg=cfg).chunk(2, dim=0)
# Append the refined token to our final output history
refined_motion_tokens = torch.cat([refined_motion_tokens, final_token], dim=0)
# IMPORTANT: For the next step, we must update the "rough draft" with our new refined token
# This mimics the training where the input is a mix of GT and predictions.
# Here, it's a mix of the initial rough draft and the new refined tokens.
rough_motion_tokens[A_motion.shape[0] + k] = final_token.squeeze(0)
# Return only the newly generated tokens (B_motion)
B_motion = refined_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0)
return None, B_motion
#--------------Test classification head--------------------
def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
import clip
for k in range(51):
if k == 0:
x = []
else:
x = xs
if tokenizer == 'clip':
text = clip.tokenize(clip_text, truncate=True).to(device)
feat_clip_text = clip_model.encode_text(text).float()
elif tokenizer == 't5-xxl':
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
conditions = self.forward(x, feat_clip_text)
conditions = conditions[:, -1, :]
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
empty_conditions = self.forward(x, empty_feat_clip_text)
empty_conditions = empty_conditions[:, -1, :]
temperature = 1.0
cfg = 7.5
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
if cfg != 1:
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
scaled_logits = sampled_token_latent
prediction_logits = self.classify_head(conditions)
probs = torch.sigmoid(prediction_logits)
predicted_classes = torch.argmax(probs, dim=-1)
scaled_logits = scaled_logits.unsqueeze(0)
if k == 0:
xs = scaled_logits
else:
xs = torch.cat((xs, scaled_logits), dim=1)
if predicted_classes == 1:
break
return xs
#--------------------Test CFG-----------------------
def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
import clip
max_token_len = length // unit_length
for k in range(max_token_len):
if k == 0:
x = []
else:
x = xs
if cfg != 1:
if tokenizer == 'clip':
text = clip.tokenize(clip_text, truncate=True).to(device)
feat_clip_text = clip_model.encode_text(text).float()
elif tokenizer == 't5-xxl':
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
conditions = self.forward(x, feat_clip_text)
conditions = conditions[:, -1, :]
empty_clip_text = ''
if tokenizer == 'clip':
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
empty_feat_clip_text = clip_model.encode_text(empty_text).float()
elif tokenizer == 't5-xxl':
empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
empty_feat_clip_text = empty_feat_clip_text.to(device)
empty_conditions = self.forward(x, empty_feat_clip_text)
empty_conditions = empty_conditions[:, -1, :]
temperature = 1.0
mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
# chunk
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
else:
if tokenizer == 'clip':
text = clip.tokenize(clip_text, truncate=True).to(device)
feat_clip_text = clip_model.encode_text(text).float()
elif tokenizer == 't5-xxl':
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
feat_clip_text = feat_clip_text.to(device)
conditions = self.forward(x, feat_clip_text)
conditions = conditions[:, -1, :]
temperature = 1.0
sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg)
scaled_logits = sampled_token_latent
scaled_logits = scaled_logits.unsqueeze(0)
if k == 0:
xs = scaled_logits
else:
xs = torch.cat((xs, scaled_logits), dim=1)
return xs
#--------------------------------------------------
def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor:
"""
Vector-token path: idx must be shape [B, T, input_token_dim].
If you want discrete IDs instead, you must switch wte to nn.Embedding.
"""
context = None
if idx.numel() == 0:
context = self._prepare_context(clip_feature)
token_embeddings = context
if token_embeddings is None:
raise ValueError("Conditioning features are required when no motion tokens are provided.")
else:
b, t, _ = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
token_embeddings = self.transformer.wte(idx) # Linear -> [B, T, n_embd]
context = self._prepare_context(clip_feature, batch_size=b)
if context is not None:
token_embeddings = torch.cat([context, token_embeddings], dim=1)
x = token_embeddings
if use_cache and past_key_values is None:
past_key_values = [None] * len(self.transformer.h)
for i, block in enumerate(self.transformer.h):
if use_cache:
last_past = past_key_values[i]
x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache)
past_key_values[i] = list(presents)
else:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = self.out_proj(x)
return logits
def forward(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor:
"""
If self._prompt_cached is True, we DO NOT concat context each call.
Instead, blocks read the cached prompt KV.
Otherwise we embed and concat context as before.
"""
context = None
if len(idx) == 0:
if self._prompt_cached:
if self._prompt_bsz is None:
raise ValueError("Prompt cache set but batch size unknown.")
b = self._prompt_bsz
token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype)
else:
context = self._prepare_context(feature)
token_embeddings = context
if token_embeddings is None:
raise ValueError("Conditioning features are required when no motion tokens are provided.")
else:
b, t, c = idx.size()
idx = idx.float()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
token_embeddings = self.transformer.wte(idx)
if not self._prompt_cached:
context = self._prepare_context(feature, batch_size=b)
if context is not None:
token_embeddings = torch.cat([context, token_embeddings], dim=1)
# Always prepend BOS scene token
bos = self.bos.expand(token_embeddings.size(0), 1, -1)
x = torch.cat([bos, token_embeddings], dim=1)
# blocks: if context is None -> use cached prompt kv (if set)
for block in self.transformer.h:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = self.out_proj(x)
return logits
def forward_inference(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor:
context = None
if len(idx) == 0:
if self._prompt_cached:
if self._prompt_bsz is None:
raise ValueError("Prompt cache set but batch size unknown.")
b = self._prompt_bsz
token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype)
else:
context = self._prepare_context(feature)
token_embeddings = context
if token_embeddings is None:
raise ValueError("Conditioning features are required when no motion tokens are provided.")
else:
b, t, c = idx.size()
idx = idx.float()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
token_embeddings = self.transformer.wte(idx)
if not self._prompt_cached:
context = self._prepare_context(feature, batch_size=b)
if context is not None:
token_embeddings = torch.cat([context, token_embeddings], dim=1)
x = token_embeddings
if len(x.shape) == 2:
x = x.unsqueeze(0)
# prepend BOS
bos = self.bos.expand(x.size(0), 1, -1)
x = torch.cat([bos, x], dim=1)
if context is not None and context.size(0) != x.size(0):
if context.size(0) == 1:
context = context.expand(x.size(0), -1, -1)
else:
raise ValueError("Conditioning batch size does not match token batch size.")
for block in self.transformer.h:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = self.out_proj(x)
return logits
def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor:
b, t, c = idx.size()
idx = idx.float()
idx = self.transformer.wte(idx)
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
for i in range(b):
length_i = length[i][:num_subseq[i]]
clip_feature_i = clip_feature[i][:num_subseq[i]]
pointer = 0
for j in range(num_subseq[i]):
if j > 0:
pointer += length_i[j].item()
pointer += 1
pointer = int(pointer)
clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1)
idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0]
x = idx
context = None
if use_cache:
if past_key_values is None:
past_key_values = [None] * len(self.transformer.h)
for i,block in enumerate(self.transformer.h):
if use_cache:
last_past = past_key_values[i]
x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache)
past_key_values[i] = list(presents)
else:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = self.out_proj(x)
return logits
def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor:
layer_attentions = []
context = None
for block in self.transformer.h:
if return_attention:
x, att = block(x, context=context, return_attention=True)
layer_attentions.append(att)
else:
x = block(x, context=context)
x = self.transformer.ln_f(x)
if self.use_out_proj:
logits = self.out_proj(x)
else:
logits = x
if return_attention:
return logits, layer_attentions
return logits
def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor:
context = None
if len(idx) == 0: # inference
context = self._prepare_context(clip_feature)
token_embeddings = context
else:
b, t, c = idx.size()
idx = idx.float()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
A_feature = clip_feature[:, 0, :]
B_feature = clip_feature[:, 1, :]
A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1)
B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1)
context = torch.cat([A_text_embeddings, B_text_embeddings], dim=1)
token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device)
for i in range(b):
A_idx = idx[i, :A_token_length[i].item(), :]
B_idx = idx[i, A_token_length[i].item():-2, :]
token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024)
x = token_embeddings
if context is not None and context.size(0) != x.size(0):
if context.size(0) == 1:
context = context.expand(x.size(0), -1, -1)
else:
raise ValueError("Conditioning batch size does not match token batch size.")
for block in self.transformer.h:
x = block(x, context=context)
x = self.transformer.ln_f(x)
if self.use_out_proj:
logits = self.out_proj(x)
else:
logits = x
return logits
def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor:
context = None
if idx.numel() == 0: # inference with only context
context = self._prepare_context(clip_feature)
token_embeddings = context
else:
b, t, c = idx.size()
idx = idx.float()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
B_feature = clip_feature # [B, D] or [B, 1, D]
B_text_embeddings = self.transformer.cond_embed(B_feature) # [B, D] -> [B, D]
if B_text_embeddings.dim() == 2:
B_text_embeddings = B_text_embeddings.unsqueeze(1) # [B, 1, D]
context = B_text_embeddings # [B, 1, D]
idx_embeddings = self.transformer.wte(idx) # [B, T, D]
token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1) # [B, 1+T, D]
x = token_embeddings
if context is not None:
if context.dim() == 2:
context = context.unsqueeze(1)
if context.size(0) != x.size(0):
if context.size(0) == 1:
context = context.expand(x.size(0), -1, -1)
else:
raise ValueError("Conditioning batch size does not match token batch size.")
for block in self.transformer.h:
x = block(x, context=context)
x = self.transformer.ln_f(x)
logits = self.out_proj(x) if self.use_out_proj else x
return logits
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens (`int`, *optional*):
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = model_embeds.weight.shape[0]
# Tie weights again if needed
# self.tie_weights()
return model_embeds
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
new_embeddings.requires_grad_(old_embeddings_requires_grad)
self.set_input_embeddings(new_embeddings)
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
# if is_deepspeed_zero3_enabled():
# import deepspeed
# with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
# new_num_tokens = new_embeddings.weight.shape[0]
# else:
new_num_tokens = new_embeddings.weight.shape[0]
# if word embeddings are not tied, make sure that lm head is resized as well
# if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
if self.get_output_embeddings() is not None and not False:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
# if hasattr(old_lm_head, "_hf_hook"):
# hook = old_lm_head._hf_hook
# add_hook_to_module(new_lm_head, hook)
old_lm_head_requires_grad = old_lm_head.weight.requires_grad
new_lm_head.requires_grad_(old_lm_head_requires_grad)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()
def _get_resized_embeddings(
self,
old_embeddings: nn.Embedding,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end
Args:
old_embeddings (`torch.nn.Embedding`):
Old embeddings to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
`new_num_tokens` is `None`
"""
if pad_to_multiple_of is not None:
if not isinstance(pad_to_multiple_of, int):
raise ValueError(
f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
)
if new_num_tokens is None:
new_num_tokens = old_embeddings.weight.shape[0]
new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
else:
print(
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
" For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
" https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
)
if new_num_tokens is None:
return old_embeddings
# if is_deepspeed_zero3_enabled():
if False:
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
# if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
if old_num_tokens == new_num_tokens and not False:
return old_embeddings
if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
" should either use a different resize function or make sure that `old_embeddings` are an instance of"
f" {nn.Embedding}."
)
# Build new embeddings
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
# to errors when training.
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
# if is_deepspeed_zero3_enabled():
if False:
import deepspeed
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings
def _get_resized_lm_head(
self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end
Args:
old_lm_head (`torch.nn.Linear`):
Old lm head liner layer to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the linear matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
vocab_size` else `vocab_size, lm_head_dim`.
Return:
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
`None`
"""
if new_num_tokens is None:
return old_lm_head
# if is_deepspeed_zero3_enabled():
if False:
import deepspeed
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
else:
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
# if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
if old_num_tokens == new_num_tokens and not False:
return old_lm_head
if not isinstance(old_lm_head, nn.Linear):
raise TypeError(
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
" should either use a different resize function or make sure that `old_lm_head` are an instance of"
f" {nn.Linear}."
)
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
# When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
# because the shape of the new embedding layer is used across various modeling files
# as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
# to errors when training.
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# if is_deepspeed_zero3_enabled():
if False:
import deepspeed
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._copy_lm_head_original_to_resized(
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
)
else:
self._copy_lm_head_original_to_resized(
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
)
return new_lm_head
def _copy_lm_head_original_to_resized(
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
):
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
@classmethod
def from_name(cls, name: str) -> Self:
return cls(LLaMAHFConfig.from_name(name))
class Block(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.rms_cross = RMSNorm(config.n_embd)
self.cross_attn = CrossAttention(config)
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)
# cached prompt kv (precomputed by set_prompt)
self._ctx_k_repeat = None
self._ctx_v_repeat = None
self._ctx_bsz = None
@torch.no_grad()
def set_context_cache(self, context: torch.Tensor):
# Precompute KV for cross attention and repeat across kv groups
B, S, D = context.shape
ca = self.cross_attn
k = ca.k_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2)
v = ca.v_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2)
k = ca.k_norm(k)
# repeat K/V to match heads
self._ctx_k_repeat = repeat_kv(k, ca.num_kv_groups) # [B, n_head, S, d]
self._ctx_v_repeat = repeat_kv(v, ca.num_kv_groups) # [B, n_head, S, d]
self._ctx_bsz = B
@torch.no_grad()
def clear_context_cache(self):
self._ctx_k_repeat = None
self._ctx_v_repeat = None
self._ctx_bsz = None
def _cross_attend_cached(self, x: torch.Tensor):
# x: [B, T, D]
if self._ctx_k_repeat is None or self._ctx_v_repeat is None:
return x # no-op if no cached prompt
B, T, _ = x.size()
if self._ctx_bsz is not None and self._ctx_bsz != B:
# different batch: ignore cache (or you could raise)
return x
ca = self.cross_attn
q = ca.q_proj(x).view(B, T, ca.n_head, ca.head_dim).transpose(1, 2)
q = ca.q_norm(q)
y = F.scaled_dot_product_attention(
q, self._ctx_k_repeat, self._ctx_v_repeat,
attn_mask=None, dropout_p=0.0, is_causal=False, scale=ca.softmax_scale,
)
y = y.transpose(1, 2).contiguous().view(B, T, ca.n_head * ca.head_dim)
return x + ca.o_proj(y)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
last_past=None,
use_cache: bool = False,
return_attention: bool = False,
) -> torch.Tensor:
present = None
# self-attn
if use_cache:
if return_attention:
attn_output, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache)
else:
attn_output, present = self.attn(self.rms_1(x), last_past, use_cache)
x = x + attn_output
else:
if return_attention:
attn_output, attn = self.attn.forward_attn(self.rms_1(x))
else:
attn_output = self.attn(self.rms_1(x))
x = x + attn_output
# cross-attn: prefer live context if provided; else use cached prompt kv
if context is not None:
x = x + self.cross_attn(self.rms_cross(x), context)
else:
x = self._cross_attend_cached(self.rms_cross(x))
# mlp
x = x + self.mlp(self.rms_2(x))
if use_cache:
if return_attention:
return x, present, attn
else:
return x, present
else:
if return_attention:
return x, attn
else:
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
self.head_dim = config.n_embd // config.n_head
self.block_size = config.block_size
self.rope_base = config.rope_base
self.rope_cache = None
self.num_kv_groups = self.n_head // self.n_kv_head
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.softmax_scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
B, T, _ = x.size()
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
if (
self.rope_cache is None
or self.rope_cache.dtype != x.dtype
or self.rope_cache.device != x.device
):
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.head_dim,
dtype=x.dtype,
device=x.device,
base=self.rope_base,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
if use_cache:
if last_past is not None:
past_key, past_value = last_past
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
present = (k, v)
else:
present = None
k_repeat = repeat_kv(k, self.num_kv_groups)
v_repeat = repeat_kv(v, self.num_kv_groups)
y = F.scaled_dot_product_attention(
q,
k_repeat,
v_repeat,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
scale=self.softmax_scale,
)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
y = self.o_proj(y)
if use_cache:
return y, present
return y
def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
B, T, _ = x.size()
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
if (
self.rope_cache is None
or self.rope_cache.dtype != x.dtype
or self.rope_cache.device != x.device
):
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.head_dim,
dtype=x.dtype,
device=x.device,
base=self.rope_base,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
if use_cache:
if last_past is not None:
past_key, past_value = last_past
k = torch.cat([past_key, k], dim=-2)
v = torch.cat([past_value, v], dim=-2)
k_repeat = repeat_kv(k, self.num_kv_groups)
v_repeat = repeat_kv(v, self.num_kv_groups)
att = torch.matmul(q, k_repeat.transpose(-2, -1)) * self.softmax_scale
att = F.softmax(att, dim=-1)
y = torch.matmul(att, v_repeat)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
y = self.o_proj(y)
return y, att
class CrossAttention(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
self.head_dim = config.n_embd // config.n_head
self.num_kv_groups = self.n_head // self.n_kv_head
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.softmax_scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
B, T, _ = x.size()
_, S, _ = context.size()
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
k_repeat = repeat_kv(k, self.num_kv_groups)
v_repeat = repeat_kv(v, self.num_kv_groups)
y = F.scaled_dot_product_attention(
q,
k_repeat,
v_repeat,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=self.softmax_scale,
)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
return self.o_proj(y)
def repeat_kv(hidden_states: torch.Tensor, num_groups: int) -> torch.Tensor:
if num_groups == 1:
return hidden_states
bsz, n_kv, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states.unsqueeze(2).expand(bsz, n_kv, num_groups, seq_len, head_dim)
return hidden_states.reshape(bsz, n_kv * num_groups, seq_len, head_dim)
class LengthCausalSelfAttention(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
self.head_dim = config.n_embd // config.n_head
self.block_size = config.block_size
self.rope_base = config.rope_base
self.rope_cache = None
self.num_kv_groups = self.n_head // self.n_kv_head
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.softmax_scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
B, T, _ = x.size()
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
if (
self.rope_cache is None
or self.rope_cache.dtype != x.dtype
or self.rope_cache.device != x.device
):
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.head_dim,
dtype=x.dtype,
device=x.device,
base=self.rope_base,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device)
attn_mask = torch.tril(attn_mask)
attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
text_mask = y_mask.unsqueeze(2) * y_mask.unsqueeze(1)
text_mask = F.pad(text_mask, (0, T - y_mask.shape[1], 0, T - y_mask.shape[1]), mode='constant', value=0)
attn_mask = torch.logical_or(attn_mask, text_mask)
k_repeat = repeat_kv(k, self.num_kv_groups)
v_repeat = repeat_kv(v, self.num_kv_groups)
y = F.scaled_dot_product_attention(
q,
k_repeat,
v_repeat,
attn_mask=attn_mask.unsqueeze(1),
dropout_p=0.0,
is_causal=False,
scale=self.softmax_scale,
)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
y = self.o_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
hidden_dim = 4 * config.n_embd
n_hidden = int(2 * hidden_dim / 3)
N = 256
# ensure n_hidden is multiple of N
n_hidden = ((n_hidden - 1) // N) * N + N
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.scale = nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: the original RMSNorm paper implementation is not equivalent
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.scale * x_normed
def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
"""
Rotary-position cache with safe dtype handling.
"""
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
idx_theta = torch.outer(seq_idx, theta)
# cast to float32 for torch.polar when needed
dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
working_dtype = torch.float32 if dtype in dtypes_requiring_casting else dtype
complex_dtype = torch.complex64 # torch.complex32 does not exist
cache = torch.polar(torch.ones_like(idx_theta, dtype=working_dtype, device=device),
idx_theta.to(working_dtype)).to(complex_dtype)
return cache
def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
# truncate to support variable sizes
T = x.size(1)
rope_cache = rope_cache[:T]
# cast because `view_as_complex` does not support 16 bit tensors
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
x_out = torch.view_as_real(xc * rope_cache).flatten(3)
return x_out.transpose(1, 2).type_as(x)