Leacb4 commited on
Commit
fc411a2
·
verified ·
1 Parent(s): 41133e4

Upload training/color_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/color_model.py +258 -485
training/color_model.py CHANGED
@@ -1,422 +1,205 @@
1
  """
2
  ColorCLIP model for learning color-aligned embeddings.
3
- This file contains the ColorCLIP model that learns to encode images and texts
4
- in an embedding space specialized for color representation. It includes
5
- a ResNet-based image encoder, a text encoder with custom tokenizer,
6
- and contrastive loss functions for training.
 
 
7
  """
8
 
9
  import config
10
- import os
11
- import json
12
  import torch
13
  from torch.utils.data import Dataset, DataLoader
14
- from torchvision import transforms, models
15
  from PIL import Image
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
  import pandas as pd
19
  from tqdm.auto import tqdm
20
- from collections import defaultdict
21
- from typing import Optional, List
22
  import logging
23
 
24
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
  logger = logging.getLogger(__name__)
 
 
28
  # -------------------------------
29
  # Dataset Classes
30
  # -------------------------------
 
31
  class ColorDataset(Dataset):
32
- """
33
- Dataset class for color embedding training.
34
-
35
- Handles loading images from local paths and tokenizing text descriptions
36
- for training the ColorCLIP model.
37
- """
38
-
39
- def __init__(self, dataframe, tokenizer, transform=None):
40
- """
41
- Initialize the color dataset.
42
-
43
- Args:
44
- dataframe: DataFrame with columns for image paths and text descriptions
45
- tokenizer: Tokenizer instance that converts text to list of integers (tokens)
46
- transform: Optional image transformations (default: standard ImageNet normalization)
47
- """
48
  self.df = dataframe.reset_index(drop=True)
49
- self.tokenizer = tokenizer
50
- self.transform = transform or transforms.Compose([
51
- transforms.Resize((224,224)),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[0.485,0.456,0.406],
54
- std=[0.229,0.224,0.225])
55
- ])
56
-
57
  def __len__(self):
58
- """Return the number of samples in the dataset."""
59
  return len(self.df)
60
-
61
  def __getitem__(self, idx):
62
- """
63
- Get a sample from the dataset.
64
-
65
- Args:
66
- idx: Index of the sample
67
-
68
- Returns:
69
- Tuple of (image_tensor, token_tensor)
70
- """
71
  row = self.df.iloc[idx]
72
- # Fix: Get the image path from the row, not the column name
73
- img_path = row[config.column_local_image_path]
74
- img = Image.open(img_path).convert("RGB")
75
- img = self.transform(img)
76
- tokens = torch.tensor(self.tokenizer(row[config.text_column]), dtype=torch.long)
77
- return img, tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # -------------------------------
80
- # Tokenizer
81
  # -------------------------------
82
- class Tokenizer:
83
- """
84
- Tokenizer for extracting color-related keywords from text.
85
-
86
- This tokenizer filters text to keep only color-related words and basic
87
- descriptive words, then maps them to integer indices for embedding.
88
- """
89
-
90
- def __init__(self):
91
- """
92
- Initialize the tokenizer.
93
-
94
- Creates empty word-to-index and index-to-word mappings.
95
- Index 0 is reserved for padding/unknown tokens.
96
- """
97
- self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
98
- self.idx2word = {}
99
- self.counter = 1
100
-
101
- def preprocess_text(self, text):
102
- """
103
- Extract color-related keywords from text.
104
-
105
- Args:
106
- text: Input text string
107
-
108
- Returns:
109
- Preprocessed text containing only color and descriptive keywords
110
- """
111
- # Color-related keywords to keep
112
- color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
113
- 'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime',
114
- 'violet', 'turquoise', 'teal', 'tan', 'snow', 'silver', 'plum',
115
- 'olive', 'fuchsia', 'gold', 'cream', 'ivory', 'maroon']
116
-
117
- # Keep only color-related words and basic descriptive words
118
- descriptive_words = ['shirt', 'dress', 'top', 'bottom', 'shoe', 'bag', 'hat', 'short', 'long', 'sleeve']
119
-
120
- words = text.lower().split()
121
- filtered_words = []
122
- for word in words:
123
- # Keep color words and some descriptive words
124
- if word in color_keywords or word in descriptive_words:
125
- filtered_words.append(word)
126
-
127
- return ' '.join(filtered_words) if filtered_words else text.lower()
128
-
129
- def fit(self, texts):
130
- """
131
- Build vocabulary from a list of texts.
132
-
133
- Args:
134
- texts: List of text strings to build vocabulary from
135
- """
136
- for text in texts:
137
- processed_text = self.preprocess_text(text)
138
- for word in processed_text.split():
139
- if word not in self.word2idx:
140
- self.word2idx[word] = self.counter
141
- self.idx2word[self.counter] = word
142
- self.counter += 1
143
-
144
- def __call__(self, text):
145
- """
146
- Tokenize a text string into a list of integer indices.
147
-
148
- Args:
149
- text: Input text string
150
-
151
- Returns:
152
- List of integer token indices
153
- """
154
- processed_text = self.preprocess_text(text)
155
- return [self.word2idx[word] for word in processed_text.split()]
156
-
157
- def load_vocab(self, word2idx_dict):
158
- """
159
- Load vocabulary from a word-to-index dictionary.
160
-
161
- Args:
162
- word2idx_dict: Dictionary mapping words to indices
163
- """
164
- self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()})
165
- self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
166
- self.counter = max(self.word2idx.values(), default=0) + 1
167
 
