File size: 6,449 Bytes
02a5dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Activation Avatars — Adapter models.

Small neural networks that map LLM activations (from Qwen3-4B forward hooks)
into FLUX.2-Klein prompt embedding space, producing real-time avatar expressions.

Usage:
    from adapter import load_adapter
    adapter = load_adapter("adapters/xattn8tok_thinking.pt")
    # activation: [in_dim] tensor from LLM hidden state
    expression = adapter(activation, emotion_scale=6.0)
    # expression: [n_tokens, out_dim] — feed to Klein as prompt_embeds
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiTokenAdapter(nn.Module):
    def __init__(self, in_dim, out_dim, n_tokens=8, rank=128):
        super().__init__()
        self.in_dim, self.out_dim = in_dim, out_dim
        self.n_tokens, self.rank = n_tokens, rank
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, rank), nn.GELU(),
            nn.Linear(rank, rank), nn.GELU(),
        )
        self.token_queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
        self.project = nn.Linear(rank, out_dim)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        h = self.encoder(x)
        combined = h.unsqueeze(1) + self.token_queries.unsqueeze(0)
        return self.project(combined)


class CrossAttentionAdapter(nn.Module):
    def __init__(self, in_dim, out_dim, n_tokens=64, rank=128,
                 n_input_tokens=4, n_heads=4, n_layers=2):
        super().__init__()
        self.in_dim, self.out_dim = in_dim, out_dim
        self.n_tokens, self.rank = n_tokens, rank
        self.n_input_tokens = n_input_tokens
        self.input_encoder = nn.Sequential(
            nn.Linear(in_dim, rank), nn.GELU(),
            nn.Linear(rank, n_input_tokens * rank),
        )
        self.queries = nn.Parameter(torch.randn(n_tokens, rank) * 0.02)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=rank, nhead=n_heads,
            dim_feedforward=rank * 4, activation='gelu',
            batch_first=True, norm_first=True,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.project = nn.Linear(rank, out_dim)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        B = x.shape[0]
        memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
        queries = self.queries.unsqueeze(0).expand(B, -1, -1)
        decoded = self.decoder(queries, memory)
        return self.project(decoded)


class LayerWeightedInput(nn.Module):
    def __init__(self, n_layers=3, layer_dim=2560):
        super().__init__()
        self.n_layers, self.layer_dim = n_layers, layer_dim
        self.layer_logits = nn.Parameter(torch.zeros(n_layers))

    def forward(self, x):
        chunks = x.reshape(x.shape[0], self.n_layers, self.layer_dim)
        weights = F.softmax(self.layer_logits, dim=0)
        return (chunks * weights[None, :, None]).sum(dim=1)


def load_adapter(path, device="cpu", dtype=torch.float32):
    """Load an adapter checkpoint and return a callable wrapper.

    Args:
        path: Path to the .pt checkpoint file.
        device: Device to load onto.
        dtype: Dtype for normalization buffers.

    Returns:
        A callable that takes (activation, emotion_scale) and returns
        expression embeddings [n_tokens, out_dim].
    """
    ckpt = torch.load(path, map_location="cpu", weights_only=False)
    model_type = ckpt.get("model_type", "cross_attention")

    if model_type == "cross_attention":
        rank = ckpt["rank"]
        sd = ckpt["model_state_dict"]
        # Infer architecture from state dict
        enc_w = sd.get("_orig_mod.input_encoder.2.weight",
                       sd.get("input_encoder.2.weight"))
        n_input_tokens = enc_w.shape[0] // rank if enc_w is not None else ckpt.get("n_input_tokens", 4)
        decoder_keys = [k for k in sd if "decoder.layers." in k]
        layer_indices = set(int(k.split("decoder.layers.")[1].split(".")[0]) for k in decoder_keys)
        n_attn_layers = max(layer_indices) + 1 if layer_indices else ckpt.get("n_attn_layers", 2)

        adapter = CrossAttentionAdapter(
            ckpt["in_dim"], ckpt["out_dim"],
            n_tokens=ckpt["n_tokens"], rank=rank,
            n_input_tokens=n_input_tokens,
            n_heads=ckpt.get("n_heads", 4),
            n_layers=n_attn_layers,
        )
    else:
        adapter = MultiTokenAdapter(
            ckpt["in_dim"], ckpt["out_dim"],
            n_tokens=ckpt["n_tokens"], rank=ckpt["rank"],
        )

    sd = ckpt["model_state_dict"]
    sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()}
    adapter.load_state_dict(sd)
    adapter.eval().to(device)

    layer_weight = None
    if "layer_weight_state_dict" in ckpt:
        layer_weight = LayerWeightedInput()
        layer_weight.load_state_dict(ckpt["layer_weight_state_dict"])
        layer_weight.eval().to(device)

    act_mean = ckpt["act_mean"].to(device=device, dtype=dtype)
    act_std = ckpt["act_std"].to(device=device, dtype=dtype)
    target_center = ckpt.get("target_center", torch.zeros(1)).to(device=device, dtype=dtype)
    target_residual_std = ckpt.get("target_residual_std", torch.ones(1)).to(device=device, dtype=dtype)

    @torch.no_grad()
    def forward(activation, emotion_scale=1.0):
        act = activation.to(device=device, dtype=dtype)
        act_norm = (act - act_mean) / act_std
        if layer_weight is not None:
            act_norm = layer_weight(act_norm.unsqueeze(0)).squeeze(0)
        pred = adapter(act_norm.unsqueeze(0)).squeeze(0)
        return pred * target_residual_std * emotion_scale + target_center

    # Resolve hook layers from input_layers field
    input_layers = ckpt.get("input_layers", "layer_24")
    _LAYER_MAP = {
        "learned_weight": [9, 18, 27],
        "all_3": [9, 18, 27],
        "layer_9": [9],
        "layer_18": [18],
        "layer_27": [27],
        "layer_24": [24],
    }
    hook_layers = _LAYER_MAP.get(input_layers, [24])

    forward.adapter = adapter
    forward.layer_weight = layer_weight
    forward.hook_layers = hook_layers
    forward.metadata = {
        "model_type": model_type,
        "in_dim": ckpt["in_dim"],
        "out_dim": ckpt["out_dim"],
        "n_tokens": ckpt["n_tokens"],
        "rank": ckpt["rank"],
        "input_layers": input_layers,
        "hook_layers": hook_layers,
    }
    return forward