lfw3 commited on
Commit
942ca9e
·
verified ·
1 Parent(s): 0d08f16

Updated model to match weight shape

Browse files
Files changed (1) hide show
  1. clip_mlp.py +118 -65
clip_mlp.py CHANGED
@@ -1,65 +1,118 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class CLIPEmbeddingMLP(nn.Module):
6
- def __init__(
7
- self,
8
- clip_dim=512, # CLIP text embedding dimension (512 for ViT-B/32, 768 for ViT-L/14)
9
- string_embed_dim=512, # dimension for string embeddings from CLIP
10
- hidden_dims=[1024, 1024, 512], # hidden layer dimensions
11
- ):
12
- super().__init__()
13
-
14
- # Calculate total input dimension
15
- # 3 string embeddings + 4 categorical embeddings
16
- input_dim = 3 * string_embed_dim + 4 * 5
17
-
18
- # Build MLP layers
19
- layers = []
20
- prev_dim = input_dim
21
- for hidden_dim in hidden_dims:
22
- layers.extend([
23
- nn.Linear(prev_dim, hidden_dim),
24
- nn.LayerNorm(hidden_dim),
25
- nn.ReLU(),
26
- nn.Dropout(0.1)
27
- ])
28
- prev_dim = hidden_dim
29
-
30
- # Final projection to CLIP dimension
31
- layers.append(nn.Linear(prev_dim, clip_dim))
32
-
33
- self.mlp = nn.Sequential(*layers)
34
-
35
- def forward(self, string_embeds, categorical_inputs):
36
- """
37
- Args:
38
- string_embeds: tensor of shape (batch_size, 3, string_embed_dim)
39
- Pre-computed embeddings for the 3 string inputs
40
- categorical_inputs: tensor of shape (batch_size, 4) with values in [0, 4]
41
- Integer indices for the 4 categorical inputs
42
-
43
- Returns:
44
- clip_embeddings: tensor of shape (batch_size, clip_dim)
45
- """
46
- batch_size = string_embeds.shape[0]
47
-
48
- # Flatten string embeddings
49
- string_flat = string_embeds.reshape(batch_size, -1)
50
-
51
- # Convert categorical inputs to one-hot vectors
52
- # Each categorical input becomes a one-hot vector of size 5
53
- cat_onehot = F.one_hot(categorical_inputs.long(), num_classes=5) # (batch_size, 4, 5)
54
- cat_flat = cat_onehot.reshape(batch_size, -1).float() # (batch_size, 20)
55
-
56
- # Concatenate all inputs
57
- combined = torch.cat([string_flat, cat_flat], dim=1)
58
-
59
- # Pass through MLP
60
- output = self.mlp(combined)
61
-
62
- # L2 normalize to match CLIP embeddings
63
- output = output / output.norm(dim=-1, keepdim=True)
64
-
65
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class CLIPOffsetMLP(nn.Module):
6
+ """
7
+ MLP that predicts an offset in CLIP embedding space.
8
+ Architecture: concatenated [one-hot vectors, CLIP text embeddings] -> MLP -> offset vector
9
+ Final embedding: E_pred = E_base + E_offset
10
+ """
11
+ def __init__(
12
+ self,
13
+ clip_dim=512, # CLIP embedding dimension (512 for ViT-B/32, 768 for ViT-L/14)
14
+ string_embed_dim=512, # dimension for string embeddings from CLIP
15
+ num_categories_per_attr=[6, 7, 5, 5], # number of categories for each discrete attribute
16
+ hidden_dims=[1024, 1024, 512], # hidden layer dimensions
17
+ normalize_inputs=True, # normalize components before concatenation
18
+ ):
19
+ super().__init__()
20
+
21
+ self.clip_dim = clip_dim
22
+ self.string_embed_dim = string_embed_dim
23
+ self.num_categories_per_attr = num_categories_per_attr
24
+ self.normalize_inputs = normalize_inputs
25
+
26
+ # Calculate input dimensions
27
+ num_discrete_attrs = len(num_categories_per_attr)
28
+ total_onehot_dim = sum(num_categories_per_attr)
29
+
30
+ # Assuming 3 textual attributes (constellation, affiliation, etc.)
31
+ num_text_attrs = 3
32
+ total_text_dim = num_text_attrs * string_embed_dim
33
+
34
+ # Total input dimension after concatenation
35
+ input_dim = total_onehot_dim + total_text_dim
36
+
37
+ # Build MLP layers
38
+ layers = []
39
+ prev_dim = input_dim
40
+ for hidden_dim in hidden_dims:
41
+ layers.extend([
42
+ nn.Linear(prev_dim, hidden_dim),
43
+ nn.LayerNorm(hidden_dim),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.1)
46
+ ])
47
+ prev_dim = hidden_dim
48
+
49
+ # Final projection to CLIP dimension (offset vector)
50
+ layers.append(nn.Linear(prev_dim, clip_dim))
51
+
52
+ self.mlp = nn.Sequential(*layers)
53
+
54
+ def forward(self, string_embeds, categorical_inputs, base_text_embed):
55
+ """
56
+ Args:
57
+ string_embeds: tensor of shape (batch_size, num_text_attrs, string_embed_dim)
58
+ Pre-computed CLIP embeddings for textual attributes
59
+ categorical_inputs: tensor of shape (batch_size, num_discrete_attrs)
60
+ Integer indices for discrete attributes
61
+ base_text_embed: tensor of shape (batch_size, clip_dim) or (1, clip_dim)
62
+ Base text embedding for "Genshin-style character"
63
+
64
+ Returns:
65
+ pred_embeddings: tensor of shape (batch_size, clip_dim)
66
+ E_pred = E_base + E_offset
67
+ """
68
+ batch_size = string_embeds.shape[0]
69
+
70
+ # 1. Process one-hot vectors for discrete attributes
71
+ onehot_vectors = []
72
+ for i, num_cats in enumerate(self.num_categories_per_attr):
73
+ onehot = F.one_hot(categorical_inputs[:, i].long(), num_classes=num_cats)
74
+ onehot_vectors.append(onehot.float())
75
+
76
+ x_onehot = torch.cat(onehot_vectors, dim=1) # (batch_size, total_onehot_dim)
77
+
78
+ # 2. Process text embeddings
79
+ x_text = string_embeds.reshape(batch_size, -1) # (batch_size, num_text_attrs * embed_dim)
80
+
81
+ # 3. Normalize components before concatenation (as per spec)
82
+ if self.normalize_inputs:
83
+ # Normalize one-hot vector (L2 norm)
84
+ x_onehot = F.normalize(x_onehot, p=2, dim=1)
85
+ # Normalize text embeddings (L2 norm)
86
+ x_text = F.normalize(x_text, p=2, dim=1)
87
+
88
+ # 4. Concatenate: x_input = [x_onehot, E_text_attr]
89
+ x_input = torch.cat([x_onehot, x_text], dim=1)
90
+
91
+ # 5. Pass through MLP to get offset vector
92
+ offset = self.mlp(x_input)
93
+
94
+ # 6. Add offset to base embedding: E_pred = E_text + E_offset
95
+ # Handle broadcasting if base_text_embed is (1, clip_dim)
96
+ if base_text_embed.shape[0] == 1 and batch_size > 1:
97
+ base_text_embed = base_text_embed.expand(batch_size, -1)
98
+
99
+ pred_embeddings = base_text_embed + offset
100
+
101
+ # 7. Normalize final embedding (CLIP embeddings are typically normalized)
102
+ pred_embeddings = F.normalize(pred_embeddings, p=2, dim=1)
103
+
104
+ return pred_embeddings
105
+
106
+ def inference(self, string_embeds, categorical_inputs, base_text_embed):
107
+ """
108
+ Inference mode - identical to forward pass but explicitly named for clarity.
109
+
110
+ Args:
111
+ string_embeds: CLIP embeddings of textual attributes
112
+ categorical_inputs: Integer indices for discrete attributes
113
+ base_text_embed: Base embedding for "Genshin-style character"
114
+
115
+ Returns:
116
+ E_star: Conditioning vector for diffusion model
117
+ """
118
+ return self.forward(string_embeds, categorical_inputs, base_text_embed)