168
  # -------------------------------
169
- # Model Components
170
  # -------------------------------
171
- class ImageEncoder(nn.Module):
172
- """
173
- Image encoder based on ResNet18 for extracting image embeddings.
174
-
175
- Uses a pretrained ResNet18 backbone and replaces the final layer
176
- to output embeddings of the specified dimension.
177
- """
178
-
179
- def __init__(self, embedding_dim=config.color_emb_dim):
180
- """
181
- Initialize the image encoder.
182
-
183
- Args:
184
- embedding_dim: Dimension of the output embedding (default: color_emb_dim)
185
- """
186
- super().__init__()
187
- self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
188
- self.backbone.fc = nn.Sequential(
189
- nn.Dropout(0.1), # Add regularization
190
- nn.Linear(self.backbone.fc.in_features, embedding_dim)
191
- )
192
-
193
- def forward(self, x):
194
- """
195
- Forward pass through the image encoder.
196
-
197
- Args:
198
- x: Image tensor [batch_size, channels, height, width]
199
-
200
- Returns:
201
- Normalized image embeddings [batch_size, embedding_dim]
202
- """
203
- x = self.backbone(x)
204
- return F.normalize(x, dim=-1)
205
-
206
- class TextEncoder(nn.Module):
207
- """
208
- Text encoder for extracting text embeddings from token sequences.
209
-
210
- Uses an embedding layer followed by mean pooling (with optional length normalization)
211
- and a linear projection to the output embedding dimension.
212
- """
213
-
214
- def __init__(self, vocab_size, embedding_dim=config.color_emb_dim):
215
- """
216
- Initialize the text encoder.
217
-
218
- Args:
219
- vocab_size: Size of the vocabulary
220
- embedding_dim: Dimension of the output embedding (default: color_emb_dim)
221
- """
222
- super().__init__()
223
- self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
224
- self.dropout = nn.Dropout(0.1) # Add regularization
225
- self.fc = nn.Linear(32, embedding_dim)
226
-
227
- def forward(self, x, lengths=None):
228
- """
229
- Forward pass through the text encoder.
230
-
231
- Args:
232
- x: Token tensor [batch_size, sequence_length]
233
- lengths: Optional sequence lengths tensor [batch_size] for proper mean pooling
234
-
235
- Returns:
236
- Normalized text embeddings [batch_size, embedding_dim]
237
- """
238
- emb = self.embedding(x) # [B, T, 32]
239
- emb = self.dropout(emb) # Apply dropout
240
- if lengths is not None:
241
- summed = emb.sum(dim=1) # [B, 32]
242
- mean = summed / lengths.unsqueeze(1).clamp_min(1)
243
- else:
244
- mean = emb.mean(dim=1)
245
- return F.normalize(self.fc(mean), dim=-1)
246
 
247
  class ColorCLIP(nn.Module):
248
  """
249
- Color CLIP model for learning color-aligned image-text embeddings.
 
 
 
 
250
  """
251
- def __init__(self, vocab_size, embedding_dim=config.color_emb_dim, tokenizer=None):
252
- """
253
- Initialize ColorCLIP model.
254
-
255
- Args:
256
- vocab_size: Size of the vocabulary for text encoding
257
- embedding_dim: Dimension of the embedding space (default: color_emb_dim)
258
- tokenizer: Optional Tokenizer instance (will create one if None)
259
- """
260
  super().__init__()
