Leacb4 commited on
Commit
1994248
·
verified ·
1 Parent(s): 4a5d61a

Upload color_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. color_model.py +120 -4
color_model.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import config
2
  import os
3
  import json
@@ -21,11 +29,21 @@ logger = logging.getLogger(__name__)
21
  # Dataset Classes
22
  # -------------------------------
23
  class ColorDataset(Dataset):
 
 
 
 
 
 
 
24
  def __init__(self, dataframe, tokenizer, transform=None):
25
  """
26
- dataframe : pd.DataFrame with columns image and text columns
27
- tokenizer : function that converts text -> list of integers (tokens)
28
- transform : transformations on the image
 
 
 
29
  """
30
  self.df = dataframe.reset_index(drop=True)
31
  self.tokenizer = tokenizer
@@ -37,9 +55,19 @@ class ColorDataset(Dataset):
37
  ])
38
 
39
  def __len__(self):
 
40
  return len(self.df)
41
 
42
  def __getitem__(self, idx):
 
 
 
 
 
 
 
 
 
43
  row = self.df.iloc[idx]
44
  img = Image.open(config.column_local_image_path).convert("RGB")
45
  img = self.transform(img)
@@ -50,13 +78,34 @@ class ColorDataset(Dataset):
50
  # Tokenizer
51
  # -------------------------------
52
  class Tokenizer:
 
 
 
 
 
 
 
53
  def __init__(self):
 
 
 
 
 
 
54
  self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
55
  self.idx2word = {}
56
  self.counter = 1
57
 
58
  def preprocess_text(self, text):
59
- """Extract color-related keywords from text"""
 
 
 
 
 
 
 
 
60
  # Color-related keywords to keep
61
  color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
62
  'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime',
@@ -76,6 +125,12 @@ class Tokenizer:
76
  return ' '.join(filtered_words) if filtered_words else text.lower()
77
 
78
  def fit(self, texts):
 
 
 
 
 
 
79
  for text in texts:
80
  processed_text = self.preprocess_text(text)
81
  for word in processed_text.split():
@@ -85,10 +140,25 @@ class Tokenizer:
85
  self.counter += 1
86
 
87
  def __call__(self, text):
 
 
 
 
 
 
 
 
 
88
  processed_text = self.preprocess_text(text)
89
  return [self.word2idx[word] for word in processed_text.split()]
90
 
91
  def load_vocab(self, word2idx_dict):
 
 
 
 
 
 
92
  self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()})
93
  self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
94
  self.counter = max(self.word2idx.values(), default=0) + 1
@@ -97,7 +167,20 @@ class Tokenizer:
97
  # Model Components
98
  # -------------------------------
99
  class ImageEncoder(nn.Module):
 
 
 
 
 
 
 
100
  def __init__(self, embedding_dim=config.color_emb_dim):
 
 
 
 
 
 
101
  super().__init__()
102
  self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
103
  self.backbone.fc = nn.Sequential(
@@ -106,17 +189,50 @@ class ImageEncoder(nn.Module):
106
  )
107
 
108
  def forward(self, x):
 
 
 
 
 
 
 
 
 
109
  x = self.backbone(x)
110
  return F.normalize(x, dim=-1)
111
 
112
  class TextEncoder(nn.Module):
 
 
 
 
 
 
 
113
  def __init__(self, vocab_size, embedding_dim=config.color_emb_dim):
 
 
 
 
 
 
 
114
  super().__init__()
115
  self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
116
  self.dropout = nn.Dropout(0.1) # Add regularization
117
  self.fc = nn.Linear(32, embedding_dim)
118
 
119
  def forward(self, x, lengths=None):
 
 
 
 
 
 
 
 
 
 
120
  emb = self.embedding(x) # [B, T, 32]
121
  emb = self.dropout(emb) # Apply dropout
122
  if lengths is not None:
 
1
+ """
2
+ ColorCLIP model for learning color-aligned embeddings.
3
+ This file contains the ColorCLIP model that learns to encode images and texts
4
+ in an embedding space specialized for color representation. It includes
5
+ a ResNet-based image encoder, a text encoder with custom tokenizer,
6
+ and contrastive loss functions for training.
7
+ """
8
+
9
  import config
10
  import os
11
  import json
 
29
  # Dataset Classes
30
  # -------------------------------
31
  class ColorDataset(Dataset):
32
+ """
33
+ Dataset class for color embedding training.
34
+
35
+ Handles loading images from local paths and tokenizing text descriptions
36
+ for training the ColorCLIP model.
37
+ """
38
+
39
  def __init__(self, dataframe, tokenizer, transform=None):
40
  """
41
+ Initialize the color dataset.
42
+
43
+ Args:
44
+ dataframe: DataFrame with columns for image paths and text descriptions
45
+ tokenizer: Tokenizer instance that converts text to list of integers (tokens)
46
+ transform: Optional image transformations (default: standard ImageNet normalization)
47
  """
48
  self.df = dataframe.reset_index(drop=True)
49
  self.tokenizer = tokenizer
 
55
  ])
56
 
57
  def __len__(self):
