File size: 6,449 Bytes
02a5dc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
Activation Avatars — Adapter models.
Small neural networks that map LLM activations (from Qwen3-4B forward hooks)
into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions.
Usage:
from adapter import load_adapter
adapter = load_adapter("adapters/xattn8tok_thinking.pt")
# activation: [in_dim] tensor from LLM hidden state
expression = adapter(activation, emotion_scale=6.0)
# expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiTokenAdapter(nn.Module):
def __init__(self, in_dim, out_dim, n_tokens=8, rank=128):
super().__init__()
self.in_dim, self.out_dim = in_dim, out_dim
self.n_tokens, self.rank = n_tokens, rank
self.encoder = nn.Sequential(
nn.Linear(in_dim, rank), nn.GELU(),
nn.Linear(rank, rank), nn.GELU(),
)
self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
self.project = nn.Linear(rank, out_dim)
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(0)
h = self.encoder(x)
combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0)
return self.project(combined)
class CrossAttentionAdapter(nn.Module):
def __init__(self, in_dim, out_dim, n_tokens=64, rank=128,
n_input_tokens=4, n_heads=4, n_layers=2):
super().__init__()
self.in_dim, self.out_dim = in_dim, out_dim
self.n_tokens, self.rank = n_tokens, rank
self.n_input_tokens = n_input_tokens
self.input_encoder = nn.Sequential(
nn.Linear(in_dim, rank), nn.GELU(),
nn.Linear(rank, n_input_tokens * rank),
)
self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
decoder_layer = nn.TransformerDecoderLayer(
d_model=rank, nhead=n_heads,
dim_feedforward=rank * 4, activation='gelu',
batch_first=True, norm_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
self.project = nn.Linear(rank, out_dim)
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(0)
B = x.shape[0]
memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
queries = self.queries.unsqueeze(0).expand(B, -1, -1)
decoded = self.decoder(queries, memory)
return self.project(decoded)
class LayerWeightedInput(nn.Module):
def __init__(self, n_layers=3, layer_dim=2560):
super().__init__()
self.n_layers, self.layer_dim = n_layers, layer_dim
self.layer_logits = nn.Parameter(torch.zeros(n_layers))
def forward(self, x):
chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim)
weights = F.softmax(self.layer_logits, dim=0)
return (chunks * weights[None, :, None]).sum(dim=1)
def load_adapter(path, device="cpu", dtype=torch.float32):
"""Load an adapter checkpoint and return a callable wrapper.
Args:
path: Path to the .pt checkpoint file.
device: Device to load onto.
dtype: Dtype for normalization buffers.
Returns:
A callable that takes (activation, emotion_scale) and returns
expression embeddings [n_tokens, out_dim].
"""
ckpt = torch.load(path, map_location="cpu", weights_only=False)
model_type = ckpt.get("model_type", "cross_attention")
if model_type == "cross_attention":
rank = ckpt["rank"]
sd = ckpt["model_state_dict"]
# Infer architecture from state dict
enc_w = sd.get("_orig_mod.input_encoder.2.weight",
sd.get("input_encoder.2.weight"))
n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4)
decoder_keys = [k for k in sd if "decoder.layers." in k]
layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys)
n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2)
adapter = CrossAttentionAdapter(
ckpt["in_dim"], ckpt["out_dim"],
n_tokens=ckpt["n_tokens"], rank=rank,
n_input_tokens=n_input_tokens,
n_heads=ckpt.get("n_heads", 4),
n_layers=n_attn_layers,
)
else:
adapter = MultiTokenAdapter(
ckpt["in_dim"], ckpt["out_dim"],
n_tokens=ckpt["n_tokens"], rank=ckpt["rank"],
)
sd = ckpt["model_state_dict"]
sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()}
adapter.load_state_dict(sd)
adapter.eval().to(device)
layer_weight = None
if "layer_weight_state_dict" in ckpt:
layer_weight = LayerWeightedInput()
layer_weight.load_state_dict(ckpt["layer_weight_state_dict"])
layer_weight.eval().to(device)
act_mean = ckpt["act_mean"].to(device=device, dtype=dtype)
act_std = ckpt["act_std"].to(device=device, dtype=dtype)
target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype)
target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype)
@torch.no_grad()
def forward(activation, emotion_scale=1.0):
act = activation.to(device=device, dtype=dtype)
act_norm = (act - act_mean) / act_std
if layer_weight is not None:
act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0)
pred = adapter(act_norm.unsqueeze(0)).squeeze(0)
return pred * target_residual_std * emotion_scale + target_center
# Resolve hook layers from input_layers field
input_layers = ckpt.get("input_layers", "layer_24")
_LAYER_MAP = {
"learned_weight": [9, 18, 27],
"all_3": [9, 18, 27],
"layer_9": [9],
"layer_18": [18],
"layer_27": [27],
"layer_24": [24],
}
hook_layers = _LAYER_MAP.get(input_layers, [24])
forward.adapter = adapter
forward.layer_weight = layer_weight
forward.hook_layers = hook_layers
forward.metadata = {
"model_type": model_type,
"in_dim": ckpt["in_dim"],
"out_dim": ckpt["out_dim"],
"n_tokens": ckpt["n_tokens"],
"rank": ckpt["rank"],
"input_layers": input_layers,
"hook_layers": hook_layers,
}
return forward
|