Alogotron commited on
Commit
f0366ca
·
verified ·
1 Parent(s): 5df13d2

Upload sdxl/sdxl_adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sdxl/sdxl_adapter.py +112 -0
sdxl/sdxl_adapter.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SDXL Adapter - Maps Qwen3-4B activations to SDXL prompt embedding space.
3
+
4
+ Input: [B, 7680] - Qwen3-4B hidden states from layers [9, 18, 27]
5
+ Output: [B, 77, 2048] prompt_embeds + [B, 1280] pooled_prompt_embeds
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class LayerWeightedInput(nn.Module):
13
+ def __init__(self, n_layers=3, layer_dim=2560):
14
+ super().__init__()
15
+ self.n_layers = n_layers
16
+ self.layer_dim = layer_dim
17
+ self.layer_logits = nn.Parameter(torch.zeros(n_layers))
18
+
19
+ def forward(self, x):
20
+ # x: [B, n_layers * layer_dim] -> [B, layer_dim]
21
+ B = x.shape[0]
22
+ chunks = x.reshape(B, self.n_layers, self.layer_dim)
23
+ weights = F.softmax(self.layer_logits, dim=0)
24
+ return (chunks * weights[None, :, None]).sum(dim=1)
25
+
26
+
27
+ class SDXLCrossAttentionAdapter(nn.Module):
28
+ """Cross-attention adapter mapping LLM activations to SDXL embedding space."""
29
+
30
+ def __init__(self, in_dim=2560, rank=256, n_input_tokens=8,
31
+ n_heads=8, n_layers=3, n_output_tokens=77,
32
+ main_dim=2048, pooled_dim=1280):
33
+ super().__init__()
34
+ self.in_dim = in_dim
35
+ self.rank = rank
36
+ self.n_input_tokens = n_input_tokens
37
+ self.n_output_tokens = n_output_tokens
38
+ self.main_dim = main_dim
39
+ self.pooled_dim = pooled_dim
40
+
41
+ # Encode input activation into multiple tokens
42
+ self.input_encoder = nn.Sequential(
43
+ nn.Linear(in_dim, rank), nn.GELU(),
44
+ nn.Linear(rank, n_input_tokens * rank),
45
+ )
46
+
47
+ # Learnable queries for 77 output tokens
48
+ self.queries = nn.Parameter(torch.randn(n_output_tokens, rank) * 0.02)
49
+
50
+ # Transformer decoder: queries attend to encoded input
51
+ decoder_layer = nn.TransformerDecoderLayer(
52
+ d_model=rank, nhead=n_heads,
53
+ dim_feedforward=rank * 4, activation='gelu',
54
+ batch_first=True, norm_first=True,
55
+ )
56
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
57
+
58
+ # Project to SDXL main embedding space
59
+ self.main_project = nn.Sequential(
60
+ nn.LayerNorm(rank),
61
+ nn.Linear(rank, main_dim),
62
+ )
63
+
64
+ # Pooled embedding head: aggregate decoded tokens -> single vector
65
+ self.pooled_head = nn.Sequential(
66
+ nn.Linear(rank, rank), nn.GELU(),
67
+ nn.Linear(rank, pooled_dim),
68
+ )
69
+
70
+ def forward(self, x):
71
+ """x: [B, in_dim] -> (main: [B, 77, 2048], pooled: [B, 1280])"""
72
+ if x.dim() == 1:
73
+ x = x.unsqueeze(0)
74
+ B = x.shape[0]
75
+
76
+ # Encode input into memory tokens
77
+ memory = self.input_encoder(x).reshape(B, self.n_input_tokens, self.rank)
78
+
79
+ # Cross-attention: queries attend to memory
80
+ queries = self.queries.unsqueeze(0).expand(B, -1, -1)
81
+ decoded = self.decoder(queries, memory) # [B, 77, rank]
82
+
83
+ # Main embeddings
84
+ main_embeds = self.main_project(decoded) # [B, 77, 2048]
85
+
86
+ # Pooled embeddings from mean of decoded
87
+ pooled = self.pooled_head(decoded.mean(dim=1)) # [B, 1280]
88
+
89
+ return main_embeds, pooled
90
+
91
+
92
+ def count_params(model):
93
+ return sum(p.numel() for p in model.parameters())
94
+
95
+
96
+ if __name__ == "__main__":
97
+ # Quick test
98
+ layer_weight = LayerWeightedInput(n_layers=3, layer_dim=2560)
99
+ adapter = SDXLCrossAttentionAdapter(
100
+ in_dim=2560, rank=256, n_input_tokens=8,
101
+ n_heads=8, n_layers=3,
102
+ )
103
+ x = torch.randn(2, 7680) # batch of 2, concat of 3 layers
104
+ x_weighted = layer_weight(x) # [2, 2560]
105
+ main, pooled = adapter(x_weighted)
106
+ print(f"LayerWeightedInput params: {count_params(layer_weight):,}")
107
+ print(f"SDXLAdapter params: {count_params(adapter):,}")
108
+ print(f"Total params: {count_params(layer_weight) + count_params(adapter):,}")
109
+ print(f"Input: {x.shape}")
110
+ print(f"Weighted: {x_weighted.shape}")
111
+ print(f"Main embeds: {main.shape}")
112
+ print(f"Pooled embeds: {pooled.shape}")