mbiss's picture
Initial release: 3 adapter checkpoints with loading code
8943cad verified
"""
Activation Avatars — Adapter models.
Small neural networks that map LLM activations (from Qwen3-4B forward hooks)
into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions.
Usage:
from adapter import load_adapter
adapter = load_adapter("adapters/xattn8tok_thinking.pt")
# activation: [in_dim] tensor from LLM hidden state
expression = adapter(activation, emotion_scale=6.0)
# expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiTokenAdapter(nn.Module):
def __init__(self, in_dim, out_dim, n_tokens=8, rank=128):
super().__init__()
self.in_dim, self.out_dim = in_dim, out_dim
self.n_tokens, self.rank = n_tokens, rank
self.encoder = nn.Sequential(
nn.Linear(in_dim, rank), nn.GELU(),
nn.Linear(rank, rank), nn.GELU(),
)
self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
self.project = nn.Linear(rank, out_dim)
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(0)
h = self.encoder(x)
combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0)
return self.project(combined)
class CrossAttentionAdapter(nn.Module):
def __init__(self, in_dim, out_dim, n_tokens=64, rank=128,
n_input_tokens=4, n_heads=4, n_layers=2):
super().__init__()
self.in_dim, self.out_dim = in_dim, out_dim
self.n_tokens, self.rank = n_tokens, rank
self.n_input_tokens = n_input_tokens
self.input_encoder = nn.Sequential(
nn.Linear(in_dim, rank), nn.GELU(),
nn.Linear(rank, n_input_tokens * rank),
)
self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
decoder_layer = nn.TransformerDecoderLayer(
d_model=rank, nhead=n_heads,
dim_feedforward=rank * 4, activation='gelu',
batch_first=True, norm_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
self.project = nn.Linear(rank, out_dim)
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(0)
B = x.shape[0]
memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
queries = self.queries.unsqueeze(0).expand(B, -1, -1)
decoded = self.decoder(queries, memory)
return self.project(decoded)
class LayerWeightedInput(nn.Module):
def __init__(self, n_layers=3, layer_dim=2560):
super().__init__()
self.n_layers, self.layer_dim = n_layers, layer_dim
self.layer_logits = nn.Parameter(torch.zeros(n_layers))
def forward(self, x):
chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim)
weights = F.softmax(self.layer_logits, dim=0)
return (chunks * weights[None, :, None]).sum(dim=1)
def load_adapter(path, device="cpu", dtype=torch.float32):
"""Load an adapter checkpoint and return a callable wrapper.
Args:
path: Path to the .pt checkpoint file.
device: Device to load onto.
dtype: Dtype for normalization buffers.
Returns:
A callable that takes (activation, emotion_scale) and returns
expression embeddings [n_tokens, out_dim].
"""
ckpt = torch.load(path, map_location="cpu", weights_only=False)
model_type = ckpt.get("model_type", "cross_attention")
if model_type == "cross_attention":
rank = ckpt["rank"]
sd = ckpt["model_state_dict"]
# Infer architecture from state dict
enc_w = sd.get("_orig_mod.input_encoder.2.weight",
sd.get("input_encoder.2.weight"))
n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4)
decoder_keys = [k for k in sd if "decoder.layers." in k]
layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys)
n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2)
adapter = CrossAttentionAdapter(
ckpt["in_dim"], ckpt["out_dim"],
n_tokens=ckpt["n_tokens"], rank=rank,
n_input_tokens=n_input_tokens,
n_heads=ckpt.get("n_heads", 4),
n_layers=n_attn_layers,
)
else:
adapter = MultiTokenAdapter(
ckpt["in_dim"], ckpt["out_dim"],
n_tokens=ckpt["n_tokens"], rank=ckpt["rank"],
)
sd = ckpt["model_state_dict"]
sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()}
adapter.load_state_dict(sd)
adapter.eval().to(device)
layer_weight = None
if "layer_weight_state_dict" in ckpt:
layer_weight = LayerWeightedInput()
layer_weight.load_state_dict(ckpt["layer_weight_state_dict"])
layer_weight.eval().to(device)
act_mean = ckpt["act_mean"].to(device=device, dtype=dtype)
act_std = ckpt["act_std"].to(device=device, dtype=dtype)
target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype)
target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype)
@torch.no_grad()
def forward(activation, emotion_scale=1.0):
act = activation.to(device=device, dtype=dtype)
act_norm = (act - act_mean) / act_std
if layer_weight is not None:
act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0)
pred = adapter(act_norm.unsqueeze(0)).squeeze(0)
return pred * target_residual_std * emotion_scale + target_center
# Resolve hook layers from input_layers field
input_layers = ckpt.get("input_layers", "layer_24")
_LAYER_MAP = {
"learned_weight": [9, 18, 27],
"all_3": [9, 18, 27],
"layer_9": [9],
"layer_18": [18],
"layer_27": [27],
"layer_24": [24],
}
hook_layers = _LAYER_MAP.get(input_layers, [24])
forward.adapter = adapter
forward.layer_weight = layer_weight
forward.hook_layers = hook_layers
forward.metadata = {
"model_type": model_type,
"in_dim": ckpt["in_dim"],
"out_dim": ckpt["out_dim"],
"n_tokens": ckpt["n_tokens"],
"rank": ckpt["rank"],
"input_layers": input_layers,
"hook_layers": hook_layers,
}
return forward