|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from typing_extensions import Self |
|
|
from typing import Optional |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from torch.distributions import Categorical |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LLaMAHFConfig: |
|
|
block_size: int = 156 |
|
|
n_layer: int = 32 |
|
|
n_head: int = 32 |
|
|
n_kv_head: Optional[int] = None |
|
|
n_embd: int = 4096 |
|
|
rope_base: int = 500000 |
|
|
T5_xxl_dim: int = 768 |
|
|
|
|
|
@classmethod |
|
|
def from_name(cls, name: str) -> Self: |
|
|
return cls(**llama_configs[name]) |
|
|
|
|
|
|
|
|
llama_configs = { |
|
|
"Normal_size": dict(n_layer=12, n_head=12, n_embd=768) |
|
|
} |
|
|
|
|
|
|
|
|
class LLaMAHF(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=6, n_diffusion_heads=4, input_token_dim=16, device=torch.device('cuda'), width=512) -> None: |
|
|
super().__init__() |
|
|
assert config.block_size is not None |
|
|
self.config = config |
|
|
|
|
|
cond_dim = config.T5_xxl_dim |
|
|
|
|
|
self.transformer = nn.ModuleDict( |
|
|
dict( |
|
|
wte=nn.Linear(input_token_dim, config.n_embd), |
|
|
cond_embed=nn.Linear(cond_dim, config.n_embd), |
|
|
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
|
ln_f=RMSNorm(config.n_embd), |
|
|
) |
|
|
) |
|
|
|
|
|
target_channels = input_token_dim |
|
|
from models.diffloss import DiffLoss |
|
|
self.diff_loss = DiffLoss( |
|
|
target_channels=target_channels, |
|
|
z_channels=config.n_embd, |
|
|
width=width, |
|
|
depth=num_diffusion_head_layers, |
|
|
num_sampling_steps='50', |
|
|
grad_checkpointing=False, |
|
|
n_heads=n_diffusion_heads, |
|
|
mlp_ratio=2.0 |
|
|
).to(device) |
|
|
|
|
|
self.out_proj = nn.Linear(config.n_embd, config.n_embd) |
|
|
self.use_out_proj = True |
|
|
|
|
|
|
|
|
self._prompt_cached = False |
|
|
self._prompt_bsz = None |
|
|
self.bos = nn.Parameter(torch.zeros(1, 1, config.n_embd)) |
|
|
|
|
|
|
|
|
|
|
|
self.llama_proj = nn.Linear(config.T5_xxl_dim, config.n_embd) |
|
|
|
|
|
self.BOM_tag = nn.Parameter(torch.zeros(1, 1, config.n_embd)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def set_prompt(self, feature: torch.Tensor): |
|
|
""" |
|
|
Precompute and cache cross-attention K/V for the current prompt (feature). |
|
|
Call this ONCE when you switch prompt (e.g., 'walk' -> 'crawl'). |
|
|
""" |
|
|
context = self._prepare_context(feature) |
|
|
if context is None: |
|
|
raise ValueError("set_prompt: feature cannot be None") |
|
|
|
|
|
self._prompt_bsz = context.size(0) |
|
|
for blk in self.transformer.h: |
|
|
blk.set_context_cache(context) |
|
|
self._prompt_cached = True |
|
|
|
|
|
@torch.no_grad() |
|
|
def clear_prompt(self): |
|
|
for blk in self.transformer.h: |
|
|
blk.clear_context_cache() |
|
|
self._prompt_cached = False |
|
|
self._prompt_bsz = None |
|
|
|
|
|
def _prepare_context(self, feature: Optional[torch.Tensor], batch_size: Optional[int] = None) -> Optional[torch.Tensor]: |
|
|
if feature is None: |
|
|
return None |
|
|
if not torch.is_tensor(feature): |
|
|
feature = torch.as_tensor( |
|
|
feature, |
|
|
dtype=self.transformer.cond_embed.weight.dtype, |
|
|
device=self.transformer.cond_embed.weight.device, |
|
|
) |
|
|
else: |
|
|
feature = feature.to( |
|
|
dtype=self.transformer.cond_embed.weight.dtype, |
|
|
device=self.transformer.cond_embed.weight.device, |
|
|
) |
|
|
|
|
|
if feature.dim() == 1: |
|
|
feature = feature.unsqueeze(0) |
|
|
|
|
|
context = self.transformer.cond_embed(feature) |
|
|
if context.dim() == 2: |
|
|
context = context.unsqueeze(1) |
|
|
|
|
|
if batch_size is not None and context.size(0) != batch_size: |
|
|
if context.size(0) == 1: |
|
|
context = context.expand(batch_size, -1, -1) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Condition batch ({context.size(0)}) does not match token batch ({batch_size})." |
|
|
) |
|
|
return context |
|
|
|
|
|
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): |
|
|
"""Tie or clone module weights depending of whether we are using TorchScript or not""" |
|
|
output_embeddings.weight = input_embeddings.weight |
|
|
|
|
|
if getattr(output_embeddings, "bias", None) is not None: |
|
|
output_embeddings.bias.data = nn.functional.pad( |
|
|
output_embeddings.bias.data, |
|
|
( |
|
|
0, |
|
|
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], |
|
|
), |
|
|
"constant", |
|
|
0, |
|
|
) |
|
|
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): |
|
|
output_embeddings.out_features = input_embeddings.num_embeddings |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.transformer.wte |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.transformer.wte = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.out_proj |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.out_proj = new_embeddings |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
|
|
|
|
|
|
|
|
|
|
|
def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor: |
|
|
|
|
|
text_length = clip_feature.shape[1] |
|
|
context = self._prepare_context(clip_feature) |
|
|
if len(idx) == 0: |
|
|
x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :] |
|
|
else: |
|
|
_, t = idx.size() |
|
|
assert ( |
|
|
t <= self.config.block_size |
|
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
|
|
|
x = self.transformer.wte(idx) |
|
|
x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1) |
|
|
|
|
|
if context is not None and context.size(0) != x.size(0): |
|
|
if context.size(0) == 1: |
|
|
context = context.expand(x.size(0), -1, -1) |
|
|
else: |
|
|
raise ValueError("Conditioning batch size does not match token batch size.") |
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
logits = x |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0): |
|
|
max_token_len = length // unit_length |
|
|
|
|
|
|
|
|
feat_text = torch.from_numpy(tokenize_model.encode(text)).float().to(device) |
|
|
self.set_prompt(feat_text) |
|
|
|
|
|
|
|
|
empty_feat_text = torch.from_numpy(tokenize_model.encode('')).float().unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
def _use_cond_cache(): |
|
|
self.set_prompt(feat_text) |
|
|
|
|
|
def _use_uncond_cache(): |
|
|
self.set_prompt(empty_feat_text) |
|
|
|
|
|
xs = None |
|
|
for k in range(max_token_len): |
|
|
x = [] if k == 0 else xs |
|
|
|
|
|
|
|
|
_use_cond_cache() |
|
|
conditions = self.forward(x, feature=None)[:, -1, :] |
|
|
|
|
|
|
|
|
_use_uncond_cache() |
|
|
empty_conditions = self.forward(x, feature=None)[:, -1, :] |
|
|
|
|
|
temperature = 1.0 |
|
|
if cfg != 1: |
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1) |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
|
|
|
self.set_prompt(feat_text) |
|
|
return xs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), |
|
|
unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0): |
|
|
max_token_len = length // unit_length |
|
|
feat_text = torch.from_numpy(tokenizer.encode(text)).float().to(device) |
|
|
empty_feat_text = torch.from_numpy(tokenizer.encode('')).float().unsqueeze(0).to(device) |
|
|
|
|
|
def _use_cond(): self.set_prompt(feat_text) |
|
|
def _use_uncond(): self.set_prompt(empty_feat_text) |
|
|
|
|
|
xs = None |
|
|
for k in range(max_token_len): |
|
|
x = [] if k == 0 else xs |
|
|
|
|
|
_use_cond() |
|
|
conditions = self.forward_inference(x, feature=None)[:, -1, :] |
|
|
|
|
|
_use_uncond() |
|
|
empty_conditions = self.forward(x, feature=None)[:, -1, :] |
|
|
|
|
|
mix = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled = self.diff_loss.sample(mix, temperature=temperature, cfg=cfg) |
|
|
scaled_logits, _ = sampled.chunk(2, dim=0) if cfg != 1 else (sampled, None) |
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if reference_end_latent is not None: |
|
|
dist = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2)) |
|
|
if dist < threshold: break |
|
|
|
|
|
xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
|
|
|
self.set_prompt(feat_text) |
|
|
return xs |
|
|
|
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): |
|
|
|
|
|
import clip |
|
|
max_token_len = length // unit_length |
|
|
|
|
|
for k in range(max_token_len): |
|
|
if k == 0: |
|
|
x = [] |
|
|
else: |
|
|
x = xs |
|
|
|
|
|
try: |
|
|
conditions = self.forward(x, feat_clip_text) |
|
|
except: |
|
|
conditions = self.forward(x, feat_clip_text.unsqueeze(0)) |
|
|
|
|
|
|
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
|
|
|
|
|
|
empty_conditions = self.forward(x, empty_feat_clip_text) |
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if reference_end_token is not None: |
|
|
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
|
|
print(distance_l2) |
|
|
if distance_l2 < threshold: |
|
|
break |
|
|
|
|
|
if k == 0: |
|
|
xs = scaled_logits |
|
|
else: |
|
|
xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
return xs |
|
|
|
|
|
def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): |
|
|
|
|
|
import clip |
|
|
max_token_len = length // unit_length |
|
|
|
|
|
|
|
|
for k in range(1): |
|
|
|
|
|
if current_token == []: |
|
|
x = [] |
|
|
else: |
|
|
x = torch.cat(current_token, dim=1) |
|
|
|
|
|
|
|
|
try: |
|
|
conditions = self.forward(x, feat_clip_text) |
|
|
except: |
|
|
conditions = self.forward(x, feat_clip_text.unsqueeze(0)) |
|
|
|
|
|
|
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
|
|
|
empty_conditions = self.forward(x, empty_feat_clip_text) |
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
|
|
|
if k == 0: |
|
|
xs = scaled_logits |
|
|
else: |
|
|
xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
return xs |
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): |
|
|
|
|
|
import clip |
|
|
B_token_length = length // unit_length - A_motion.shape[0] |
|
|
|
|
|
if tokenizer == 'clip': |
|
|
A_text = clip.tokenize(A_text, truncate=True).to(device) |
|
|
A_feat_clip_text = clip_model.encode_text(A_text).float() |
|
|
B_text = clip.tokenize(B_text, truncate=True).to(device) |
|
|
B_feat_clip_text = clip_model.encode_text(B_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() |
|
|
A_feat_clip_text = A_feat_clip_text.to(device) |
|
|
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
|
|
B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
|
|
|
A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) |
|
|
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
A_motion = A_motion.unsqueeze(0) |
|
|
A_motion_embeddings = self.transformer.wte(A_motion) |
|
|
B_motion = torch.tensor([]).to(device) |
|
|
|
|
|
for k in range(B_token_length): |
|
|
if k == 0: |
|
|
x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) |
|
|
else: |
|
|
x = xs |
|
|
|
|
|
|
|
|
conditions = self.forward_babel_eval(x) |
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
else: |
|
|
B_motion_embeddings = self.transformer.wte(B_motion) |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
|
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
temperature = 1.0 |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
|
|
|
B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
|
|
|
scaled_logits_embedding = self.transformer.wte(scaled_logits) |
|
|
xs = torch.cat((x, scaled_logits_embedding), dim=1) |
|
|
|
|
|
|
|
|
return xs, B_motion |
|
|
|
|
|
def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): |
|
|
|
|
|
import clip |
|
|
B_token_length = length // unit_length - A_motion.shape[0] |
|
|
|
|
|
if tokenizer == 'clip': |
|
|
A_text = clip.tokenize(A_text, truncate=True).to(device) |
|
|
A_feat_clip_text = clip_model.encode_text(A_text).float() |
|
|
B_text = clip.tokenize(B_text, truncate=True).to(device) |
|
|
B_feat_clip_text = clip_model.encode_text(B_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() |
|
|
A_feat_clip_text = A_feat_clip_text.to(device) |
|
|
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
|
|
B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
|
|
|
A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) |
|
|
A_text_embeddings = A_text_embeddings.unsqueeze(0) |
|
|
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
B_text_embeddings = B_text_embeddings.unsqueeze(0) |
|
|
|
|
|
A_motion = A_motion.unsqueeze(0) |
|
|
A_motion_embeddings = self.transformer.wte(A_motion) |
|
|
B_motion = torch.tensor([]).to(device) |
|
|
|
|
|
attention_weights = [] |
|
|
|
|
|
for k in range(B_token_length): |
|
|
if k == 0: |
|
|
x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) |
|
|
|
|
|
else: |
|
|
x = xs |
|
|
|
|
|
|
|
|
|
|
|
conditions = self.forward_babel_eval(x, return_attention=False) |
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
else: |
|
|
B_motion_embeddings = self.transformer.wte(B_motion) |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
|
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
temperature = 1.0 |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if reference_end_token is not None: |
|
|
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
|
|
print(distance_l2) |
|
|
if distance_l2 < threshold: |
|
|
break |
|
|
|
|
|
B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
|
|
|
scaled_logits_embedding = self.transformer.wte(scaled_logits) |
|
|
xs = torch.cat((x, scaled_logits_embedding), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
return xs, B_motion |
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3): |
|
|
|
|
|
import clip |
|
|
B_token_length = length // unit_length |
|
|
|
|
|
if tokenizer == 'clip': |
|
|
A_text = clip.tokenize(A_text, truncate=True).to(device) |
|
|
A_feat_clip_text = clip_model.encode_text(A_text).float() |
|
|
B_text = clip.tokenize(B_text, truncate=True).to(device) |
|
|
B_feat_clip_text = clip_model.encode_text(B_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
|
|
B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
|
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
A_motion = A_motion.unsqueeze(0) |
|
|
A_motion_embeddings = self.transformer.wte(A_motion) |
|
|
B_motion = torch.tensor([]).to(device) |
|
|
|
|
|
|
|
|
attention_weights = [] |
|
|
|
|
|
for k in range(B_token_length): |
|
|
if k == 0: |
|
|
x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) |
|
|
else: |
|
|
x = xs |
|
|
|
|
|
conditions = self.forward_babel_eval(x, return_attention=False) |
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
|
|
|
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) |
|
|
|
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
else: |
|
|
B_motion_embeddings = self.transformer.wte(B_motion) |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
|
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
temperature = 1.0 |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if reference_end_token is not None: |
|
|
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
|
|
print(distance_l2) |
|
|
if distance_l2 < threshold: |
|
|
break |
|
|
|
|
|
B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
|
|
|
scaled_logits_embedding = self.transformer.wte(scaled_logits) |
|
|
xs = torch.cat((x, scaled_logits_embedding), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
return xs, B_motion |
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): |
|
|
|
|
|
import clip |
|
|
B_token_length = length // unit_length - A_motion.shape[0] |
|
|
|
|
|
if tokenizer == 'clip': |
|
|
A_text = clip.tokenize(A_text, truncate=True).to(device) |
|
|
A_feat_clip_text = clip_model.encode_text(A_text).float() |
|
|
B_text = clip.tokenize(B_text, truncate=True).to(device) |
|
|
B_feat_clip_text = clip_model.encode_text(B_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
|
|
B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
|
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
B_text_embeddings = B_text_embeddings.unsqueeze(0) |
|
|
|
|
|
A_motion = A_motion.unsqueeze(0) |
|
|
A_motion_embeddings = self.transformer.wte(A_motion) |
|
|
B_motion = torch.tensor([]).to(device) |
|
|
|
|
|
|
|
|
attention_weights = [] |
|
|
|
|
|
for k in range(B_token_length): |
|
|
if k == 0: |
|
|
x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) |
|
|
|
|
|
else: |
|
|
x = xs |
|
|
|
|
|
|
|
|
conditions = self.forward_babel_eval(x, return_attention=False) |
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
|
|
|
empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
else: |
|
|
B_motion_embeddings = self.transformer.wte(B_motion) |
|
|
empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(empty_input) |
|
|
|
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if reference_end_token is not None: |
|
|
distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
|
|
print(distance_l2) |
|
|
if distance_l2 < threshold and k > 10: |
|
|
break |
|
|
|
|
|
B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
|
|
|
scaled_logits_embedding = self.transformer.wte(scaled_logits) |
|
|
xs = torch.cat((x, scaled_logits_embedding), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
return xs, B_motion |
|
|
|
|
|
def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): |
|
|
""" |
|
|
Inference loop that mimics the "Two-Forward" training strategy. |
|
|
This version correctly performs two full passes over the entire sequence. |
|
|
""" |
|
|
import clip |
|
|
B_token_length = length // unit_length - A_motion.shape[0] |
|
|
|
|
|
if tokenizer == 't5-xxl': |
|
|
B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device) |
|
|
else: |
|
|
raise NotImplementedError("Only t5-xxl is supported for this function.") |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0) |
|
|
empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
|
|
|
|
|
A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0)) |
|
|
|
|
|
|
|
|
rough_motion_tokens = A_motion |
|
|
for k in range(B_token_length): |
|
|
current_rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0)) |
|
|
|
|
|
|
|
|
x_cond = torch.cat([B_text_embeddings, current_rough_embeddings], dim=1) |
|
|
conditions = self.forward_babel_eval(x_cond, return_attention=False)[:, -1, :] |
|
|
|
|
|
|
|
|
x_uncond = torch.cat([empty_text_embeddings, current_rough_embeddings], dim=1) |
|
|
empty_conditions = self.forward_babel_eval(x_uncond, return_attention=False)[:, -1, :] |
|
|
|
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
pred_xstart_rough = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
if cfg != 1: |
|
|
pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0) |
|
|
|
|
|
rough_motion_tokens = torch.cat([rough_motion_tokens, pred_xstart_rough], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
refined_motion_tokens = A_motion |
|
|
for k in range(B_token_length): |
|
|
|
|
|
rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0)) |
|
|
|
|
|
|
|
|
x_cond_refined = torch.cat([B_text_embeddings, rough_embeddings], dim=1) |
|
|
|
|
|
conditions_refined = self.forward_babel_eval(x_cond_refined, return_attention=False)[:, A_motion.shape[0] + k, :] |
|
|
|
|
|
|
|
|
x_uncond_refined = torch.cat([empty_text_embeddings, rough_embeddings], dim=1) |
|
|
empty_conditions_refined = self.forward_babel_eval(x_uncond_refined, return_attention=False)[:, A_motion.shape[0] + k, :] |
|
|
|
|
|
|
|
|
mix_conditions_refined = torch.cat([conditions_refined, empty_conditions_refined], dim=0) |
|
|
final_token, _ = self.diff_loss.sample(mix_conditions_refined, temperature=temperature, cfg=cfg).chunk(2, dim=0) |
|
|
|
|
|
|
|
|
refined_motion_tokens = torch.cat([refined_motion_tokens, final_token], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rough_motion_tokens[A_motion.shape[0] + k] = final_token.squeeze(0) |
|
|
|
|
|
|
|
|
B_motion = refined_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0) |
|
|
return None, B_motion |
|
|
|
|
|
|
|
|
|
|
|
def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4): |
|
|
|
|
|
import clip |
|
|
|
|
|
|
|
|
for k in range(51): |
|
|
if k == 0: |
|
|
x = [] |
|
|
else: |
|
|
x = xs |
|
|
|
|
|
if tokenizer == 'clip': |
|
|
text = clip.tokenize(clip_text, truncate=True).to(device) |
|
|
|
|
|
feat_clip_text = clip_model.encode_text(text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
|
|
|
|
|
conditions = self.forward(x, feat_clip_text) |
|
|
conditions = conditions[:, -1, :] |
|
|
|
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
empty_conditions = self.forward(x, empty_feat_clip_text) |
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
|
|
|
temperature = 1.0 |
|
|
cfg = 7.5 |
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
else: |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
|
|
|
prediction_logits = self.classify_head(conditions) |
|
|
probs = torch.sigmoid(prediction_logits) |
|
|
predicted_classes = torch.argmax(probs, dim=-1) |
|
|
|
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
xs = scaled_logits |
|
|
else: |
|
|
xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
if predicted_classes == 1: |
|
|
break |
|
|
|
|
|
return xs |
|
|
|
|
|
|
|
|
|
|
|
def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4): |
|
|
|
|
|
import clip |
|
|
max_token_len = length // unit_length |
|
|
|
|
|
|
|
|
for k in range(max_token_len): |
|
|
if k == 0: |
|
|
x = [] |
|
|
else: |
|
|
x = xs |
|
|
|
|
|
|
|
|
if cfg != 1: |
|
|
if tokenizer == 'clip': |
|
|
text = clip.tokenize(clip_text, truncate=True).to(device) |
|
|
|
|
|
feat_clip_text = clip_model.encode_text(text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
|
|
|
|
|
conditions = self.forward(x, feat_clip_text) |
|
|
|
|
|
conditions = conditions[:, -1, :] |
|
|
empty_clip_text = '' |
|
|
if tokenizer == 'clip': |
|
|
empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
|
|
empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() |
|
|
empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
|
|
empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
|
|
|
empty_conditions = self.forward(x, empty_feat_clip_text) |
|
|
empty_conditions = empty_conditions[:, -1, :] |
|
|
temperature = 1.0 |
|
|
|
|
|
|
|
|
mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
|
|
sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
|
|
|
|
|
|
scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
|
|
|
|
|
else: |
|
|
if tokenizer == 'clip': |
|
|
text = clip.tokenize(clip_text, truncate=True).to(device) |
|
|
feat_clip_text = clip_model.encode_text(text).float() |
|
|
elif tokenizer == 't5-xxl': |
|
|
feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
|
|
feat_clip_text = feat_clip_text.to(device) |
|
|
|
|
|
|
|
|
conditions = self.forward(x, feat_clip_text) |
|
|
|
|
|
conditions = conditions[:, -1, :] |
|
|
temperature = 1.0 |
|
|
sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg) |
|
|
scaled_logits = sampled_token_latent |
|
|
|
|
|
scaled_logits = scaled_logits.unsqueeze(0) |
|
|
|
|
|
if k == 0: |
|
|
xs = scaled_logits |
|
|
else: |
|
|
xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
|
|
|
return xs |
|
|
|
|
|
|
|
|
def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor: |
|
|
""" |
|
|
Vector-token path: idx must be shape [B, T, input_token_dim]. |
|
|
If you want discrete IDs instead, you must switch wte to nn.Embedding. |
|
|
""" |
|
|
context = None |
|
|
if idx.numel() == 0: |
|
|
context = self._prepare_context(clip_feature) |
|
|
token_embeddings = context |
|
|
if token_embeddings is None: |
|
|
raise ValueError("Conditioning features are required when no motion tokens are provided.") |
|
|
else: |
|
|
b, t, _ = idx.size() |
|
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
token_embeddings = self.transformer.wte(idx) |
|
|
context = self._prepare_context(clip_feature, batch_size=b) |
|
|
if context is not None: |
|
|
token_embeddings = torch.cat([context, token_embeddings], dim=1) |
|
|
|
|
|
x = token_embeddings |
|
|
|
|
|
if use_cache and past_key_values is None: |
|
|
past_key_values = [None] * len(self.transformer.h) |
|
|
|
|
|
for i, block in enumerate(self.transformer.h): |
|
|
if use_cache: |
|
|
last_past = past_key_values[i] |
|
|
x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache) |
|
|
past_key_values[i] = list(presents) |
|
|
else: |
|
|
x = block(x, context=context) |
|
|
|
|
|
x = self.transformer.ln_f(x) |
|
|
logits = self.out_proj(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
def forward(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
If self._prompt_cached is True, we DO NOT concat context each call. |
|
|
Instead, blocks read the cached prompt KV. |
|
|
Otherwise we embed and concat context as before. |
|
|
""" |
|
|
context = None |
|
|
if len(idx) == 0: |
|
|
if self._prompt_cached: |
|
|
if self._prompt_bsz is None: |
|
|
raise ValueError("Prompt cache set but batch size unknown.") |
|
|
b = self._prompt_bsz |
|
|
token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype) |
|
|
else: |
|
|
context = self._prepare_context(feature) |
|
|
token_embeddings = context |
|
|
if token_embeddings is None: |
|
|
raise ValueError("Conditioning features are required when no motion tokens are provided.") |
|
|
else: |
|
|
b, t, c = idx.size() |
|
|
idx = idx.float() |
|
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
token_embeddings = self.transformer.wte(idx) |
|
|
if not self._prompt_cached: |
|
|
context = self._prepare_context(feature, batch_size=b) |
|
|
if context is not None: |
|
|
token_embeddings = torch.cat([context, token_embeddings], dim=1) |
|
|
|
|
|
|
|
|
bos = self.bos.expand(token_embeddings.size(0), 1, -1) |
|
|
x = torch.cat([bos, token_embeddings], dim=1) |
|
|
|
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
logits = self.out_proj(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
def forward_inference(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
context = None |
|
|
if len(idx) == 0: |
|
|
if self._prompt_cached: |
|
|
if self._prompt_bsz is None: |
|
|
raise ValueError("Prompt cache set but batch size unknown.") |
|
|
b = self._prompt_bsz |
|
|
token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype) |
|
|
else: |
|
|
context = self._prepare_context(feature) |
|
|
token_embeddings = context |
|
|
if token_embeddings is None: |
|
|
raise ValueError("Conditioning features are required when no motion tokens are provided.") |
|
|
else: |
|
|
b, t, c = idx.size() |
|
|
idx = idx.float() |
|
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
token_embeddings = self.transformer.wte(idx) |
|
|
if not self._prompt_cached: |
|
|
context = self._prepare_context(feature, batch_size=b) |
|
|
if context is not None: |
|
|
token_embeddings = torch.cat([context, token_embeddings], dim=1) |
|
|
|
|
|
x = token_embeddings |
|
|
if len(x.shape) == 2: |
|
|
x = x.unsqueeze(0) |
|
|
|
|
|
|
|
|
bos = self.bos.expand(x.size(0), 1, -1) |
|
|
x = torch.cat([bos, x], dim=1) |
|
|
|
|
|
if context is not None and context.size(0) != x.size(0): |
|
|
if context.size(0) == 1: |
|
|
context = context.expand(x.size(0), -1, -1) |
|
|
else: |
|
|
raise ValueError("Conditioning batch size does not match token batch size.") |
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
logits = self.out_proj(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor: |
|
|
|
|
|
b, t, c = idx.size() |
|
|
idx = idx.float() |
|
|
idx = self.transformer.wte(idx) |
|
|
assert ( |
|
|
t <= self.config.block_size |
|
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
for i in range(b): |
|
|
length_i = length[i][:num_subseq[i]] |
|
|
clip_feature_i = clip_feature[i][:num_subseq[i]] |
|
|
|
|
|
pointer = 0 |
|
|
for j in range(num_subseq[i]): |
|
|
if j > 0: |
|
|
pointer += length_i[j].item() |
|
|
pointer += 1 |
|
|
pointer = int(pointer) |
|
|
|
|
|
clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1) |
|
|
idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0] |
|
|
|
|
|
x = idx |
|
|
|
|
|
context = None |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
if past_key_values is None: |
|
|
past_key_values = [None] * len(self.transformer.h) |
|
|
|
|
|
|
|
|
for i,block in enumerate(self.transformer.h): |
|
|
if use_cache: |
|
|
last_past = past_key_values[i] |
|
|
x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache) |
|
|
past_key_values[i] = list(presents) |
|
|
else: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
logits = self.out_proj(x) |
|
|
return logits |
|
|
|
|
|
|
|
|
def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor: |
|
|
layer_attentions = [] |
|
|
context = None |
|
|
for block in self.transformer.h: |
|
|
if return_attention: |
|
|
x, att = block(x, context=context, return_attention=True) |
|
|
layer_attentions.append(att) |
|
|
else: |
|
|
x = block(x, context=context) |
|
|
|
|
|
x = self.transformer.ln_f(x) |
|
|
if self.use_out_proj: |
|
|
logits = self.out_proj(x) |
|
|
else: |
|
|
logits = x |
|
|
|
|
|
if return_attention: |
|
|
return logits, layer_attentions |
|
|
return logits |
|
|
|
|
|
def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor: |
|
|
context = None |
|
|
if len(idx) == 0: |
|
|
context = self._prepare_context(clip_feature) |
|
|
token_embeddings = context |
|
|
|
|
|
else: |
|
|
b, t, c = idx.size() |
|
|
idx = idx.float() |
|
|
assert ( |
|
|
t <= self.config.block_size |
|
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
|
|
|
|
|
|
|
|
|
A_feature = clip_feature[:, 0, :] |
|
|
B_feature = clip_feature[:, 1, :] |
|
|
|
|
|
|
|
|
A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1) |
|
|
B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1) |
|
|
context = torch.cat([A_text_embeddings, B_text_embeddings], dim=1) |
|
|
|
|
|
token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device) |
|
|
for i in range(b): |
|
|
A_idx = idx[i, :A_token_length[i].item(), :] |
|
|
B_idx = idx[i, A_token_length[i].item():-2, :] |
|
|
token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) |
|
|
|
|
|
x = token_embeddings |
|
|
if context is not None and context.size(0) != x.size(0): |
|
|
if context.size(0) == 1: |
|
|
context = context.expand(x.size(0), -1, -1) |
|
|
else: |
|
|
raise ValueError("Conditioning batch size does not match token batch size.") |
|
|
for block in self.transformer.h: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
if self.use_out_proj: |
|
|
logits = self.out_proj(x) |
|
|
else: |
|
|
logits = x |
|
|
|
|
|
|
|
|
return logits |
|
|
|
|
|
def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor: |
|
|
context = None |
|
|
if idx.numel() == 0: |
|
|
context = self._prepare_context(clip_feature) |
|
|
token_embeddings = context |
|
|
else: |
|
|
b, t, c = idx.size() |
|
|
idx = idx.float() |
|
|
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
|
|
|
B_feature = clip_feature |
|
|
B_text_embeddings = self.transformer.cond_embed(B_feature) |
|
|
if B_text_embeddings.dim() == 2: |
|
|
B_text_embeddings = B_text_embeddings.unsqueeze(1) |
|
|
context = B_text_embeddings |
|
|
|
|
|
idx_embeddings = self.transformer.wte(idx) |
|
|
token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1) |
|
|
|
|
|
x = token_embeddings |
|
|
if context is not None: |
|
|
if context.dim() == 2: |
|
|
context = context.unsqueeze(1) |
|
|
if context.size(0) != x.size(0): |
|
|
if context.size(0) == 1: |
|
|
context = context.expand(x.size(0), -1, -1) |
|
|
else: |
|
|
raise ValueError("Conditioning batch size does not match token batch size.") |
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, context=context) |
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
logits = self.out_proj(x) if self.use_out_proj else x |
|
|
return logits |
|
|
|
|
|
|
|
|
def resize_token_embeddings( |
|
|
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False |
|
|
) -> nn.Embedding: |
|
|
""" |
|
|
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. |
|
|
|
|
|
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. |
|
|
|
|
|
Arguments: |
|
|
new_num_tokens (`int`, *optional*): |
|
|
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized |
|
|
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just |
|
|
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. |
|
|
pad_to_multiple_of (`int`, *optional*): |
|
|
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to |
|
|
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
|
|
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
|
|
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
|
|
details about this, or help on choosing the correct value for resizing, refer to this guide: |
|
|
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
|
|
|
|
|
Return: |
|
|
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
|
|
""" |
|
|
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
|
if new_num_tokens is None and pad_to_multiple_of is None: |
|
|
return model_embeds |
|
|
|
|
|
|
|
|
self.config.vocab_size = model_embeds.weight.shape[0] |
|
|
self.vocab_size = model_embeds.weight.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_embeds |
|
|
|
|
|
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): |
|
|
old_embeddings = self.get_input_embeddings() |
|
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) |
|
|
old_embeddings_requires_grad = old_embeddings.weight.requires_grad |
|
|
new_embeddings.requires_grad_(old_embeddings_requires_grad) |
|
|
self.set_input_embeddings(new_embeddings) |
|
|
|
|
|
|
|
|
if pad_to_multiple_of is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_num_tokens = new_embeddings.weight.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
if self.get_output_embeddings() is not None and not False: |
|
|
old_lm_head = self.get_output_embeddings() |
|
|
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) |
|
|
|
|
|
|
|
|
|
|
|
old_lm_head_requires_grad = old_lm_head.weight.requires_grad |
|
|
new_lm_head.requires_grad_(old_lm_head_requires_grad) |
|
|
self.set_output_embeddings(new_lm_head) |
|
|
|
|
|
return self.get_input_embeddings() |
|
|
|
|
|
def _get_resized_embeddings( |
|
|
self, |
|
|
old_embeddings: nn.Embedding, |
|
|
new_num_tokens: Optional[int] = None, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
) -> nn.Embedding: |
|
|
""" |
|
|
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly |
|
|
initialized vectors at the end. Reducing the size will remove vectors from the end |
|
|
|
|
|
Args: |
|
|
old_embeddings (`torch.nn.Embedding`): |
|
|
Old embeddings to be resized. |
|
|
new_num_tokens (`int`, *optional*): |
|
|
New number of tokens in the embedding matrix. |
|
|
|
|
|
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove |
|
|
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens |
|
|
`torch.nn.Embedding` module of the model without doing anything. |
|
|
pad_to_multiple_of (`int`, *optional*): |
|
|
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to |
|
|
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
|
|
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
|
|
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
|
|
details about this, or help on choosing the correct value for resizing, refer to this guide: |
|
|
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
|
|
|
|
|
|
|
|
Return: |
|
|
`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if |
|
|
`new_num_tokens` is `None` |
|
|
""" |
|
|
|
|
|
if pad_to_multiple_of is not None: |
|
|
if not isinstance(pad_to_multiple_of, int): |
|
|
raise ValueError( |
|
|
f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" |
|
|
) |
|
|
if new_num_tokens is None: |
|
|
new_num_tokens = old_embeddings.weight.shape[0] |
|
|
new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of |
|
|
else: |
|
|
print( |
|
|
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" |
|
|
f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." |
|
|
" For more details about this, or help on choosing the correct value for resizing, refer to this guide:" |
|
|
" https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" |
|
|
) |
|
|
|
|
|
if new_num_tokens is None: |
|
|
return old_embeddings |
|
|
|
|
|
|
|
|
if False: |
|
|
import deepspeed |
|
|
|
|
|
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): |
|
|
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
|
|
else: |
|
|
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
|
|
|
|
|
|
|
|
if old_num_tokens == new_num_tokens and not False: |
|
|
return old_embeddings |
|
|
|
|
|
if not isinstance(old_embeddings, nn.Embedding): |
|
|
raise TypeError( |
|
|
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" |
|
|
" should either use a different resize function or make sure that `old_embeddings` are an instance of" |
|
|
f" {nn.Embedding}." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_embeddings = nn.Embedding( |
|
|
new_num_tokens, |
|
|
old_embedding_dim, |
|
|
device=old_embeddings.weight.device, |
|
|
dtype=old_embeddings.weight.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
self._init_weights(new_embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = min(old_num_tokens, new_num_tokens) |
|
|
|
|
|
|
|
|
if False: |
|
|
import deepspeed |
|
|
|
|
|
params = [old_embeddings.weight, new_embeddings.weight] |
|
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
|
|
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] |
|
|
else: |
|
|
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] |
|
|
|
|
|
return new_embeddings |
|
|
|
|
|
|
|
|
def _get_resized_lm_head( |
|
|
self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False |
|
|
) -> nn.Linear: |
|
|
""" |
|
|
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized |
|
|
vectors at the end. Reducing the size will remove vectors from the end |
|
|
|
|
|
Args: |
|
|
old_lm_head (`torch.nn.Linear`): |
|
|
Old lm head liner layer to be resized. |
|
|
new_num_tokens (`int`, *optional*): |
|
|
New number of tokens in the linear matrix. |
|
|
|
|
|
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove |
|
|
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens |
|
|
`torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults |
|
|
to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, |
|
|
vocab_size` else `vocab_size, lm_head_dim`. |
|
|
|
|
|
Return: |
|
|
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is |
|
|
`None` |
|
|
""" |
|
|
if new_num_tokens is None: |
|
|
return old_lm_head |
|
|
|
|
|
|
|
|
if False: |
|
|
import deepspeed |
|
|
|
|
|
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): |
|
|
old_num_tokens, old_lm_head_dim = ( |
|
|
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() |
|
|
) |
|
|
else: |
|
|
old_num_tokens, old_lm_head_dim = ( |
|
|
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() |
|
|
) |
|
|
|
|
|
|
|
|
if old_num_tokens == new_num_tokens and not False: |
|
|
return old_lm_head |
|
|
|
|
|
if not isinstance(old_lm_head, nn.Linear): |
|
|
raise TypeError( |
|
|
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" |
|
|
" should either use a different resize function or make sure that `old_lm_head` are an instance of" |
|
|
f" {nn.Linear}." |
|
|
) |
|
|
|
|
|
|
|
|
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) |
|
|
has_new_lm_head_bias = old_lm_head.bias is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_lm_head = nn.Linear( |
|
|
*new_lm_head_shape, |
|
|
bias=has_new_lm_head_bias, |
|
|
device=old_lm_head.weight.device, |
|
|
dtype=old_lm_head.weight.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
self._init_weights(new_lm_head) |
|
|
|
|
|
num_tokens_to_copy = min(old_num_tokens, new_num_tokens) |
|
|
|
|
|
|
|
|
if False: |
|
|
import deepspeed |
|
|
|
|
|
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] |
|
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
|
|
self._copy_lm_head_original_to_resized( |
|
|
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
|
|
) |
|
|
else: |
|
|
self._copy_lm_head_original_to_resized( |
|
|
new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
|
|
) |
|
|
|
|
|
return new_lm_head |
|
|
|
|
|
def _copy_lm_head_original_to_resized( |
|
|
self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
|
|
): |
|
|
|
|
|
if not transposed: |
|
|
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] |
|
|
else: |
|
|
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] |
|
|
|
|
|
|
|
|
if has_new_lm_head_bias: |
|
|
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] |
|
|
|
|
|
@classmethod |
|
|
def from_name(cls, name: str) -> Self: |
|
|
return cls(LLaMAHFConfig.from_name(name)) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig) -> None: |
|
|
super().__init__() |
|
|
self.rms_1 = RMSNorm(config.n_embd) |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.rms_cross = RMSNorm(config.n_embd) |
|
|
self.cross_attn = CrossAttention(config) |
|
|
self.rms_2 = RMSNorm(config.n_embd) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
self._ctx_k_repeat = None |
|
|
self._ctx_v_repeat = None |
|
|
self._ctx_bsz = None |
|
|
|
|
|
@torch.no_grad() |
|
|
def set_context_cache(self, context: torch.Tensor): |
|
|
|
|
|
B, S, D = context.shape |
|
|
ca = self.cross_attn |
|
|
k = ca.k_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2) |
|
|
v = ca.v_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2) |
|
|
k = ca.k_norm(k) |
|
|
|
|
|
self._ctx_k_repeat = repeat_kv(k, ca.num_kv_groups) |
|
|
self._ctx_v_repeat = repeat_kv(v, ca.num_kv_groups) |
|
|
self._ctx_bsz = B |
|
|
|
|
|
@torch.no_grad() |
|
|
def clear_context_cache(self): |
|
|
self._ctx_k_repeat = None |
|
|
self._ctx_v_repeat = None |
|
|
self._ctx_bsz = None |
|
|
|
|
|
def _cross_attend_cached(self, x: torch.Tensor): |
|
|
|
|
|
if self._ctx_k_repeat is None or self._ctx_v_repeat is None: |
|
|
return x |
|
|
B, T, _ = x.size() |
|
|
if self._ctx_bsz is not None and self._ctx_bsz != B: |
|
|
|
|
|
return x |
|
|
ca = self.cross_attn |
|
|
q = ca.q_proj(x).view(B, T, ca.n_head, ca.head_dim).transpose(1, 2) |
|
|
q = ca.q_norm(q) |
|
|
y = F.scaled_dot_product_attention( |
|
|
q, self._ctx_k_repeat, self._ctx_v_repeat, |
|
|
attn_mask=None, dropout_p=0.0, is_causal=False, scale=ca.softmax_scale, |
|
|
) |
|
|
y = y.transpose(1, 2).contiguous().view(B, T, ca.n_head * ca.head_dim) |
|
|
return x + ca.o_proj(y) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
context: Optional[torch.Tensor] = None, |
|
|
last_past=None, |
|
|
use_cache: bool = False, |
|
|
return_attention: bool = False, |
|
|
) -> torch.Tensor: |
|
|
present = None |
|
|
|
|
|
if use_cache: |
|
|
if return_attention: |
|
|
attn_output, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache) |
|
|
else: |
|
|
attn_output, present = self.attn(self.rms_1(x), last_past, use_cache) |
|
|
x = x + attn_output |
|
|
else: |
|
|
if return_attention: |
|
|
attn_output, attn = self.attn.forward_attn(self.rms_1(x)) |
|
|
else: |
|
|
attn_output = self.attn(self.rms_1(x)) |
|
|
x = x + attn_output |
|
|
|
|
|
|
|
|
if context is not None: |
|
|
x = x + self.cross_attn(self.rms_cross(x), context) |
|
|
else: |
|
|
x = self._cross_attend_cached(self.rms_cross(x)) |
|
|
|
|
|
|
|
|
x = x + self.mlp(self.rms_2(x)) |
|
|
|
|
|
if use_cache: |
|
|
if return_attention: |
|
|
return x, present, attn |
|
|
else: |
|
|
return x, present |
|
|
else: |
|
|
if return_attention: |
|
|
return x, attn |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig) -> None: |
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.n_head = config.n_head |
|
|
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) |
|
|
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" |
|
|
self.head_dim = config.n_embd // config.n_head |
|
|
self.block_size = config.block_size |
|
|
self.rope_base = config.rope_base |
|
|
self.rope_cache = None |
|
|
self.num_kv_groups = self.n_head // self.n_kv_head |
|
|
|
|
|
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
|
|
self.q_norm = RMSNorm(self.head_dim) |
|
|
self.k_norm = RMSNorm(self.head_dim) |
|
|
|
|
|
self.softmax_scale = self.head_dim ** -0.5 |
|
|
|
|
|
def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: |
|
|
B, T, _ = x.size() |
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
if ( |
|
|
self.rope_cache is None |
|
|
or self.rope_cache.dtype != x.dtype |
|
|
or self.rope_cache.device != x.device |
|
|
): |
|
|
self.rope_cache = build_rope_cache( |
|
|
seq_len=self.block_size, |
|
|
n_elem=self.head_dim, |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
base=self.rope_base, |
|
|
) |
|
|
|
|
|
q = apply_rope(q, self.rope_cache) |
|
|
k = apply_rope(k, self.rope_cache) |
|
|
|
|
|
if use_cache: |
|
|
if last_past is not None: |
|
|
past_key, past_value = last_past |
|
|
k = torch.cat([past_key, k], dim=-2) |
|
|
v = torch.cat([past_value, v], dim=-2) |
|
|
present = (k, v) |
|
|
else: |
|
|
present = None |
|
|
|
|
|
k_repeat = repeat_kv(k, self.num_kv_groups) |
|
|
v_repeat = repeat_kv(v, self.num_kv_groups) |
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
|
q, |
|
|
k_repeat, |
|
|
v_repeat, |
|
|
attn_mask=None, |
|
|
dropout_p=0.0, |
|
|
is_causal=True, |
|
|
scale=self.softmax_scale, |
|
|
) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) |
|
|
y = self.o_proj(y) |
|
|
|
|
|
if use_cache: |
|
|
return y, present |
|
|
return y |
|
|
|
|
|
def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: |
|
|
B, T, _ = x.size() |
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
if ( |
|
|
self.rope_cache is None |
|
|
or self.rope_cache.dtype != x.dtype |
|
|
or self.rope_cache.device != x.device |
|
|
): |
|
|
self.rope_cache = build_rope_cache( |
|
|
seq_len=self.block_size, |
|
|
n_elem=self.head_dim, |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
base=self.rope_base, |
|
|
) |
|
|
|
|
|
q = apply_rope(q, self.rope_cache) |
|
|
k = apply_rope(k, self.rope_cache) |
|
|
|
|
|
if use_cache: |
|
|
if last_past is not None: |
|
|
past_key, past_value = last_past |
|
|
k = torch.cat([past_key, k], dim=-2) |
|
|
v = torch.cat([past_value, v], dim=-2) |
|
|
|
|
|
k_repeat = repeat_kv(k, self.num_kv_groups) |
|
|
v_repeat = repeat_kv(v, self.num_kv_groups) |
|
|
|
|
|
att = torch.matmul(q, k_repeat.transpose(-2, -1)) * self.softmax_scale |
|
|
att = F.softmax(att, dim=-1) |
|
|
|
|
|
y = torch.matmul(att, v_repeat) |
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) |
|
|
y = self.o_proj(y) |
|
|
|
|
|
return y, att |
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig) -> None: |
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.n_head = config.n_head |
|
|
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) |
|
|
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" |
|
|
self.head_dim = config.n_embd // config.n_head |
|
|
self.num_kv_groups = self.n_head // self.n_kv_head |
|
|
|
|
|
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
|
|
self.q_norm = RMSNorm(self.head_dim) |
|
|
self.k_norm = RMSNorm(self.head_dim) |
|
|
|
|
|
self.softmax_scale = self.head_dim ** -0.5 |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: |
|
|
B, T, _ = x.size() |
|
|
_, S, _ = context.size() |
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
k_repeat = repeat_kv(k, self.num_kv_groups) |
|
|
v_repeat = repeat_kv(v, self.num_kv_groups) |
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
|
q, |
|
|
k_repeat, |
|
|
v_repeat, |
|
|
attn_mask=None, |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
scale=self.softmax_scale, |
|
|
) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) |
|
|
return self.o_proj(y) |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, num_groups: int) -> torch.Tensor: |
|
|
if num_groups == 1: |
|
|
return hidden_states |
|
|
bsz, n_kv, seq_len, head_dim = hidden_states.shape |
|
|
hidden_states = hidden_states.unsqueeze(2).expand(bsz, n_kv, num_groups, seq_len, head_dim) |
|
|
return hidden_states.reshape(bsz, n_kv * num_groups, seq_len, head_dim) |
|
|
|
|
|
|
|
|
class LengthCausalSelfAttention(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig) -> None: |
|
|
super().__init__() |
|
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.n_head = config.n_head |
|
|
self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4) |
|
|
assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head" |
|
|
self.head_dim = config.n_embd // config.n_head |
|
|
self.block_size = config.block_size |
|
|
self.rope_base = config.rope_base |
|
|
self.rope_cache = None |
|
|
self.num_kv_groups = self.n_head // self.n_kv_head |
|
|
|
|
|
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
|
|
self.q_norm = RMSNorm(self.head_dim) |
|
|
self.k_norm = RMSNorm(self.head_dim) |
|
|
|
|
|
self.softmax_scale = self.head_dim ** -0.5 |
|
|
|
|
|
def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor: |
|
|
B, T, _ = x.size() |
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
|
|
|
if ( |
|
|
self.rope_cache is None |
|
|
or self.rope_cache.dtype != x.dtype |
|
|
or self.rope_cache.device != x.device |
|
|
): |
|
|
self.rope_cache = build_rope_cache( |
|
|
seq_len=self.block_size, |
|
|
n_elem=self.head_dim, |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
base=self.rope_base, |
|
|
) |
|
|
|
|
|
q = apply_rope(q, self.rope_cache) |
|
|
k = apply_rope(k, self.rope_cache) |
|
|
|
|
|
attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device) |
|
|
attn_mask = torch.tril(attn_mask) |
|
|
attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1) |
|
|
|
|
|
text_mask = y_mask.unsqueeze(2) * y_mask.unsqueeze(1) |
|
|
text_mask = F.pad(text_mask, (0, T - y_mask.shape[1], 0, T - y_mask.shape[1]), mode='constant', value=0) |
|
|
attn_mask = torch.logical_or(attn_mask, text_mask) |
|
|
|
|
|
k_repeat = repeat_kv(k, self.num_kv_groups) |
|
|
v_repeat = repeat_kv(v, self.num_kv_groups) |
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
|
q, |
|
|
k_repeat, |
|
|
v_repeat, |
|
|
attn_mask=attn_mask.unsqueeze(1), |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
scale=self.softmax_scale, |
|
|
) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) |
|
|
y = self.o_proj(y) |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config: LLaMAHFConfig) -> None: |
|
|
super().__init__() |
|
|
hidden_dim = 4 * config.n_embd |
|
|
n_hidden = int(2 * hidden_dim / 3) |
|
|
N = 256 |
|
|
|
|
|
n_hidden = ((n_hidden - 1) // N) * N + N |
|
|
|
|
|
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) |
|
|
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) |
|
|
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x = F.silu(self.c_fc1(x)) * self.c_fc2(x) |
|
|
x = self.c_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root Mean Square Layer Normalization. |
|
|
|
|
|
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
|
|
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
|
|
""" |
|
|
|
|
|
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter(torch.ones(size)) |
|
|
self.eps = eps |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) |
|
|
x_normed = x * torch.rsqrt(norm_x + self.eps) |
|
|
return self.scale * x_normed |
|
|
|
|
|
|
|
|
def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: |
|
|
""" |
|
|
Rotary-position cache with safe dtype handling. |
|
|
""" |
|
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) |
|
|
seq_idx = torch.arange(seq_len, dtype=dtype, device=device) |
|
|
idx_theta = torch.outer(seq_idx, theta) |
|
|
|
|
|
|
|
|
dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8] |
|
|
working_dtype = torch.float32 if dtype in dtypes_requiring_casting else dtype |
|
|
complex_dtype = torch.complex64 |
|
|
|
|
|
cache = torch.polar(torch.ones_like(idx_theta, dtype=working_dtype, device=device), |
|
|
idx_theta.to(working_dtype)).to(complex_dtype) |
|
|
return cache |
|
|
|
|
|
|
|
|
def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
T = x.size(1) |
|
|
rope_cache = rope_cache[:T] |
|
|
|
|
|
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
|
|
rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) |
|
|
x_out = torch.view_as_real(xc * rope_cache).flatten(3) |
|
|
return x_out.transpose(1, 2).type_as(x) |
|
|
|