| | """ |
| | Activation Avatars — Adapter models. |
| | |
| | Small neural networks that map LLM activations (from Qwen3-4B forward hooks) |
| | into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions. |
| | |
| | Usage: |
| | from adapter import load_adapter |
| | adapter = load_adapter("adapters/xattn8tok_thinking.pt") |
| | # activation: [in_dim] tensor from LLM hidden state |
| | expression = adapter(activation, emotion_scale=6.0) |
| | # expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class MultiTokenAdapter(nn.Module): |
| | def __init__(self, in_dim, out_dim, n_tokens=8, rank=128): |
| | super().__init__() |
| | self.in_dim, self.out_dim = in_dim, out_dim |
| | self.n_tokens, self.rank = n_tokens, rank |
| | self.encoder = nn.Sequential( |
| | nn.Linear(in_dim, rank), nn.GELU(), |
| | nn.Linear(rank, rank), nn.GELU(), |
| | ) |
| | self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02) |
| | self.project = nn.Linear(rank, out_dim) |
| |
|
| | def forward(self, x): |
| | if x.dim() == 1: |
| | x = x.unsqueeze(0) |
| | h = self.encoder(x) |
| | combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0) |
| | return self.project(combined) |
| |
|
| |
|
| | class CrossAttentionAdapter(nn.Module): |
| | def __init__(self, in_dim, out_dim, n_tokens=64, rank=128, |
| | n_input_tokens=4, n_heads=4, n_layers=2): |
| | super().__init__() |
| | self.in_dim, self.out_dim = in_dim, out_dim |
| | self.n_tokens, self.rank = n_tokens, rank |
| | self.n_input_tokens = n_input_tokens |
| | self.input_encoder = nn.Sequential( |
| | nn.Linear(in_dim, rank), nn.GELU(), |
| | nn.Linear(rank, n_input_tokens * rank), |
| | ) |
| | self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02) |
| | decoder_layer = nn.TransformerDecoderLayer( |
| | d_model=rank, nhead=n_heads, |
| | dim_feedforward=rank * 4, activation='gelu', |
| | batch_first=True, norm_first=True, |
| | ) |
| | self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers) |
| | self.project = nn.Linear(rank, out_dim) |
| |
|
| | def forward(self, x): |
| | if x.dim() == 1: |
| | x = x.unsqueeze(0) |
| | B = x.shape[0] |
| | memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank) |
| | queries = self.queries.unsqueeze(0).expand(B, -1, -1) |
| | decoded = self.decoder(queries, memory) |
| | return self.project(decoded) |
| |
|
| |
|
| | class LayerWeightedInput(nn.Module): |
| | def __init__(self, n_layers=3, layer_dim=2560): |
| | super().__init__() |
| | self.n_layers, self.layer_dim = n_layers, layer_dim |
| | self.layer_logits = nn.Parameter(torch.zeros(n_layers)) |
| |
|
| | def forward(self, x): |
| | chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim) |
| | weights = F.softmax(self.layer_logits, dim=0) |
| | return (chunks * weights[None, :, None]).sum(dim=1) |
| |
|
| |
|
| | def load_adapter(path, device="cpu", dtype=torch.float32): |
| | """Load an adapter checkpoint and return a callable wrapper. |
| | |
| | Args: |
| | path: Path to the .pt checkpoint file. |
| | device: Device to load onto. |
| | dtype: Dtype for normalization buffers. |
| | |
| | Returns: |
| | A callable that takes (activation, emotion_scale) and returns |
| | expression embeddings [n_tokens, out_dim]. |
| | """ |
| | ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| | model_type = ckpt.get("model_type", "cross_attention") |
| |
|
| | if model_type == "cross_attention": |
| | rank = ckpt["rank"] |
| | sd = ckpt["model_state_dict"] |
| | |
| | enc_w = sd.get("_orig_mod.input_encoder.2.weight", |
| | sd.get("input_encoder.2.weight")) |
| | n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4) |
| | decoder_keys = [k for k in sd if "decoder.layers." in k] |
| | layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys) |
| | n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2) |
| |
|
| | adapter = CrossAttentionAdapter( |
| | ckpt["in_dim"], ckpt["out_dim"], |
| | n_tokens=ckpt["n_tokens"], rank=rank, |
| | n_input_tokens=n_input_tokens, |
| | n_heads=ckpt.get("n_heads", 4), |
| | n_layers=n_attn_layers, |
| | ) |
| | else: |
| | adapter = MultiTokenAdapter( |
| | ckpt["in_dim"], ckpt["out_dim"], |
| | n_tokens=ckpt["n_tokens"], rank=ckpt["rank"], |
| | ) |
| |
|
| | sd = ckpt["model_state_dict"] |
| | sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()} |
| | adapter.load_state_dict(sd) |
| | adapter.eval().to(device) |
| |
|
| | layer_weight = None |
| | if "layer_weight_state_dict" in ckpt: |
| | layer_weight = LayerWeightedInput() |
| | layer_weight.load_state_dict(ckpt["layer_weight_state_dict"]) |
| | layer_weight.eval().to(device) |
| |
|
| | act_mean = ckpt["act_mean"].to(device=device, dtype=dtype) |
| | act_std = ckpt["act_std"].to(device=device, dtype=dtype) |
| | target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype) |
| | target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype) |
| |
|
| | @torch.no_grad() |
| | def forward(activation, emotion_scale=1.0): |
| | act = activation.to(device=device, dtype=dtype) |
| | act_norm = (act - act_mean) / act_std |
| | if layer_weight is not None: |
| | act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0) |
| | pred = adapter(act_norm.unsqueeze(0)).squeeze(0) |
| | return pred * target_residual_std * emotion_scale + target_center |
| |
|
| | |
| | input_layers = ckpt.get("input_layers", "layer_24") |
| | _LAYER_MAP = { |
| | "learned_weight": [9, 18, 27], |
| | "all_3": [9, 18, 27], |
| | "layer_9": [9], |
| | "layer_18": [18], |
| | "layer_27": [27], |
| | "layer_24": [24], |
| | } |
| | hook_layers = _LAYER_MAP.get(input_layers, [24]) |
| |
|
| | forward.adapter = adapter |
| | forward.layer_weight = layer_weight |
| | forward.hook_layers = hook_layers |
| | forward.metadata = { |
| | "model_type": model_type, |
| | "in_dim": ckpt["in_dim"], |
| | "out_dim": ckpt["out_dim"], |
| | "n_tokens": ckpt["n_tokens"], |
| | "rank": ckpt["rank"], |
| | "input_layers": input_layers, |
| | "hook_layers": hook_layers, |
| | } |
| | return forward |
| |
|