Leacb4 commited on
Commit
1b41378
·
verified ·
1 Parent(s): 5014bf7

Delete models/color_model.py

Browse files
Files changed (1) hide show
  1. models/color_model.py +0 -567
models/color_model.py DELETED
@@ -1,567 +0,0 @@
1
- """
2
- ColorCLIP model for learning color-aligned embeddings.
3
- This file contains the ColorCLIP model that learns to encode images and texts
4
- in an embedding space specialized for color representation. It includes
5
- a ResNet-based image encoder, a text encoder with custom tokenizer,
6
- and contrastive loss functions for training.
7
- """
8
-
9
- import config
10
- import os
11
- import json
12
- import torch
13
- from torch.utils.data import Dataset, DataLoader
14
- from torchvision import transforms, models
15
- from PIL import Image
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import pandas as pd
19
- from tqdm.auto import tqdm
20
- from collections import defaultdict
21
- from typing import Optional, List
22
- import logging
23
-
24
-
25
- # Configure logging
26
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
- logger = logging.getLogger(__name__)
28
- # -------------------------------
29
- # Dataset Classes
30
- # -------------------------------
31
- class ColorDataset(Dataset):
32
- """
33
- Dataset class for color embedding training.
34
-
35
- Handles loading images from local paths and tokenizing text descriptions
36
- for training the ColorCLIP model.
37
- """
38
-
39
- def __init__(self, dataframe, tokenizer, transform=None):
40
- """
41
- Initialize the color dataset.
42
-
43
- Args:
44
- dataframe: DataFrame with columns for image paths and text descriptions
45
- tokenizer: Tokenizer instance that converts text to list of integers (tokens)
46
- transform: Optional image transformations (default: standard ImageNet normalization)
47
- """
48
- self.df = dataframe.reset_index(drop=True)
49
- self.tokenizer = tokenizer
50
- self.transform = transform or transforms.Compose([
51
- transforms.Resize((224,224)),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[0.485,0.456,0.406],
54
- std=[0.229,0.224,0.225])
55
- ])
56
-
57
- def __len__(self):
58
- """Return the number of samples in the dataset."""
59
- return len(self.df)
60
-
61
- def __getitem__(self, idx):
62
- """
63
- Get a sample from the dataset.
64
-
65
- Args:
66
- idx: Index of the sample
67
-
68
- Returns:
69
- Tuple of (image_tensor, token_tensor)
70
- """
71
- row = self.df.iloc[idx]
72
- img = Image.open(config.column_local_image_path).convert("RGB")
73
- img = self.transform(img)
74
- tokens = torch.tensor(self.tokenizer(row[config.text_column]), dtype=torch.long)
75
- return img, tokens
76
-
77
- # -------------------------------
78
- # Tokenizer
79
- # -------------------------------
80
- class Tokenizer:
81
- """
82
- Tokenizer for extracting color-related keywords from text.
83
-
84
- This tokenizer filters text to keep only color-related words and basic
85
- descriptive words, then maps them to integer indices for embedding.
86
- """
87
-
88
- def __init__(self):
89
- """
90
- Initialize the tokenizer.
91
-
92
- Creates empty word-to-index and index-to-word mappings.
93
- Index 0 is reserved for padding/unknown tokens.
94
- """
95
- self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
96
- self.idx2word = {}
97
- self.counter = 1
98
-
99
- def preprocess_text(self, text):
100
- """
101
- Extract color-related keywords from text.
102
-
103
- Args:
104
- text: Input text string
105
-
106
- Returns:
107
- Preprocessed text containing only color and descriptive keywords
108
- """
109
- # Color-related keywords to keep
110
- color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
111
- 'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime',
112
- 'violet', 'turquoise', 'teal', 'tan', 'snow', 'silver', 'plum',
113
- 'olive', 'fuchsia', 'gold', 'cream', 'ivory', 'maroon']
114
-
115
- # Keep only color-related words and basic descriptive words
116
- descriptive_words = ['shirt', 'dress', 'top', 'bottom', 'shoe', 'bag', 'hat', 'short', 'long', 'sleeve']
117
-
118
- words = text.lower().split()
119
- filtered_words = []
120
- for word in words:
121
- # Keep color words and some descriptive words
122
- if word in color_keywords or word in descriptive_words:
123
- filtered_words.append(word)
124
-
125
- return ' '.join(filtered_words) if filtered_words else text.lower()
126
-
127
- def fit(self, texts):
128
- """
129
- Build vocabulary from a list of texts.
130
-
131
- Args:
132
- texts: List of text strings to build vocabulary from
133
- """
134
- for text in texts:
135
- processed_text = self.preprocess_text(text)
136
- for word in processed_text.split():
137
- if word not in self.word2idx:
138
- self.word2idx[word] = self.counter
139
- self.idx2word[self.counter] = word
140
- self.counter += 1
141
-
142
- def __call__(self, text):
143
- """
144
- Tokenize a text string into a list of integer indices.
145
-
146
- Args:
147
- text: Input text string
148
-
149
- Returns:
150
- List of integer token indices
151
- """
152
- processed_text = self.preprocess_text(text)
153
- return [self.word2idx[word] for word in processed_text.split()]
154
-
155
- def load_vocab(self, word2idx_dict):
156
- """
157
- Load vocabulary from a word-to-index dictionary.
158
-
159
- Args:
160
- word2idx_dict: Dictionary mapping words to indices
161
- """
162
- self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()})
163
- self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
164
- self.counter = max(self.word2idx.values(), default=0) + 1
165
-
166
- # -------------------------------
167
- # Model Components
168
- # -------------------------------
169
- class ImageEncoder(nn.Module):
170
- """
171
- Image encoder based on ResNet18 for extracting image embeddings.
172
-
173
- Uses a pretrained ResNet18 backbone and replaces the final layer
174
- to output embeddings of the specified dimension.
175
- """
176
-
177
- def __init__(self, embedding_dim=config.color_emb_dim):
178
- """
179
- Initialize the image encoder.
180
-
181
- Args:
182
- embedding_dim: Dimension of the output embedding (default: color_emb_dim)
183
- """
184
- super().__init__()
185
- self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
186
- self.backbone.fc = nn.Sequential(
187
- nn.Dropout(0.1), # Add regularization
188
- nn.Linear(self.backbone.fc.in_features, embedding_dim)
189
- )
190
-
191
- def forward(self, x):
192
- """
193
- Forward pass through the image encoder.
194
-
195
- Args:
196
- x: Image tensor [batch_size, channels, height, width]
197
-
198
- Returns:
199
- Normalized image embeddings [batch_size, embedding_dim]
200
- """
201
- x = self.backbone(x)
202
- return F.normalize(x, dim=-1)
203
-
204
- class TextEncoder(nn.Module):
205
- """
206
- Text encoder for extracting text embeddings from token sequences.
207
-
208
- Uses an embedding layer followed by mean pooling (with optional length normalization)
209
- and a linear projection to the output embedding dimension.
210
- """
211
-
212
- def __init__(self, vocab_size, embedding_dim=config.color_emb_dim):
213
- """
214
- Initialize the text encoder.
215
-
216
- Args:
217
- vocab_size: Size of the vocabulary
218
- embedding_dim: Dimension of the output embedding (default: color_emb_dim)
219
- """
220
- super().__init__()
221
- self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
222
- self.dropout = nn.Dropout(0.1) # Add regularization
223
- self.fc = nn.Linear(32, embedding_dim)
224
-
225
- def forward(self, x, lengths=None):
226
- """
227
- Forward pass through the text encoder.
228
-
229
- Args:
230
- x: Token tensor [batch_size, sequence_length]
231
- lengths: Optional sequence lengths tensor [batch_size] for proper mean pooling
232
-
233
- Returns:
234
- Normalized text embeddings [batch_size, embedding_dim]
235
- """
236
- emb = self.embedding(x) # [B, T, 32]
237
- emb = self.dropout(emb) # Apply dropout
238
- if lengths is not None:
239
- summed = emb.sum(dim=1) # [B, 32]
240
- mean = summed / lengths.unsqueeze(1).clamp_min(1)
241
- else:
242
- mean = emb.mean(dim=1)
243
- return F.normalize(self.fc(mean), dim=-1)
244
-
245
- class ColorCLIP(nn.Module):
246
- """
247
- Color CLIP model for learning color-aligned image-text embeddings.
248
- """
249
- def __init__(self, vocab_size, embedding_dim=config.color_emb_dim, tokenizer=None):
250
- """
251
- Initialize ColorCLIP model.
252
-
253
- Args:
254
- vocab_size: Size of the vocabulary for text encoding
255
- embedding_dim: Dimension of the embedding space (default: color_emb_dim)
256
- tokenizer: Optional Tokenizer instance (will create one if None)
257
- """
258
- super().__init__()
259
- self.vocab_size = vocab_size
260
- self.embedding_dim = embedding_dim
261
- self.image_encoder = ImageEncoder(embedding_dim)
262
- self.text_encoder = TextEncoder(vocab_size, embedding_dim)
263
- self.tokenizer = tokenizer
264
-
265
- def forward(self, image, text, lengths=None):
266
- """
267
- Forward pass through the model.
268
-
269
- Args:
270
- image: Image tensor [B, C, H, W]
271
- text: Text token tensor [B, T]
272
- lengths: Optional sequence lengths tensor [B]
273
-
274
- Returns:
275
- Tuple of (image_embeddings, text_embeddings)
276
- """
277
- return self.image_encoder(image), self.text_encoder(text, lengths)
278
-
279
- def get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
280
- """
281
- Get text embeddings for a list of text strings.
282
-
283
- Args:
284
- texts: List of text strings
285
-
286
- Returns:
287
- Text embeddings tensor [batch_size, embedding_dim]
288
- """
289
- if self.tokenizer is None:
290
- raise ValueError("Tokenizer must be set before calling get_text_embeddings")
291
-
292
- token_lists = [self.tokenizer(t) for t in texts]
293
- max_len = max((len(toks) for toks in token_lists), default=0)
294
- padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists]
295
- input_ids = torch.tensor(padded, dtype=torch.long, device=next(self.parameters()).device)
296
- lengths = torch.tensor([len(toks) for toks in token_lists], dtype=torch.long, device=input_ids.device)
297
- with torch.no_grad():
298
- emb = self.text_encoder(input_ids, lengths)
299
- return emb
300
-
301
- @classmethod
302
- def from_pretrained(cls, model_path: str, vocab_path: Optional[str] = None, device: str = "cpu", repo_id: Optional[str] = None):
303
- """
304
- Load a pretrained ColorCLIP model from a file path or Hugging Face Hub.
305
-
306
- Args:
307
- model_path: Path to the model checkpoint (.pt file) or filename if using repo_id
308
- vocab_path: Optional path to tokenizer vocabulary JSON file or filename if using repo_id
309
- device: Device to load the model on (default: "cpu")
310
- repo_id: Optional Hugging Face repository ID (e.g., "username/model-name")
311
- If provided, model_path and vocab_path should be filenames within the repo
312
-
313
- Returns:
314
- ColorCLIP model instance
315
-
316
- Example:
317
- # Load from local file
318
- model = ColorCLIP.from_pretrained("color_model.pt", "tokenizer_vocab.json")
319
-
320
- # Load from Hugging Face Hub
321
- from huggingface_hub import hf_hub_download
322
- model_file = hf_hub_download(repo_id="username/model-name", filename="color_model.pt")
323
- vocab_file = hf_hub_download(repo_id="username/model-name", filename="tokenizer_vocab.json")
324
- model = ColorCLIP.from_pretrained(model_file, vocab_file)
325
- """
326
- device_obj = torch.device(device)
327
-
328
- # Support loading from Hugging Face Hub if repo_id is provided
329
- if repo_id:
330
- try:
331
- from huggingface_hub import hf_hub_download
332
- model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
333
- if vocab_path:
334
- vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_path)
335
- except ImportError:
336
- raise ImportError("huggingface_hub is required to load models from Hugging Face. Install it with: pip install huggingface-hub")
337
-
338
- # Load model checkpoint
339
- checkpoint = torch.load(model_path, map_location=device_obj)
340
-
341
- # Extract vocab size and embedding dimension from checkpoint
342
- if isinstance(checkpoint, dict):
343
- # Try to get vocab_size from metadata first
344
- vocab_size = checkpoint.get('vocab_size', None)
345
- embedding_dim = checkpoint.get('embedding_dim', 16)
346
-
347
- # If not in metadata, try to infer from model state
348
- if vocab_size is None:
349
- state_dict = checkpoint.get('model_state_dict', checkpoint)
350
- if 'text_encoder.embedding.weight' in state_dict:
351
- vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
352
- else:
353
- raise ValueError("Could not determine vocab_size from checkpoint")
354
-
355
- # Load state dict
356
- state_dict = checkpoint.get('model_state_dict', checkpoint)
357
- else:
358
- raise ValueError("Checkpoint must be a dictionary")
359
-
360
- # Initialize model
361
- model = cls(vocab_size=vocab_size, embedding_dim=embedding_dim)
362
- model.load_state_dict(state_dict)
363
- model = model.to(device_obj)
364
-
365
- # Load tokenizer if vocab path is provided
366
- if vocab_path and os.path.exists(vocab_path):
367
- tokenizer = Tokenizer()
368
- with open(vocab_path, 'r') as f:
369
- vocab_dict = json.load(f)
370
- tokenizer.load_vocab(vocab_dict)
371
- model.tokenizer = tokenizer
372
-
373
- model.eval()
374
- return model
375
-
376
- def save_pretrained(self, save_directory: str, vocab_path: Optional[str] = None):
377
- """
378
- Save the model and optionally the tokenizer vocabulary.
379
-
380
- Args:
381
- save_directory: Directory to save the model
382
- vocab_path: Optional path to save tokenizer vocabulary
383
- """
384
- os.makedirs(save_directory, exist_ok=True)
385
-
386
- # Save model checkpoint
387
- model_path = os.path.join(save_directory, config.color_model_path)
388
- checkpoint = {
389
- 'model_state_dict': self.state_dict(),
390
- 'vocab_size': self.vocab_size,
391
- 'embedding_dim': self.embedding_dim
392
- }
393
- torch.save(checkpoint, model_path)
394
-
395
- # Save tokenizer vocabulary if available
396
- if self.tokenizer is not None:
397
- vocab_dict = dict(self.tokenizer.word2idx)
398
- if vocab_path is None:
399
- vocab_path = os.path.join(save_directory, config.tokeniser_path)
400
- with open(vocab_path, 'w') as f:
401
- json.dump(vocab_dict, f)
402
-
403
- return model_path, vocab_path
404
-
405
-
406
- # -------------------------------
407
- # Loss Functions and Utilities
408
- # -------------------------------
409
- def clip_loss(image_emb, text_emb, temperature=0.07):
410
- """
411
- CLIP contrastive loss function.
412
-
413
- Args:
414
- image_emb: Image embeddings [batch_size, embedding_dim]
415
- text_emb: Text embeddings [batch_size, embedding_dim]
416
- temperature: Temperature scaling parameter
417
-
418
- Returns:
419
- Contrastive loss value
420
- """
421
- logits = image_emb @ text_emb.T / temperature
422
- labels = torch.arange(len(image_emb), device=image_emb.device)
423
- loss_i2t = F.cross_entropy(logits, labels)
424
- loss_t2i = F.cross_entropy(logits.T, labels)
425
- return (loss_i2t + loss_t2i) / 2
426
-
427
- def collate_batch(batch):
428
- """
429
- Collate function for DataLoader that pads sequences and filters None values.
430
-
431
- Args:
432
- batch: List of (image, tokens) tuples or None
433
-
434
- Returns:
435
- Tuple of (images, padded_tokens, lengths) or None if batch is empty
436
- """
437
- batch = [b for b in batch if b is not None]
438
- if len(batch) == 0:
439
- return None
440
- imgs, tokens = zip(*batch)
441
- imgs = torch.stack(imgs, dim=0)
442
- lengths = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
443
- tokens_padded = nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)
444
- return imgs, tokens_padded, lengths
445
-
446
-
447
-
448
- if __name__ == "__main__":
449
- """
450
- Training script for ColorCLIP model.
451
- This code only runs when the file is executed directly, not when imported.
452
- """
453
- # Configuration
454
- batch_size = 16
455
- lr = 1e-4
456
- epochs=50
457
-
458
-
459
-
460
- # Load dataset and split train/test
461
- tokenizer = Tokenizer()
462
- df = pd.read_csv(config.local_dataset_path)
463
-
464
- # Data preparation: Reduce to main colors only (11 classes instead of 34)
465
- main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
466
- df = df[df[config.color_column].isin(main_colors)].copy()
467
- print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors")
468
- print(f"🎨 Colors: {sorted(df[config.color_column].unique())}")
469
-
470
- tokenizer.fit(df[config.text_column].tolist())
471
-
472
- # Filter only rows with a valid local file
473
- df_local = df[df[config.column_local_image_path].astype(str).str.len() > 0]
474
- df_local = df_local[df_local[config.column_local_image_path].apply(lambda p: os.path.isfile(p))]
475
- df_local = df_local.reset_index(drop=True)
476
-
477
-
478
- # split 90/10
479
- df_local = df_local.sample(frac=1.0, random_state=42).reset_index(drop=True)
480
- split_idx = int(0.9 * len(df_local))
481
- df_train = df_local.iloc[:split_idx].reset_index(drop=True)
482
- df_test = df_local.iloc[split_idx:].reset_index(drop=True)
483
-
484
-
485
- train_dataset = ColorDataset(df_train, tokenizer)
486
- test_dataset = ColorDataset(df_test, tokenizer)
487
-
488
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)
489
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=0)
490
-
491
- device = config.device
492
- print(f"Using device: {device}")
493
-
494
- model = ColorCLIP(vocab_size=tokenizer.counter, embedding_dim=config.color_emb_dim, tokenizer=tokenizer).to(device)
495
- optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) # Add weight decay
496
-
497
- # Save tokenizer vocab once (or update) so evaluation can reload the same mapping
498
- here = os.path.dirname(__file__)
499
- vocab_out = os.path.join(here, config.tokeniser_path)
500
- with open(vocab_out, "w") as f:
501
- json.dump(dict(tokenizer.word2idx), f)
502
- print(f"Tokenizer vocabulary saved to: {vocab_out}")
503
-
504
-
505
- for epoch in range(epochs):
506
- model.train()
507
- pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - train", leave=False)
508
- epoch_losses = []
509
- for batch in train_loader:
510
- if batch is None:
511
- pbar.update(1)
512
- continue
513
- imgs, texts, lengths = batch
514
- imgs = imgs.to(device)
515
- texts = texts.to(device)
516
- lengths = lengths.to(device)
517
- optimizer.zero_grad()
518
- img_emb, text_emb = model(imgs, texts, lengths)
519
- loss = clip_loss(img_emb, text_emb)
520
- loss.backward()
521
- optimizer.step()
522
- epoch_losses.append(loss.item())
523
- pbar.set_postfix({"loss": f"{loss.item():.4f}", "avg": f"{sum(epoch_losses)/len(epoch_losses):.4f}"})
524
- pbar.update(1)
525
- pbar.close()
526
-
527
- avg_train_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else None
528
- if avg_train_loss is not None:
529
- print(f"[Train] Epoch {epoch+1}/{epochs} - avg loss: {avg_train_loss:.4f}")
530
- else:
531
- print(f"[Train] Epoch {epoch+1}/{epochs} - no valid batches")
532
-
533
- # Eval rapide sur test avec barre
534
- model.eval()
535
- test_losses = []
536
- with torch.no_grad():
537
- pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{epochs} - test", leave=False)
538
- for batch in test_loader:
539
- if batch is None:
540
- pbar_t.update(1)
541
- continue
542
- imgs, texts, lengths = batch
543
- imgs = imgs.to(device)
544
- texts = texts.to(device)
545
- lengths = lengths.to(device)
546
- img_emb, text_emb = model(imgs, texts, lengths)
547
- test_losses.append(clip_loss(img_emb, text_emb).item())
548
- pbar_t.update(1)
549
- pbar_t.close()
550
- if len(test_losses) > 0:
551
- avg_test_loss = sum(test_losses) / len(test_losses)
552
- print(f"[Test ] Epoch {epoch+1}/{epochs} - avg loss: {avg_test_loss:.4f}")
553
- else:
554
- print(f"[Test ] Epoch {epoch+1}/{epochs} - no valid batches")
555
-
556
- # --- Save checkpoint at every epoch ---
557
- ckpt_dir = here
558
- latest_path = os.path.join(ckpt_dir, config.color_model_path)
559
- epoch_path = os.path.join(ckpt_dir, f"color_model_epoch_{epoch+1}.pt")
560
- checkpoint = {
561
- 'model_state_dict': model.state_dict(),
562
- 'vocab_size': model.vocab_size,
563
- 'embedding_dim': model.embedding_dim
564
- }
565
- torch.save(checkpoint, latest_path)
566
- torch.save(checkpoint, epoch_path)
567
- print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}")