261
- self.vocab_size = vocab_size
 
262
  self.embedding_dim = embedding_dim
263
- self.image_encoder = ImageEncoder(embedding_dim)
264
- self.text_encoder = TextEncoder(vocab_size, embedding_dim)
265
- self.tokenizer = tokenizer
266
-
267
- def forward(self, image, text, lengths=None):
268
- """
269
- Forward pass through the model.
270
-
271
- Args:
272
- image: Image tensor [B, C, H, W]
273
- text: Text token tensor [B, T]
274
- lengths: Optional sequence lengths tensor [B]
275
-
276
- Returns:
277
- Tuple of (image_embeddings, text_embeddings)
278
- """
279
- return self.image_encoder(image), self.text_encoder(text, lengths)
280
-
281
- def get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
282
- """
283
- Get text embeddings for a list of text strings.
284
-
285
- Args:
286
- texts: List of text strings
287
-
288
- Returns:
289
- Text embeddings tensor [batch_size, embedding_dim]
290
- """
291
- if self.tokenizer is None:
292
- raise ValueError("Tokenizer must be set before calling get_text_embeddings")
293
-
294
- token_lists = [self.tokenizer(t) for t in texts]
295
- max_len = max((len(toks) for toks in token_lists), default=0)
296
- padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists]
297
- input_ids = torch.tensor(padded, dtype=torch.long, device=next(self.parameters()).device)
298
- lengths = torch.tensor([len(toks) for toks in token_lists], dtype=torch.long, device=input_ids.device)
299
  with torch.no_grad():
300
- emb = self.text_encoder(input_ids, lengths)
301
- return emb
302
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  @classmethod
304
- def from_pretrained(cls, model_path: str, vocab_path: Optional[str] = None, device: str = "cpu", repo_id: Optional[str] = None):
305
- """
306
- Load a pretrained ColorCLIP model from a file path or Hugging Face Hub.
307
-
308
- Args:
309
- model_path: Path to the model checkpoint (.pt file) or filename if using repo_id
310
- vocab_path: Optional path to tokenizer vocabulary JSON file or filename if using repo_id
311
- device: Device to load the model on (default: "cpu")
312
- repo_id: Optional Hugging Face repository ID (e.g., "username/model-name")
313
- If provided, model_path and vocab_path should be filenames within the repo
314
-
315
- Returns:
316
- ColorCLIP model instance
317
-
318
- Example:
319
- # Load from local file
320
- model = ColorCLIP.from_pretrained("color_model.pt", "tokenizer_vocab.json")
321
-
322
- # Load from Hugging Face Hub
323
- from huggingface_hub import hf_hub_download
324
- model_file = hf_hub_download(repo_id="username/model-name", filename="color_model.pt")
325
- vocab_file = hf_hub_download(repo_id="username/model-name", filename="tokenizer_vocab.json")
326
- model = ColorCLIP.from_pretrained(model_file, vocab_file)
327
- """
328
- device_obj = torch.device(device)
329
-
330
- # Support loading from Hugging Face Hub if repo_id is provided
331
- if repo_id:
332
- try:
333
- from huggingface_hub import hf_hub_download
334
- model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
335
- if vocab_path:
336
- vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_path)
337
- except ImportError:
338
- raise ImportError("huggingface_hub is required to load models from Hugging Face. Install it with: pip install huggingface-hub")
339
-
340
- # Load model checkpoint
341
- checkpoint = torch.load(model_path, map_location=device_obj)
342
-
343
- # Extract vocab size and embedding dimension from checkpoint
344
- if isinstance(checkpoint, dict):
345
- # Try to get vocab_size from metadata first
346
- vocab_size = checkpoint.get('vocab_size', None)
347
- embedding_dim = checkpoint.get('embedding_dim', 16)
348
-
349
- # If not in metadata, try to infer from model state
350
- if vocab_size is None:
351
- state_dict = checkpoint.get('model_state_dict', checkpoint)
352
- if 'text_encoder.embedding.weight' in state_dict:
353
- vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
354
- else:
355
- raise ValueError("Could not determine vocab_size from checkpoint")
356
-
357
- # Load state dict
358
- state_dict = checkpoint.get('model_state_dict', checkpoint)
359
- else:
360
- raise ValueError("Checkpoint must be a dictionary")
361
-
362
- # Initialize model
363
- model = cls(vocab_size=vocab_size, embedding_dim=embedding_dim)
364
- model.load_state_dict(state_dict)
365
- model = model.to(device_obj)
366
-
367
- # Load tokenizer if vocab path is provided
368
- if vocab_path and os.path.exists(vocab_path):
369
- tokenizer = Tokenizer()
370
- with open(vocab_path, 'r') as f:
371
- vocab_dict = json.load(f)
372
- tokenizer.load_vocab(vocab_dict)
373
- model.tokenizer = tokenizer
374
-
375
  model.eval()
376
  return model
377
-
378
- def save_pretrained(self, save_directory: str, vocab_path: Optional[str] = None):
379
- """
380
- Save the model and optionally the tokenizer vocabulary.
381
-
382
- Args:
383
- save_directory: Directory to save the model
384
- vocab_path: Optional path to save tokenizer vocabulary
385
- """
386
- os.makedirs(save_directory, exist_ok=True)
387
-
388
- # Save model checkpoint
389
- model_path = os.path.join(save_directory, config.color_model_path)
390
- checkpoint = {
391
- 'model_state_dict': self.state_dict(),
392
- 'vocab_size': self.vocab_size,
393
- 'embedding_dim': self.embedding_dim
394
- }
395
- torch.save(checkpoint, model_path)
396
-
397
- # Save tokenizer vocabulary if available
398
- if self.tokenizer is not None:
399
- vocab_dict = dict(self.tokenizer.word2idx)
400
- if vocab_path is None:
401
- vocab_path = os.path.join(save_directory, config.tokeniser_path)
402
- with open(vocab_path, 'w') as f:
403
- json.dump(vocab_dict, f)
404
-
405
- return model_path, vocab_path
406
 
407
 
408
  # -------------------------------
409
- # Loss Functions and Utilities
410
  # -------------------------------
 
411
  def clip_loss(image_emb, text_emb, temperature=0.07):
412
  """
413
  CLIP contrastive loss function.
414
-
415
  Args:
416
  image_emb: Image embeddings [batch_size, embedding_dim]
417
  text_emb: Text embeddings [batch_size, embedding_dim]
418
  temperature: Temperature scaling parameter
419
-
420
  Returns:
421
  Contrastive loss value
422
  """
@@ -426,144 +209,134 @@ def clip_loss(image_emb, text_emb, temperature=0.07):
426
  loss_t2i = F.cross_entropy(logits.T, labels)
427
  return (loss_i2t + loss_t2i) / 2
428
 
429
- def collate_batch(batch):
430
- """
431
- Collate function for DataLoader that pads sequences and filters None values.
432
-
433
- Args:
434
- batch: List of (image, tokens) tuples or None
435
-
436
- Returns:
437
- Tuple of (images, padded_tokens, lengths) or None if batch is empty
438
- """
439
- batch = [b for b in batch if b is not None]
440
- if len(batch) == 0:
441
- return None
442
- imgs, tokens = zip(*batch)
443
- imgs = torch.stack(imgs, dim=0)
444
- lengths = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
445
- tokens_padded = nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)
446
- return imgs, tokens_padded, lengths
447
 
 
 
 
448
 
 
 
 
 
 
 
 
449
 
450
- if __name__ == "__main__":
451
- """
452
- Training script for ColorCLIP model.
453
- This code only runs when the file is executed directly, not when imported.
454
- """
455
- # Configuration
456
- batch_size = 16
457
- lr = 1e-4
458
- epochs=50
459
 
 
 
 
460
 
 
 
 
 
461
 
462
- # Load dataset and split train/test
463
- tokenizer = Tokenizer()
464
  df = pd.read_csv(config.local_dataset_path)
465
-
466
- # Data preparation: Reduce to main colors only (11 classes instead of 34)
467
- main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
468
  df = df[df[config.color_column].isin(main_colors)].copy()
469
- print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors")
470
- print(f"🎨 Colors: {sorted(df[config.color_column].unique())}")
471
-
472
- tokenizer.fit(df[config.text_column].tolist())
473
-
474
- # Filter only rows with a valid local file
475
- df_local = df[df[config.column_local_image_path].astype(str).str.len() > 0]
476
- df_local = df_local[df_local[config.column_local_image_path].apply(lambda p: os.path.isfile(p))]
477
- df_local = df_local.reset_index(drop=True)
478
-
479
-
480
- # split 90/10
481
- df_local = df_local.sample(frac=1.0, random_state=42).reset_index(drop=True)
482
- split_idx = int(0.9 * len(df_local))
483
- df_train = df_local.iloc[:split_idx].reset_index(drop=True)
484
- df_test = df_local.iloc[split_idx:].reset_index(drop=True)
485
-
486
-
487
- train_dataset = ColorDataset(df_train, tokenizer)
488
- test_dataset = ColorDataset(df_test, tokenizer)
489
-
490
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)
491
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=0)
492
-
493
- device = config.device
494
- print(f"Using device: {device}")
495
-
496
- model = ColorCLIP(vocab_size=tokenizer.counter, embedding_dim=config.color_emb_dim, tokenizer=tokenizer).to(device)
497
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) # Add weight decay
498
-
499
- # Save tokenizer vocab once (or update) so evaluation can reload the same mapping
500
- here = os.path.dirname(__file__)
501
- vocab_out = os.path.join(here, config.tokeniser_path)
502
- with open(vocab_out, "w") as f:
503
- json.dump(dict(tokenizer.word2idx), f)
504
- print(f"Tokenizer vocabulary saved to: {vocab_out}")
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  for epoch in range(epochs):
508
- model.train()
509
- pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - train", leave=False)
510
- epoch_losses = []
511
- for batch in train_loader:
512
  if batch is None:
513
- pbar.update(1)
514
  continue
515
- imgs, texts, lengths = batch
516
- imgs = imgs.to(device)
517
- texts = texts.to(device)
518
- lengths = lengths.to(device)
519
  optimizer.zero_grad()
520
- img_emb, text_emb = model(imgs, texts, lengths)
521
- loss = clip_loss(img_emb, text_emb)
 
522
  loss.backward()
523
  optimizer.step()
524
- epoch_losses.append(loss.item())
525
- pbar.set_postfix({"loss": f"{loss.item():.4f}", "avg": f"{sum(epoch_losses)/len(epoch_losses):.4f}"})
526
- pbar.update(1)
527
- pbar.close()
528
-
529
- avg_train_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else None
530
- if avg_train_loss is not None:
531
- print(f"[Train] Epoch {epoch+1}/{epochs} - avg loss: {avg_train_loss:.4f}")
532
- else:
533
- print(f"[Train] Epoch {epoch+1}/{epochs} - no valid batches")
534
-
535
- # Eval rapide sur test avec barre
536
- model.eval()
537
  test_losses = []
538
  with torch.no_grad():
539
- pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{epochs} - test", leave=False)
540
  for batch in test_loader:
541
  if batch is None:
542
- pbar_t.update(1)
543
  continue
544
- imgs, texts, lengths = batch
545
- imgs = imgs.to(device)
546
- texts = texts.to(device)
547
- lengths = lengths.to(device)
548
- img_emb, text_emb = model(imgs, texts, lengths)
549
- test_losses.append(clip_loss(img_emb, text_emb).item())
550
- pbar_t.update(1)
551
- pbar_t.close()
552
- if len(test_losses) > 0:
553
- avg_test_loss = sum(test_losses) / len(test_losses)
554
- print(f"[Test ] Epoch {epoch+1}/{epochs} - avg loss: {avg_test_loss:.4f}")
555
- else:
556
- print(f"[Test ] Epoch {epoch+1}/{epochs} - no valid batches")
557
-
558
- # --- Save checkpoint at every epoch ---
559
- ckpt_dir = here
560
- latest_path = os.path.join(ckpt_dir, config.color_model_path)
561
- epoch_path = os.path.join(ckpt_dir, f"color_model_epoch_{epoch+1}.pt")
562
- checkpoint = {
563
- 'model_state_dict': model.state_dict(),
564
- 'vocab_size': model.vocab_size,
565
- 'embedding_dim': model.embedding_dim
566
- }
567
- torch.save(checkpoint, latest_path)
568
- torch.save(checkpoint, epoch_path)
569
- print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}")
 
1
  """
2
  ColorCLIP model for learning color-aligned embeddings.
3
+
4
+ Architecture: frozen CLIP (ViT-B/32) encoders with trainable linear projections
5
+ to a compact 16-dimensional embedding space. CLIP provides rich image and text
6
+ understanding while the learned projections specialise the representation for
7
+ color similarity. Only the two small Linear layers are trained; the CLIP
8
+ backbone remains frozen throughout.
9
  """
10
 
11
  import config
 
 
12
  import torch
13
  from torch.utils.data import Dataset, DataLoader
 
14
  from PIL import Image
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
  import pandas as pd
18
  from tqdm.auto import tqdm
 
 
19
  import logging
20
 
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
  logger = logging.getLogger(__name__)
25
+
26
+
27
  # -------------------------------
28
  # Dataset Classes
29
  # -------------------------------
30
+
31
  class ColorDataset(Dataset):
32
+ """Dataset for ColorCLIP -- returns raw text strings and CLIP-preprocessed images."""
33
+
34
+ def __init__(self, dataframe, processor):
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  self.df = dataframe.reset_index(drop=True)
36
+ self.processor = processor
37
+
 
 
 
 
 
 
38
  def __len__(self):
 
39
  return len(self.df)
40
+
41
  def __getitem__(self, idx):
 
 
 
 
 
 
 
 
 
42
  row = self.df.iloc[idx]
43
+ try:
44
+ img = Image.open(row[config.column_local_image_path]).convert("RGB")
45
+ except Exception:
46
+ return None
47
+ pixel_values = self.processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
48
+ text = str(row[config.text_column])
49
+ color = str(row[config.color_column])
50
+ return pixel_values, text, color
51
+
52
+
53
+ class PrecomputedColorDataset(Dataset):
54
+ """Dataset using pre-computed CLIP features for fast training."""
55
+
56
+ def __init__(self, image_paths, colors, image_features, text_features):
57
+ self.image_paths = image_paths
58
+ self.colors = colors
59
+ self.image_features = image_features
60
+ self.text_features = text_features
61
+
62
+ def __len__(self):
63
+ return len(self.image_paths)
64
+
65
+ def __getitem__(self, idx):
66
+ path = self.image_paths[idx]
67
+ color = self.colors[idx]
68
+ img_feat = self.image_features.get(path)
69
+ txt_feat = self.text_features.get(color)
70
+ if img_feat is None or txt_feat is None:
71
+ return None
72
+ return img_feat, txt_feat
73
+
74
+ @staticmethod
75
+ def collate(batch):
76
+ batch = [b for b in batch if b is not None]
77
+ if not batch:
78
+ return None
79
+ imgs, txts = zip(*batch)
80
+ return torch.stack(imgs, 0), torch.stack(txts, 0)
81
+
82
 
83
  # -------------------------------
84
+ # Collate Function
85
  # -------------------------------
86
+
87
+ def collate_batch(batch):
88
+ """Collate for ColorDataset -- filters None, stacks images, keeps text as lists."""
89
+ batch = [b for b in batch if b is not None]
90
+ if len(batch) == 0:
91
+ return None
92
+ imgs, texts, colors = zip(*batch)
93
+ return torch.stack(imgs, 0), list(texts), list(colors)
94
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # -------------------------------
97
+ # Model
98
  # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  class ColorCLIP(nn.Module):
101
  """
102
+ Color model: frozen CLIP encoders + trainable linear projections to 16D.
103
+
104
+ Replaces the earlier custom tokenizer / ResNet18 approach with CLIP's full
105
+ encoders, giving CLIP-level text understanding in a compact 16-dimensional
106
+ space.
107
  """
108
+
109
+ CLIP_MODEL_NAME = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
110
+
111
+ def __init__(self, embedding_dim: int = config.color_emb_dim,
112
+ clip_model_name: str | None = None):
 
 
 
 
113
  super().__init__()
114
+ from transformers import CLIPModel as _CLIPModel, CLIPProcessor as _CLIPProc
115
+
116
  self.embedding_dim = embedding_dim
117
+ self.clip_model_name = clip_model_name or self.CLIP_MODEL_NAME
118
+
119
+ # Frozen CLIP backbone
120
+ self.clip = _CLIPModel.from_pretrained(self.clip_model_name)
121
+ self.processor = _CLIPProc.from_pretrained(self.clip_model_name)
122
+ for p in self.clip.parameters():
123
+ p.requires_grad = False
124
+
125
+ clip_dim = self.clip.config.projection_dim # 512
126
+ self.image_projection = nn.Linear(clip_dim, embedding_dim)
127
+ self.text_projection = nn.Linear(clip_dim, embedding_dim)
128
+
129
+ # ------ forward / embedding helpers ------
130
+
131
+ def forward(self, pixel_values: torch.Tensor, texts: list[str]):
132
+ """Return (image_emb, text_emb) each [B, embedding_dim], L2-normalised."""
133
+ device = pixel_values.device
134
+ with torch.no_grad():
135
+ image_features = self.clip.get_image_features(pixel_values=pixel_values)
136
+ text_inputs = self.processor(
137
+ text=texts, padding=True, truncation=True, return_tensors="pt"
138
+ )
139
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
140
+ text_features = self.clip.get_text_features(**text_inputs)
141
+
142
+ img_emb = F.normalize(self.image_projection(image_features), dim=-1)
143
+ txt_emb = F.normalize(self.text_projection(text_features), dim=-1)
144
+ return img_emb, txt_emb
145
+
146
+ def get_text_embeddings(self, texts: list[str]) -> torch.Tensor:
147
+ """Return text embeddings [B, embedding_dim]."""
148
+ device = next(self.parameters()).device
 
 
 
 
149
  with torch.no_grad():
150
+ text_inputs = self.processor(
151
+ text=texts, padding=True, truncation=True, return_tensors="pt"
152
+ )
153
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
154
+ text_features = self.clip.get_text_features(**text_inputs)
155
+ return F.normalize(self.text_projection(text_features), dim=-1)
156
+
157
+ def get_image_embeddings(self, pixel_values: torch.Tensor) -> torch.Tensor:
158
+ """Return image embeddings [B, embedding_dim] from preprocessed pixel_values."""
159
+ with torch.no_grad():
160
+ image_features = self.clip.get_image_features(pixel_values=pixel_values)
161
+ return F.normalize(self.image_projection(image_features), dim=-1)
162
+
163
+ # ------ serialization ------
164
+
165
+ def save_checkpoint(self, path: str):
166
+ """Save only the trainable projection weights (small file)."""
167
+ torch.save({
168
+ "model_version": "v2",
169
+ "embedding_dim": self.embedding_dim,
170
+ "clip_model_name": self.clip_model_name,
171
+ "image_projection": self.image_projection.state_dict(),
172
+ "text_projection": self.text_projection.state_dict(),
173
+ }, path)
174
+
175
  @classmethod
176
+ def from_checkpoint(cls, path: str, device: torch.device | str = "cpu"):
177
+ """Load a ColorCLIP model from a checkpoint."""
178
+ ckpt = torch.load(path, map_location=device)
179
+ model = cls(
180
+ embedding_dim=ckpt["embedding_dim"],
181
+ clip_model_name=ckpt.get("clip_model_name", cls.CLIP_MODEL_NAME),
182
+ )
183
+ model.image_projection.load_state_dict(ckpt["image_projection"])
184
+ model.text_projection.load_state_dict(ckpt["text_projection"])
185
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  model.eval()
187
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  # -------------------------------
191
+ # Loss Functions
192
  # -------------------------------
193
+
194
  def clip_loss(image_emb, text_emb, temperature=0.07):
195
  """
196
  CLIP contrastive loss function.
197
+
198
  Args:
199
  image_emb: Image embeddings [batch_size, embedding_dim]
200
  text_emb: Text embeddings [batch_size, embedding_dim]
201
  temperature: Temperature scaling parameter
202
+
203
  Returns:
204
  Contrastive loss value
205
  """
 
209
  loss_t2i = F.cross_entropy(logits.T, labels)
210
  return (loss_i2t + loss_t2i) / 2
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # -------------------------------
214
+ # Training
215
+ # -------------------------------
216
 
217
+ def train_color():
218
+ """Train ColorCLIP using pre-computed CLIP features (fast)."""
219
+ from pathlib import Path
220
+ batch_size = 256
221
+ lr = 1e-3
222
+ epochs = 30
223
+ temperature = 0.07
224
 
225
+ device = config.device
226
+ print(f"Using device: {device}")
227
+
228
+ # Load pre-computed features
229
+ feat_dir = Path(config.local_dataset_path).parent
230
+ img_feat_path = feat_dir / "clip_image_features.pt"
231
+ txt_feat_path = feat_dir / "clip_text_features.pt"
 
 
232
 
233
+ if not img_feat_path.exists() or not txt_feat_path.exists():
234
+ print("Pre-computed features not found. Run data/precompute_clip_features.py first.")
235
+ return
236
 
237
+ print("Loading pre-computed CLIP features...")
238
+ image_features = torch.load(img_feat_path, map_location="cpu")
239
+ text_features = torch.load(txt_feat_path, map_location="cpu")
240
+ print(f" Image features: {len(image_features)}, Text features: {len(text_features)}")
241
 