58
+ """Return the number of samples in the dataset."""
59
  return len(self.df)
60
 
61
  def __getitem__(self, idx):
62
+ """
63
+ Get a sample from the dataset.
64
+
65
+ Args:
66
+ idx: Index of the sample
67
+
68
+ Returns:
69
+ Tuple of (image_tensor, token_tensor)
70
+ """
71
  row = self.df.iloc[idx]
72
  img = Image.open(config.column_local_image_path).convert("RGB")
73
  img = self.transform(img)
 
78
  # Tokenizer
79
  # -------------------------------
80
  class Tokenizer:
81
+ """
82
+ Tokenizer for extracting color-related keywords from text.
83
+
84
+ This tokenizer filters text to keep only color-related words and basic
85
+ descriptive words, then maps them to integer indices for embedding.
86
+ """
87
+
88
  def __init__(self):
89
+ """
90
+ Initialize the tokenizer.
91
+
92
+ Creates empty word-to-index and index-to-word mappings.
93
+ Index 0 is reserved for padding/unknown tokens.
94
+ """
95
  self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
96
  self.idx2word = {}
97
  self.counter = 1
98
 
99
  def preprocess_text(self, text):
100
+ """
101
+ Extract color-related keywords from text.
102
+
103
+ Args:
104
+ text: Input text string
105
+
106
+ Returns:
107
+ Preprocessed text containing only color and descriptive keywords
108
+ """
109
  # Color-related keywords to keep
110
  color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
111
  'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime',
 
125
  return ' '.join(filtered_words) if filtered_words else text.lower()
126
 
127
  def fit(self, texts):
128
+ """
129
+ Build vocabulary from a list of texts.
130
+
131
+ Args:
132
+ texts: List of text strings to build vocabulary from
133
+ """
134
  for text in texts:
135
  processed_text = self.preprocess_text(text)
136
  for word in processed_text.split():
 
140
  self.counter += 1
141
 
142
  def __call__(self, text):
143
+ """
144
+ Tokenize a text string into a list of integer indices.
145
+
146
+ Args:
147
+ text: Input text string
148
+
149
+ Returns:
150
+ List of integer token indices
151
+ """
152
  processed_text = self.preprocess_text(text)
153
  return [self.word2idx[word] for word in processed_text.split()]
154
 
155
  def load_vocab(self, word2idx_dict):
156
+ """
157
+ Load vocabulary from a word-to-index dictionary.
158
+
159
+ Args:
160
+ word2idx_dict: Dictionary mapping words to indices
161
+ """
162
  self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()})
163
  self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
164
  self.counter = max(self.word2idx.values(), default=0) + 1
 
167
  # Model Components
168
  # -------------------------------
169
  class ImageEncoder(nn.Module):
170
+ """
171
+ Image encoder based on ResNet18 for extracting image embeddings.
172
+
173
+ Uses a pretrained ResNet18 backbone and replaces the final layer
174
+ to output embeddings of the specified dimension.
175
+ """
176
+
177
  def __init__(self, embedding_dim=config.color_emb_dim):
178
+ """
179
+ Initialize the image encoder.
180
+
181
+ Args:
182
+ embedding_dim: Dimension of the output embedding (default: color_emb_dim)
183
+ """
184
  super().__init__()
185
  self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
186
  self.backbone.fc = nn.Sequential(
 
189
  )
190
 
191
  def forward(self, x):
192
+ """
193
+ Forward pass through the image encoder.
194
+
195
+ Args:
196
+ x: Image tensor [batch_size, channels, height, width]
197
+
198
+ Returns:
199
+ Normalized image embeddings [batch_size, embedding_dim]
200
+ """
201
  x = self.backbone(x)
202
  return F.normalize(x, dim=-1)
203
 
204
  class TextEncoder(nn.Module):
205
+ """
206
+ Text encoder for extracting text embeddings from token sequences.
207
+
208
+ Uses an embedding layer followed by mean pooling (with optional length normalization)
209
+ and a linear projection to the output embedding dimension.
210
+ """
211
+
212
  def __init__(self, vocab_size, embedding_dim=config.color_emb_dim):
213
+ """
214
+ Initialize the text encoder.
215
+
216
+ Args:
217
+ vocab_size: Size of the vocabulary
218
+ embedding_dim: Dimension of the output embedding (default: color_emb_dim)
219
+ """
220
  super().__init__()
221
  self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
222
  self.dropout = nn.Dropout(0.1) # Add regularization
223
  self.fc = nn.Linear(32, embedding_dim)
224
 
225
  def forward(self, x, lengths=None):
226
+ """
227
+ Forward pass through the text encoder.
228
+
229
+ Args:
230
+ x: Token tensor [batch_size, sequence_length]
231
+ lengths: Optional sequence lengths tensor [batch_size] for proper mean pooling
232
+
233
+ Returns:
234
+ Normalized text embeddings [batch_size, embedding_dim]
235
+ """
236
  emb = self.embedding(x) # [B, T, 32]
237
  emb = self.dropout(emb) # Apply dropout
238
  if lengths is not None: