""" Activation Avatars — Adapter models. Small neural networks that map LLM activations (from Qwen3-4B forward hooks) into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions. Usage: from adapter import load_adapter adapter = load_adapter("adapters/xattn8tok_thinking.pt") # activation: [in_dim] tensor from LLM hidden state expression = adapter(activation, emotion_scale=6.0) # expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds """ import torch import torch.nn as nn import torch.nn.functional as F class MultiTokenAdapter(nn.Module): def __init__(self, in_dim, out_dim, n_tokens=8, rank=128): super().__init__() self.in_dim, self.out_dim = in_dim, out_dim self.n_tokens, self.rank = n_tokens, rank self.encoder = nn.Sequential( nn.Linear(in_dim, rank), nn.GELU(), nn.Linear(rank, rank), nn.GELU(), ) self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02) self.project = nn.Linear(rank, out_dim) def forward(self, x): if x.dim() == 1: x = x.unsqueeze(0) h = self.encoder(x) combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0) return self.project(combined) class CrossAttentionAdapter(nn.Module): def __init__(self, in_dim, out_dim, n_tokens=64, rank=128, n_input_tokens=4, n_heads=4, n_layers=2): super().__init__() self.in_dim, self.out_dim = in_dim, out_dim self.n_tokens, self.rank = n_tokens, rank self.n_input_tokens = n_input_tokens self.input_encoder = nn.Sequential( nn.Linear(in_dim, rank), nn.GELU(), nn.Linear(rank, n_input_tokens * rank), ) self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02) decoder_layer = nn.TransformerDecoderLayer( d_model=rank, nhead=n_heads, dim_feedforward=rank * 4, activation='gelu', batch_first=True, norm_first=True, ) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers) self.project = nn.Linear(rank, out_dim) def forward(self, x): if x.dim() == 1: x = x.unsqueeze(0) B = x.shape[0] memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank) queries = self.queries.unsqueeze(0).expand(B, -1, -1) decoded = self.decoder(queries, memory) return self.project(decoded) class LayerWeightedInput(nn.Module): def __init__(self, n_layers=3, layer_dim=2560): super().__init__() self.n_layers, self.layer_dim = n_layers, layer_dim self.layer_logits = nn.Parameter(torch.zeros(n_layers)) def forward(self, x): chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim) weights = F.softmax(self.layer_logits, dim=0) return (chunks * weights[None, :, None]).sum(dim=1) def load_adapter(path, device="cpu", dtype=torch.float32): """Load an adapter checkpoint and return a callable wrapper. Args: path: Path to the .pt checkpoint file. device: Device to load onto. dtype: Dtype for normalization buffers. Returns: A callable that takes (activation, emotion_scale) and returns expression embeddings [n_tokens, out_dim]. """ ckpt = torch.load(path, map_location="cpu", weights_only=False) model_type = ckpt.get("model_type", "cross_attention") if model_type == "cross_attention": rank = ckpt["rank"] sd = ckpt["model_state_dict"] # Infer architecture from state dict enc_w = sd.get("_orig_mod.input_encoder.2.weight", sd.get("input_encoder.2.weight")) n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4) decoder_keys = [k for k in sd if "decoder.layers." in k] layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys) n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2) adapter = CrossAttentionAdapter( ckpt["in_dim"], ckpt["out_dim"], n_tokens=ckpt["n_tokens"], rank=rank, n_input_tokens=n_input_tokens, n_heads=ckpt.get("n_heads", 4), n_layers=n_attn_layers, ) else: adapter = MultiTokenAdapter( ckpt["in_dim"], ckpt["out_dim"], n_tokens=ckpt["n_tokens"], rank=ckpt["rank"], ) sd = ckpt["model_state_dict"] sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()} adapter.load_state_dict(sd) adapter.eval().to(device) layer_weight = None if "layer_weight_state_dict" in ckpt: layer_weight = LayerWeightedInput() layer_weight.load_state_dict(ckpt["layer_weight_state_dict"]) layer_weight.eval().to(device) act_mean = ckpt["act_mean"].to(device=device, dtype=dtype) act_std = ckpt["act_std"].to(device=device, dtype=dtype) target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype) target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype) @torch.no_grad() def forward(activation, emotion_scale=1.0): act = activation.to(device=device, dtype=dtype) act_norm = (act - act_mean) / act_std if layer_weight is not None: act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0) pred = adapter(act_norm.unsqueeze(0)).squeeze(0) return pred * target_residual_std * emotion_scale + target_center # Resolve hook layers from input_layers field input_layers = ckpt.get("input_layers", "layer_24") _LAYER_MAP = { "learned_weight": [9, 18, 27], "all_3": [9, 18, 27], "layer_9": [9], "layer_18": [18], "layer_27": [27], "layer_24": [24], } hook_layers = _LAYER_MAP.get(input_layers, [24]) forward.adapter = adapter forward.layer_weight = layer_weight forward.hook_layers = hook_layers forward.metadata = { "model_type": model_type, "in_dim": ckpt["in_dim"], "out_dim": ckpt["out_dim"], "n_tokens": ckpt["n_tokens"], "rank": ckpt["rank"], "input_layers": input_layers, "hook_layers": hook_layers, } return forward