242
+ # Load data
 
243
  df = pd.read_csv(config.local_dataset_path)
244
+ main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange',
245
+ 'pink', 'purple', 'red', 'white', 'yellow']
 
246
  df = df[df[config.color_column].isin(main_colors)].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ # Filter to rows with pre-computed features
249
+ df = df[df[config.column_local_image_path].isin(image_features.keys())]
250
+ df = df[df[config.color_column].isin(text_features.keys())]
251
+ df = df.reset_index(drop=True)
252
+ print(f"Training samples (with features): {len(df)}")
253
+
254
+ # Split 90/10
255
+ df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
256
+ split_idx = int(0.9 * len(df))
257
+ df_train = df.iloc[:split_idx]
258
+ df_test = df.iloc[split_idx:]
259
+
260
+ train_ds = PrecomputedColorDataset(
261
+ df_train[config.column_local_image_path].tolist(),
262
+ df_train[config.color_column].tolist(),
263
+ image_features, text_features,
264
+ )
265
+ test_ds = PrecomputedColorDataset(
266
+ df_test[config.column_local_image_path].tolist(),
267
+ df_test[config.color_column].tolist(),
268
+ image_features, text_features,
269
+ )
270
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
271
+ collate_fn=PrecomputedColorDataset.collate, num_workers=0)
272
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
273
+ collate_fn=PrecomputedColorDataset.collate, num_workers=0)
274
+
275
+ # Create model (only projection layers)
276
+ clip_dim = 512
277
+ emb_dim = config.color_emb_dim
278
+ image_proj = nn.Linear(clip_dim, emb_dim).to(device)
279
+ text_proj = nn.Linear(clip_dim, emb_dim).to(device)
280
+
281
+ params = list(image_proj.parameters()) + list(text_proj.parameters())
282
+ optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)
283
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
284
+
285
+ best_test_loss = float("inf")
286
+ save_path = config.color_model_path
287
 
288
  for epoch in range(epochs):
289
+ image_proj.train()
290
+ text_proj.train()
291
+ train_losses = []
292
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} train", leave=False):
293
  if batch is None:
 
294
  continue
295
+ img_feat, txt_feat = batch
296
+ img_feat, txt_feat = img_feat.to(device), txt_feat.to(device)
297
+
 
298
  optimizer.zero_grad()
299
+ img_emb = F.normalize(image_proj(img_feat), dim=-1)
300
+ txt_emb = F.normalize(text_proj(txt_feat), dim=-1)
301
+ loss = clip_loss(img_emb, txt_emb, temperature)
302
  loss.backward()
303
  optimizer.step()
304
+ train_losses.append(loss.item())
305
+
306
+ scheduler.step()
307
+ avg_train = sum(train_losses) / len(train_losses) if train_losses else 0
308
+
309
+ # Eval
310
+ image_proj.eval()
311
+ text_proj.eval()
 
 
 
 
 
312
  test_losses = []
313
  with torch.no_grad():
 
314
  for batch in test_loader:
315
  if batch is None:
 
316
  continue
317
+ img_feat, txt_feat = batch
318
+ img_feat, txt_feat = img_feat.to(device), txt_feat.to(device)
319
+ img_emb = F.normalize(image_proj(img_feat), dim=-1)
320
+ txt_emb = F.normalize(text_proj(txt_feat), dim=-1)
321
+ test_losses.append(clip_loss(img_emb, txt_emb, temperature).item())
322
+ avg_test = sum(test_losses) / len(test_losses) if test_losses else 0
323
+
324
+ print(f"Epoch {epoch+1}/{epochs} train={avg_train:.4f} test={avg_test:.4f} lr={scheduler.get_last_lr()[0]:.2e}")
325
+
326
+ if avg_test < best_test_loss:
327
+ best_test_loss = avg_test
328
+ torch.save({
329
+ "model_version": "v2",
330
+ "embedding_dim": emb_dim,
331
+ "clip_model_name": ColorCLIP.CLIP_MODEL_NAME,
332
+ "image_projection": image_proj.state_dict(),
333
+ "text_projection": text_proj.state_dict(),
334
+ }, save_path)
335
+ print(f" -> Saved best model (test_loss={avg_test:.4f})")
336
+
337
+ print(f"\nTraining complete. Best test loss: {best_test_loss:.4f}")
338
+ print(f"Model saved to: {save_path}")
339
+
340
+
341
+ if __name__ == "__main__":
342
+ train_color()