Alogotron commited on
Commit
02a5dc3
·
verified ·
1 Parent(s): 40ba5ae

Upload adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. adapter.py +170 -0
adapter.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Activation Avatars — Adapter models.
3
+
4
+ Small neural networks that map LLM activations (from Qwen3-4B forward hooks)
5
+ into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions.
6
+
7
+ Usage:
8
+ from adapter import load_adapter
9
+ adapter = load_adapter("adapters/xattn8tok_thinking.pt")
10
+ # activation: [in_dim] tensor from LLM hidden state
11
+ expression = adapter(activation, emotion_scale=6.0)
12
+ # expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class MultiTokenAdapter(nn.Module):
21
+ def __init__(self, in_dim, out_dim, n_tokens=8, rank=128):
22
+ super().__init__()
23
+ self.in_dim, self.out_dim = in_dim, out_dim
24
+ self.n_tokens, self.rank = n_tokens, rank
25
+ self.encoder = nn.Sequential(
26
+ nn.Linear(in_dim, rank), nn.GELU(),
27
+ nn.Linear(rank, rank), nn.GELU(),
28
+ )
29
+ self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
30
+ self.project = nn.Linear(rank, out_dim)
31
+
32
+ def forward(self, x):
33
+ if x.dim() == 1:
34
+ x = x.unsqueeze(0)
35
+ h = self.encoder(x)
36
+ combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0)
37
+ return self.project(combined)
38
+
39
+
40
+ class CrossAttentionAdapter(nn.Module):
41
+ def __init__(self, in_dim, out_dim, n_tokens=64, rank=128,
42
+ n_input_tokens=4, n_heads=4, n_layers=2):
43
+ super().__init__()
44
+ self.in_dim, self.out_dim = in_dim, out_dim
45
+ self.n_tokens, self.rank = n_tokens, rank
46
+ self.n_input_tokens = n_input_tokens
47
+ self.input_encoder = nn.Sequential(
48
+ nn.Linear(in_dim, rank), nn.GELU(),
49
+ nn.Linear(rank, n_input_tokens * rank),
50
+ )
51
+ self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
52
+ decoder_layer = nn.TransformerDecoderLayer(
53
+ d_model=rank, nhead=n_heads,
54
+ dim_feedforward=rank * 4, activation='gelu',
55
+ batch_first=True, norm_first=True,
56
+ )
57
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
58
+ self.project = nn.Linear(rank, out_dim)
59
+
60
+ def forward(self, x):
61
+ if x.dim() == 1:
62
+ x = x.unsqueeze(0)
63
+ B = x.shape[0]
64
+ memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
65
+ queries = self.queries.unsqueeze(0).expand(B, -1, -1)
66
+ decoded = self.decoder(queries, memory)
67
+ return self.project(decoded)
68
+
69
+
70
+ class LayerWeightedInput(nn.Module):
71
+ def __init__(self, n_layers=3, layer_dim=2560):
72
+ super().__init__()
73
+ self.n_layers, self.layer_dim = n_layers, layer_dim
74
+ self.layer_logits = nn.Parameter(torch.zeros(n_layers))
75
+
76
+ def forward(self, x):
77
+ chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim)
78
+ weights = F.softmax(self.layer_logits, dim=0)
79
+ return (chunks * weights[None, :, None]).sum(dim=1)
80
+
81
+
82
+ def load_adapter(path, device="cpu", dtype=torch.float32):
83
+ """Load an adapter checkpoint and return a callable wrapper.
84
+
85
+ Args:
86
+ path: Path to the .pt checkpoint file.
87
+ device: Device to load onto.
88
+ dtype: Dtype for normalization buffers.
89
+
90
+ Returns:
91
+ A callable that takes (activation, emotion_scale) and returns
92
+ expression embeddings [n_tokens, out_dim].
93
+ """
94
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
95
+ model_type = ckpt.get("model_type", "cross_attention")
96
+
97
+ if model_type == "cross_attention":
98
+ rank = ckpt["rank"]
99
+ sd = ckpt["model_state_dict"]
100
+ # Infer architecture from state dict
101
+ enc_w = sd.get("_orig_mod.input_encoder.2.weight",
102
+ sd.get("input_encoder.2.weight"))
103
+ n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4)
104
+ decoder_keys = [k for k in sd if "decoder.layers." in k]
105
+ layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys)
106
+ n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2)
107
+
108
+ adapter = CrossAttentionAdapter(
109
+ ckpt["in_dim"], ckpt["out_dim"],
110
+ n_tokens=ckpt["n_tokens"], rank=rank,
111
+ n_input_tokens=n_input_tokens,
112
+ n_heads=ckpt.get("n_heads", 4),
113
+ n_layers=n_attn_layers,
114
+ )
115
+ else:
116
+ adapter = MultiTokenAdapter(
117
+ ckpt["in_dim"], ckpt["out_dim"],
118
+ n_tokens=ckpt["n_tokens"], rank=ckpt["rank"],
119
+ )
120
+
121
+ sd = ckpt["model_state_dict"]
122
+ sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()}
123
+ adapter.load_state_dict(sd)
124
+ adapter.eval().to(device)
125
+
126
+ layer_weight = None
127
+ if "layer_weight_state_dict" in ckpt:
128
+ layer_weight = LayerWeightedInput()
129
+ layer_weight.load_state_dict(ckpt["layer_weight_state_dict"])
130
+ layer_weight.eval().to(device)
131
+
132
+ act_mean = ckpt["act_mean"].to(device=device, dtype=dtype)
133
+ act_std = ckpt["act_std"].to(device=device, dtype=dtype)
134
+ target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype)
135
+ target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype)
136
+
137
+ @torch.no_grad()
138
+ def forward(activation, emotion_scale=1.0):
139
+ act = activation.to(device=device, dtype=dtype)
140
+ act_norm = (act - act_mean) / act_std
141
+ if layer_weight is not None:
142
+ act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0)
143
+ pred = adapter(act_norm.unsqueeze(0)).squeeze(0)
144
+ return pred * target_residual_std * emotion_scale + target_center
145
+
146
+ # Resolve hook layers from input_layers field
147
+ input_layers = ckpt.get("input_layers", "layer_24")
148
+ _LAYER_MAP = {
149
+ "learned_weight": [9, 18, 27],
150
+ "all_3": [9, 18, 27],
151
+ "layer_9": [9],
152
+ "layer_18": [18],
153
+ "layer_27": [27],
154
+ "layer_24": [24],
155
+ }
156
+ hook_layers = _LAYER_MAP.get(input_layers, [24])
157
+
158
+ forward.adapter = adapter
159
+ forward.layer_weight = layer_weight
160
+ forward.hook_layers = hook_layers
161
+ forward.metadata = {
162
+ "model_type": model_type,
163
+ "in_dim": ckpt["in_dim"],
164
+ "out_dim": ckpt["out_dim"],
165
+ "n_tokens": ckpt["n_tokens"],
166
+ "rank": ckpt["rank"],
167
+ "input_layers": input_layers,
168
+ "hook_layers": hook_layers,
169
+ }
170
+ return forward