Upload adapter.py with huggingface_hub
Browse files- adapter.py +170 -0
adapter.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Activation Avatars — Adapter models.
|
| 3 |
+
|
| 4 |
+
Small neural networks that map LLM activations (from Qwen3-4B forward hooks)
|
| 5 |
+
into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from adapter import load_adapter
|
| 9 |
+
adapter = load_adapter("adapters/xattn8tok_thinking.pt")
|
| 10 |
+
# activation: [in_dim] tensor from LLM hidden state
|
| 11 |
+
expression = adapter(activation, emotion_scale=6.0)
|
| 12 |
+
# expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiTokenAdapter(nn.Module):
|
| 21 |
+
def __init__(self, in_dim, out_dim, n_tokens=8, rank=128):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_dim, self.out_dim = in_dim, out_dim
|
| 24 |
+
self.n_tokens, self.rank = n_tokens, rank
|
| 25 |
+
self.encoder = nn.Sequential(
|
| 26 |
+
nn.Linear(in_dim, rank), nn.GELU(),
|
| 27 |
+
nn.Linear(rank, rank), nn.GELU(),
|
| 28 |
+
)
|
| 29 |
+
self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
|
| 30 |
+
self.project = nn.Linear(rank, out_dim)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
if x.dim() == 1:
|
| 34 |
+
x = x.unsqueeze(0)
|
| 35 |
+
h = self.encoder(x)
|
| 36 |
+
combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0)
|
| 37 |
+
return self.project(combined)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CrossAttentionAdapter(nn.Module):
|
| 41 |
+
def __init__(self, in_dim, out_dim, n_tokens=64, rank=128,
|
| 42 |
+
n_input_tokens=4, n_heads=4, n_layers=2):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.in_dim, self.out_dim = in_dim, out_dim
|
| 45 |
+
self.n_tokens, self.rank = n_tokens, rank
|
| 46 |
+
self.n_input_tokens = n_input_tokens
|
| 47 |
+
self.input_encoder = nn.Sequential(
|
| 48 |
+
nn.Linear(in_dim, rank), nn.GELU(),
|
| 49 |
+
nn.Linear(rank, n_input_tokens * rank),
|
| 50 |
+
)
|
| 51 |
+
self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
|
| 52 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
| 53 |
+
d_model=rank, nhead=n_heads,
|
| 54 |
+
dim_feedforward=rank * 4, activation='gelu',
|
| 55 |
+
batch_first=True, norm_first=True,
|
| 56 |
+
)
|
| 57 |
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
|
| 58 |
+
self.project = nn.Linear(rank, out_dim)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
if x.dim() == 1:
|
| 62 |
+
x = x.unsqueeze(0)
|
| 63 |
+
B = x.shape[0]
|
| 64 |
+
memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
|
| 65 |
+
queries = self.queries.unsqueeze(0).expand(B, -1, -1)
|
| 66 |
+
decoded = self.decoder(queries, memory)
|
| 67 |
+
return self.project(decoded)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LayerWeightedInput(nn.Module):
|
| 71 |
+
def __init__(self, n_layers=3, layer_dim=2560):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.n_layers, self.layer_dim = n_layers, layer_dim
|
| 74 |
+
self.layer_logits = nn.Parameter(torch.zeros(n_layers))
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim)
|
| 78 |
+
weights = F.softmax(self.layer_logits, dim=0)
|
| 79 |
+
return (chunks * weights[None, :, None]).sum(dim=1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_adapter(path, device="cpu", dtype=torch.float32):
|
| 83 |
+
"""Load an adapter checkpoint and return a callable wrapper.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
path: Path to the .pt checkpoint file.
|
| 87 |
+
device: Device to load onto.
|
| 88 |
+
dtype: Dtype for normalization buffers.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
A callable that takes (activation, emotion_scale) and returns
|
| 92 |
+
expression embeddings [n_tokens, out_dim].
|
| 93 |
+
"""
|
| 94 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 95 |
+
model_type = ckpt.get("model_type", "cross_attention")
|
| 96 |
+
|
| 97 |
+
if model_type == "cross_attention":
|
| 98 |
+
rank = ckpt["rank"]
|
| 99 |
+
sd = ckpt["model_state_dict"]
|
| 100 |
+
# Infer architecture from state dict
|
| 101 |
+
enc_w = sd.get("_orig_mod.input_encoder.2.weight",
|
| 102 |
+
sd.get("input_encoder.2.weight"))
|
| 103 |
+
n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4)
|
| 104 |
+
decoder_keys = [k for k in sd if "decoder.layers." in k]
|
| 105 |
+
layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys)
|
| 106 |
+
n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2)
|
| 107 |
+
|
| 108 |
+
adapter = CrossAttentionAdapter(
|
| 109 |
+
ckpt["in_dim"], ckpt["out_dim"],
|
| 110 |
+
n_tokens=ckpt["n_tokens"], rank=rank,
|
| 111 |
+
n_input_tokens=n_input_tokens,
|
| 112 |
+
n_heads=ckpt.get("n_heads", 4),
|
| 113 |
+
n_layers=n_attn_layers,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
adapter = MultiTokenAdapter(
|
| 117 |
+
ckpt["in_dim"], ckpt["out_dim"],
|
| 118 |
+
n_tokens=ckpt["n_tokens"], rank=ckpt["rank"],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
sd = ckpt["model_state_dict"]
|
| 122 |
+
sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()}
|
| 123 |
+
adapter.load_state_dict(sd)
|
| 124 |
+
adapter.eval().to(device)
|
| 125 |
+
|
| 126 |
+
layer_weight = None
|
| 127 |
+
if "layer_weight_state_dict" in ckpt:
|
| 128 |
+
layer_weight = LayerWeightedInput()
|
| 129 |
+
layer_weight.load_state_dict(ckpt["layer_weight_state_dict"])
|
| 130 |
+
layer_weight.eval().to(device)
|
| 131 |
+
|
| 132 |
+
act_mean = ckpt["act_mean"].to(device=device, dtype=dtype)
|
| 133 |
+
act_std = ckpt["act_std"].to(device=device, dtype=dtype)
|
| 134 |
+
target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype)
|
| 135 |
+
target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype)
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def forward(activation, emotion_scale=1.0):
|
| 139 |
+
act = activation.to(device=device, dtype=dtype)
|
| 140 |
+
act_norm = (act - act_mean) / act_std
|
| 141 |
+
if layer_weight is not None:
|
| 142 |
+
act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0)
|
| 143 |
+
pred = adapter(act_norm.unsqueeze(0)).squeeze(0)
|
| 144 |
+
return pred * target_residual_std * emotion_scale + target_center
|
| 145 |
+
|
| 146 |
+
# Resolve hook layers from input_layers field
|
| 147 |
+
input_layers = ckpt.get("input_layers", "layer_24")
|
| 148 |
+
_LAYER_MAP = {
|
| 149 |
+
"learned_weight": [9, 18, 27],
|
| 150 |
+
"all_3": [9, 18, 27],
|
| 151 |
+
"layer_9": [9],
|
| 152 |
+
"layer_18": [18],
|
| 153 |
+
"layer_27": [27],
|
| 154 |
+
"layer_24": [24],
|
| 155 |
+
}
|
| 156 |
+
hook_layers = _LAYER_MAP.get(input_layers, [24])
|
| 157 |
+
|
| 158 |
+
forward.adapter = adapter
|
| 159 |
+
forward.layer_weight = layer_weight
|
| 160 |
+
forward.hook_layers = hook_layers
|
| 161 |
+
forward.metadata = {
|
| 162 |
+
"model_type": model_type,
|
| 163 |
+
"in_dim": ckpt["in_dim"],
|
| 164 |
+
"out_dim": ckpt["out_dim"],
|
| 165 |
+
"n_tokens": ckpt["n_tokens"],
|
| 166 |
+
"rank": ckpt["rank"],
|
| 167 |
+
"input_layers": input_layers,
|
| 168 |
+
"hook_layers": hook_layers,
|
| 169 |
+
}
|
| 170 |
+
return forward
|