Leacb4 commited on
Commit
398de18
·
verified ·
1 Parent(s): 29ff4f9

Upload color_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. color_model.py +240 -317
color_model.py CHANGED
@@ -1,272 +1,31 @@
 
1
  import os
2
- import time
3
  import json
4
  import torch
5
  from torch.utils.data import Dataset, DataLoader
6
  from torchvision import transforms, models
7
  from PIL import Image
8
- import requests
9
- from io import BytesIO
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import pandas as pd
13
- from tqdm.auto import tqdm
14
-
15
- import asyncio
16
- import aiohttp
17
- import pandas as pd
18
- import os
19
- from pathlib import Path
20
- from tqdm.asyncio import tqdm
21
- import ssl
22
  import logging
23
- from typing import Optional, List, Tuple
24
- from urllib.parse import urlparse
25
- import hashlib
26
- from config import local_dataset_path
27
 
28
  # Configure logging
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
  logger = logging.getLogger(__name__)
31
-
32
- class ImageDownloader:
33
- """Enhanced image downloader with better error handling, retry logic, and progress tracking."""
34
-
35
- def __init__(self,
36
- output_dir: str = "athleta_images",
37
- max_concurrent: int = 10,
38
- timeout: int = 30,
39
- retry_attempts: int = 3,
40
- verify_ssl: bool = True):
41
- """
42
- Initialize the ImageDownloader.
43
-
44
- Args:
45
- output_dir: Directory to save downloaded images
46
- max_concurrent: Maximum number of concurrent downloads
47
- timeout: Request timeout in seconds
48
- retry_attempts: Number of retry attempts for failed downloads
49
- verify_ssl: Whether to verify SSL certificates
50
- """
51
- self.output_dir = Path(output_dir)
52
- self.max_concurrent = max_concurrent
53
- self.timeout = aiohttp.ClientTimeout(total=timeout)
54
- self.retry_attempts = retry_attempts
55
- self.verify_ssl = verify_ssl
56
-
57
- # Create output directory
58
- self.output_dir.mkdir(exist_ok=True)
59
-
60
- # Statistics
61
- self.stats = {
62
- 'total': 0,
63
- 'downloaded': 0,
64
- 'skipped': 0,
65
- 'failed': 0,
66
- 'retries': 0
67
- }
68
-
69
- def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
70
- """Create SSL context based on verification settings."""
71
- if not self.verify_ssl:
72
- ssl_context = ssl.create_default_context()
73
- ssl_context.check_hostname = False
74
- ssl_context.verify_mode = ssl.CERT_NONE
75
- return ssl_context
76
- return None
77
-
78
- def _generate_filename(self, url: str, index: int) -> str:
79
- """Generate a safe filename from URL or index."""
80
- try:
81
- # Try to extract filename from URL
82
- parsed_url = urlparse(url)
83
- filename = os.path.basename(parsed_url.path)
84
- if filename and '.' in filename:
85
- return filename
86
- except Exception:
87
- pass
88
-
89
- # Fallback: use URL hash or index
90
- try:
91
- url_hash = hashlib.md5(url.encode()).hexdigest()[:8]
92
- return f"image_{url_hash}.jpg"
93
- except Exception:
94
- return f"image_{index}.jpg"
95
-
96
- async def _download_single_image(self,
97
- session: aiohttp.ClientSession,
98
- url: str,
99
- save_path: Path,
100
- index: int) -> bool:
101
- """
102
- Download a single image with retry logic.
103
-
104
- Returns:
105
- bool: True if successful, False otherwise
106
- """
107
- for attempt in range(self.retry_attempts):
108
- try:
109
- if attempt > 0:
110
- self.stats['retries'] += 1
111
- logger.info(f"Retry {attempt}/{self.retry_attempts} for {url}")
112
-
113
- ssl_context = self._create_ssl_context()
114
- connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
115
-
116
- async with session.get(url, ssl=ssl_context, connector=connector) as response:
117
- if response.status == 200:
118
- content = await response.read()
119
-
120
- # Validate that it's actually an image
121
- if len(content) < 1024: # Too small to be a real image
122
- logger.warning(f"Image too small, skipping: {url}")
123
- return False
124
-
125
- # Ensure directory exists
126
- save_path.parent.mkdir(parents=True, exist_ok=True)
127
-
128
- # Write file
129
- with open(save_path, 'wb') as f:
130
- f.write(content)
131
-
132
- logger.debug(f"Successfully downloaded: {save_path}")
133
- return True
134
-
135
- elif response.status == 404:
136
- logger.warning(f"Image not found (404): {url}")
137
- return False
138
-
139
- else:
140
- logger.warning(f"HTTP {response.status} for {url}")
141
- if attempt == self.retry_attempts - 1:
142
- return False
143
-
144
- except asyncio.TimeoutError:
145
- logger.warning(f"Timeout downloading {url} (attempt {attempt + 1})")
146
- if attempt == self.retry_attempts - 1:
147
- return False
148
-
149
- except Exception as e:
150
- logger.error(f"Error downloading {url}: {str(e)}")
151
- if attempt == self.retry_attempts - 1:
152
- return False
153
-
154
- return False
155
-
156
- async def _download_batch(self,
157
- session: aiohttp.ClientSession,
158
- batch: List[Tuple[str, Path, int]]) -> None:
159
- """Download a batch of images concurrently."""
160
- semaphore = asyncio.Semaphore(self.max_concurrent)
161
-
162
- async def download_with_semaphore(url, save_path, index):
163
- async with semaphore:
164
- if save_path.exists():
165
- logger.debug(f"File already exists, skipping: {save_path}")
166
- self.stats['skipped'] += 1
167
- return
168
-
169
- success = await self._download_single_image(session, url, save_path, index)
170
- if success:
171
- self.stats['downloaded'] += 1
172
- else:
173
- self.stats['failed'] += 1
174
-
175
- tasks = [download_with_semaphore(url, save_path, index)
176
- for url, save_path, index in batch]
177
- await asyncio.gather(*tasks, return_exceptions=True)
178
-
179
- def _prepare_download_tasks(self, df: pd.DataFrame) -> List[Tuple[str, Path, int]]:
180
- """Prepare download tasks from DataFrame."""
181
- tasks = []
182
-
183
- for index, row in df.iterrows():
184
- # Check if image URL is valid
185
- if pd.isna(row.get('image')) or not isinstance(row.get('image'), str):
186
- logger.debug(f"Skipping row {index}: invalid image URL")
187
- continue
188
-
189
- url = row['image'].strip()
190
- if not url or not url.startswith(('http://', 'https://')):
191
- logger.debug(f"Skipping row {index}: invalid URL format")
192
- continue
193
-
194
- # Generate filename
195
- filename = self._generate_filename(url, index)
196
- save_path = self.output_dir / filename
197
-
198
- tasks.append((url, save_path, index))
199
-
200
- return tasks
201
-
202
- async def download_all_images(self, df: pd.DataFrame) -> None:
203
- """Download all images from the DataFrame."""
204
- logger.info("Preparing download tasks...")
205
- tasks = self._prepare_download_tasks(df)
206
- self.stats['total'] = len(tasks)
207
-
208
- if not tasks:
209
- logger.warning("No valid image URLs found in the dataset")
210
- return
211
-
212
- logger.info(f"Found {len(tasks)} valid image URLs to download")
213
-
214
- # Create session with proper configuration
215
- ssl_context = self._create_ssl_context()
216
- connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
217
-
218
- async with aiohttp.ClientSession(
219
- timeout=self.timeout,
220
- connector=connector,
221
- headers={'User-Agent': 'Mozilla/5.0 (compatible; ImageDownloader/1.0)'}
222
- ) as session:
223
-
224
- # Process in batches to avoid overwhelming the server
225
- batch_size = self.max_concurrent * 2
226
- for i in range(0, len(tasks), batch_size):
227
- batch = tasks[i:i + batch_size]
228
- logger.info(f"Processing batch {i//batch_size + 1}/{(len(tasks)-1)//batch_size + 1}")
229
-
230
- await self._download_batch(session, batch)
231
-
232
- # Small delay between batches to be respectful
233
- if i + batch_size < len(tasks):
234
- await asyncio.sleep(1)
235
-
236
- def print_statistics(self) -> None:
237
- """Print download statistics."""
238
- logger.info("Download Statistics:")
239
- logger.info(f" Total URLs processed: {self.stats['total']}")
240
- logger.info(f" Successfully downloaded: {self.stats['downloaded']}")
241
- logger.info(f" Skipped (already exists): {self.stats['skipped']}")
242
- logger.info(f" Failed: {self.stats['failed']}")
243
- logger.info(f" Retry attempts: {self.stats['retries']}")
244
-
245
- if self.stats['total'] > 0:
246
- success_rate = (self.stats['downloaded'] / self.stats['total']) * 100
247
- logger.info(f" Success rate: {success_rate:.1f}%")
248
-
249
-
250
- import os
251
- import time
252
- import json
253
- import torch
254
- from torch.utils.data import Dataset, DataLoader
255
- from torchvision import transforms, models
256
- from PIL import Image
257
- import requests
258
- from io import BytesIO
259
- import torch.nn as nn
260
- import torch.nn.functional as F
261
- import pandas as pd
262
- from tqdm.auto import tqdm
263
-
264
  class ColorDataset(Dataset):
265
  def __init__(self, dataframe, tokenizer, transform=None):
266
  """
267
- dataframe : pd.DataFrame avec colonnes 'image_url' et 'text'
268
- tokenizer : fonction qui convertit texte -> list d'entiers (tokens)
269
- transform : transformations image
270
  """
271
  self.df = dataframe.reset_index(drop=True)
272
  self.tokenizer = tokenizer
@@ -282,20 +41,15 @@ class ColorDataset(Dataset):
282
 
283
  def __getitem__(self, idx):
284
  row = self.df.iloc[idx]
285
- try:
286
- src = row.get('local_image_path', None)
287
- if not src or not os.path.isfile(src):
288
- return None # filtered by collate
289
- img = Image.open(src).convert("RGB")
290
- img = self.transform(img)
291
- tokens = torch.tensor(self.tokenizer(row['text']), dtype=torch.long)
292
- return img, tokens
293
- except Exception:
294
- return None
295
-
296
- from collections import defaultdict
297
-
298
- class SimpleTokenizer:
299
  def __init__(self):
300
  self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
301
  self.idx2word = {}
@@ -339,8 +93,11 @@ class SimpleTokenizer:
339
  self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
340
  self.counter = max(self.word2idx.values(), default=0) + 1
341
 
 
 
 
342
  class ImageEncoder(nn.Module):
343
- def __init__(self, embedding_dim=16):
344
  super().__init__()
345
  self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
346
  self.backbone.fc = nn.Sequential(
@@ -353,7 +110,7 @@ class ImageEncoder(nn.Module):
353
  return F.normalize(x, dim=-1)
354
 
355
  class TextEncoder(nn.Module):
356
- def __init__(self, vocab_size, embedding_dim=16):
357
  super().__init__()
358
  self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
359
  self.dropout = nn.Dropout(0.1) # Add regularization
@@ -370,15 +127,52 @@ class TextEncoder(nn.Module):
370
  return F.normalize(self.fc(mean), dim=-1)
371
 
372
  class ColorCLIP(nn.Module):
373
- def __init__(self, vocab_size, embedding_dim=16): # Keep 16 dimensions
 
 
 
 
 
 
 
 
 
 
 
374
  super().__init__()
 
 
375
  self.image_encoder = ImageEncoder(embedding_dim)
376
  self.text_encoder = TextEncoder(vocab_size, embedding_dim)
 
377
 
378
  def forward(self, image, text, lengths=None):
 
 
 
 
 
 
 
 
 
 
 
379
  return self.image_encoder(image), self.text_encoder(text, lengths)
380
 
381
  def get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
382
  token_lists = [self.tokenizer(t) for t in texts]
383
  max_len = max((len(toks) for toks in token_lists), default=0)
384
  padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists]
@@ -387,17 +181,143 @@ class ColorCLIP(nn.Module):
387
  with torch.no_grad():
388
  emb = self.text_encoder(input_ids, lengths)
389
  return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
 
 
 
 
392
  def clip_loss(image_emb, text_emb, temperature=0.07):
 
 
 
 
 
 
 
 
 
 
 
393
  logits = image_emb @ text_emb.T / temperature
394
  labels = torch.arange(len(image_emb), device=image_emb.device)
395
  loss_i2t = F.cross_entropy(logits, labels)
396
  loss_t2i = F.cross_entropy(logits.T, labels)
397
  return (loss_i2t + loss_t2i) / 2
398
 
399
- # Collate qui pad les séquences et filtre les None
400
  def collate_batch(batch):
 
 
 
 
 
 
 
 
 
401
  batch = [b for b in batch if b is not None]
402
  if len(batch) == 0:
403
  return None
@@ -410,33 +330,32 @@ def collate_batch(batch):
410
 
411
 
412
  if __name__ == "__main__":
413
- # Chargement + split train/test + cache local
414
- tokenizer = SimpleTokenizer()
415
- df = pd.read_csv('df_color_with_local_paths.csv')
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- # Reduce to main colors only (11 classes instead of 34)
418
  main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
419
- df = df[df['color'].isin(main_colors)].copy()
420
  print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors")
421
- print(f"🎨 Colors: {sorted(df['color'].unique())}")
422
 
423
- tokenizer.fit(df['text'].tolist())
424
-
425
- # If no local paths column, download/calc it once
426
- if 'local_image_path' not in df.columns or df['local_image_path'].isna().all():
427
- downloader = ImageDownloader(
428
- csv_path='new/df_color_with_local_paths.csv',
429
- images_dir='data/images',
430
- max_workers=16,
431
- timeout=10
432
- )
433
- df_local = downloader.download_all_images()
434
- else:
435
- df_local = df
436
 
437
  # Filter only rows with a valid local file
438
- df_local = df_local[df_local['local_image_path'].astype(str).str.len() > 0]
439
- df_local = df_local[df_local['local_image_path'].apply(lambda p: os.path.isfile(p))]
440
  df_local = df_local.reset_index(drop=True)
441
 
442
 
@@ -450,30 +369,27 @@ if __name__ == "__main__":
450
  train_dataset = ColorDataset(df_train, tokenizer)
451
  test_dataset = ColorDataset(df_test, tokenizer)
452
 
453
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch, num_workers=0)
454
- test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_batch, num_workers=0)
455
 
456
- device = "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"
457
  print(f"Using device: {device}")
458
 
459
- model = ColorCLIP(vocab_size=tokenizer.counter).to(device)
460
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) # Add weight decay
461
 
462
  # Save tokenizer vocab once (or update) so evaluation can reload the same mapping
463
  here = os.path.dirname(__file__)
464
- vocab_out = os.path.join(here, "tokenizer_vocab.json")
465
  with open(vocab_out, "w") as f:
466
  json.dump(dict(tokenizer.word2idx), f)
 
467
 
468
 
469
- from collections import defaultdict
470
-
471
-
472
- EPOCHS = 50 # Increased from 10 to 50
473
- for epoch in range(EPOCHS):
474
  model.train()
475
- pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS} - train", leave=False)
476
- last_loss = None
477
  for batch in train_loader:
478
  if batch is None:
479
  pbar.update(1)
@@ -487,20 +403,22 @@ if __name__ == "__main__":
487
  loss = clip_loss(img_emb, text_emb)
488
  loss.backward()
489
  optimizer.step()
490
- last_loss = loss.item()
491
- pbar.set_postfix({"loss": f"{last_loss:.4f}"})
492
  pbar.update(1)
493
  pbar.close()
494
- if last_loss is not None:
495
- print(f"[Train] Epoch {epoch+1}/{EPOCHS} - last batch loss: {last_loss:.4f}")
 
 
496
  else:
497
- print(f"[Train] Epoch {epoch+1}/{EPOCHS} - no valid batches")
498
 
499
  # Eval rapide sur test avec barre
500
  model.eval()
501
  test_losses = []
502
  with torch.no_grad():
503
- pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{EPOCHS} - test", leave=False)
504
  for batch in test_loader:
505
  if batch is None:
506
  pbar_t.update(1)
@@ -514,15 +432,20 @@ if __name__ == "__main__":
514
  pbar_t.update(1)
515
  pbar_t.close()
516
  if len(test_losses) > 0:
517
- print(f"[Test ] Epoch {epoch+1}/{EPOCHS} - avg loss: {sum(test_losses)/len(test_losses):.4f}")
 
518
  else:
519
- print(f"[Test ] Epoch {epoch+1}/{EPOCHS} - no valid batches")
520
 
521
  # --- Save checkpoint at every epoch ---
522
  ckpt_dir = here
523
- latest_path = os.path.join(ckpt_dir, "colorclip_image_text.pt")
524
- epoch_path = os.path.join(ckpt_dir, f"colorclip_image_text_epoch_{epoch+1}.pt")
525
- state_dict = model.state_dict()
526
- torch.save(state_dict, latest_path)
527
- torch.save(state_dict, epoch_path)
 
 
 
 
528
  print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}")
 
1
+ import config
2
  import os
 
3
  import json
4
  import torch
5
  from torch.utils.data import Dataset, DataLoader
6
  from torchvision import transforms, models
7
  from PIL import Image
 
 
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  import pandas as pd
11
+ from tqdm.auto import tqdm
12
+ from collections import defaultdict
13
+ from typing import Optional, List
 
 
 
 
 
 
14
  import logging
15
+
 
 
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
  logger = logging.getLogger(__name__)
20
+ # -------------------------------
21
+ # Dataset Classes
22
+ # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class ColorDataset(Dataset):
24
  def __init__(self, dataframe, tokenizer, transform=None):
25
  """
26
+ dataframe : pd.DataFrame with columns image and text columns
27
+ tokenizer : function that converts text -> list of integers (tokens)
28
+ transform : transformations on the image
29
  """
30
  self.df = dataframe.reset_index(drop=True)
31
  self.tokenizer = tokenizer
 
41
 
42
  def __getitem__(self, idx):
43
  row = self.df.iloc[idx]
44
+ img = Image.open(config.column_local_image_path).convert("RGB")
45
+ img = self.transform(img)
46
+ tokens = torch.tensor(self.tokenizer(row[config.text_column]), dtype=torch.long)
47
+ return img, tokens
48
+
49
+ # -------------------------------
50
+ # Tokenizer
51
+ # -------------------------------
52
+ class Tokenizer:
 
 
 
 
 
53
  def __init__(self):
54
  self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
55
  self.idx2word = {}
 
93
  self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
94
  self.counter = max(self.word2idx.values(), default=0) + 1
95
 
96
+ # -------------------------------
97
+ # Model Components
98
+ # -------------------------------
99
  class ImageEncoder(nn.Module):
100
+ def __init__(self, embedding_dim=config.color_emb_dim):
101
  super().__init__()
102
  self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
103
  self.backbone.fc = nn.Sequential(
 
110
  return F.normalize(x, dim=-1)
111
 
112
  class TextEncoder(nn.Module):
113
+ def __init__(self, vocab_size, embedding_dim=config.color_emb_dim):
114
  super().__init__()
115
  self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
116
  self.dropout = nn.Dropout(0.1) # Add regularization
 
127
  return F.normalize(self.fc(mean), dim=-1)
128
 
129
  class ColorCLIP(nn.Module):
130
+ """
131
+ Color CLIP model for learning color-aligned image-text embeddings.
132
+ """
133
+ def __init__(self, vocab_size, embedding_dim=config.color_emb_dim, tokenizer=None):
134
+ """
135
+ Initialize ColorCLIP model.
136
+
137
+ Args:
138
+ vocab_size: Size of the vocabulary for text encoding
139
+ embedding_dim: Dimension of the embedding space (default: color_emb_dim)
140
+ tokenizer: Optional Tokenizer instance (will create one if None)
141
+ """
142
  super().__init__()
143
+ self.vocab_size = vocab_size
144
+ self.embedding_dim = embedding_dim
145
  self.image_encoder = ImageEncoder(embedding_dim)
146
  self.text_encoder = TextEncoder(vocab_size, embedding_dim)
147
+ self.tokenizer = tokenizer
148
 
149
  def forward(self, image, text, lengths=None):
150
+ """
151
+ Forward pass through the model.
152
+
153
+ Args:
154
+ image: Image tensor [B, C, H, W]
155
+ text: Text token tensor [B, T]
156
+ lengths: Optional sequence lengths tensor [B]
157
+
158
+ Returns:
159
+ Tuple of (image_embeddings, text_embeddings)
160
+ """
161
  return self.image_encoder(image), self.text_encoder(text, lengths)
162
 
163
  def get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
164
+ """
165
+ Get text embeddings for a list of text strings.
166
+
167
+ Args:
168
+ texts: List of text strings
169
+
170
+ Returns:
171
+ Text embeddings tensor [batch_size, embedding_dim]
172
+ """
173
+ if self.tokenizer is None:
174
+ raise ValueError("Tokenizer must be set before calling get_text_embeddings")
175
+
176
  token_lists = [self.tokenizer(t) for t in texts]
177
  max_len = max((len(toks) for toks in token_lists), default=0)
178
  padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists]
 
181
  with torch.no_grad():
182
  emb = self.text_encoder(input_ids, lengths)
183
  return emb
184
+
185
+ @classmethod
186
+ def from_pretrained(cls, model_path: str, vocab_path: Optional[str] = None, device: str = "cpu", repo_id: Optional[str] = None):
187
+ """
188
+ Load a pretrained ColorCLIP model from a file path or Hugging Face Hub.
189
+
190
+ Args:
191
+ model_path: Path to the model checkpoint (.pt file) or filename if using repo_id
192
+ vocab_path: Optional path to tokenizer vocabulary JSON file or filename if using repo_id
193
+ device: Device to load the model on (default: "cpu")
194
+ repo_id: Optional Hugging Face repository ID (e.g., "username/model-name")
195
+ If provided, model_path and vocab_path should be filenames within the repo
196
+
197
+ Returns:
198
+ ColorCLIP model instance
199
+
200
+ Example:
201
+ # Load from local file
202
+ model = ColorCLIP.from_pretrained("color_model.pt", "tokenizer_vocab.json")
203
+
204
+ # Load from Hugging Face Hub
205
+ from huggingface_hub import hf_hub_download
206
+ model_file = hf_hub_download(repo_id="username/model-name", filename="color_model.pt")
207
+ vocab_file = hf_hub_download(repo_id="username/model-name", filename="tokenizer_vocab.json")
208
+ model = ColorCLIP.from_pretrained(model_file, vocab_file)
209
+ """
210
+ device_obj = torch.device(device)
211
+
212
+ # Support loading from Hugging Face Hub if repo_id is provided
213
+ if repo_id:
214
+ try:
215
+ from huggingface_hub import hf_hub_download
216
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
217
+ if vocab_path:
218
+ vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_path)
219
+ except ImportError:
220
+ raise ImportError("huggingface_hub is required to load models from Hugging Face. Install it with: pip install huggingface-hub")
221
+
222
+ # Load model checkpoint
223
+ checkpoint = torch.load(model_path, map_location=device_obj)
224
+
225
+ # Extract vocab size and embedding dimension from checkpoint
226
+ if isinstance(checkpoint, dict):
227
+ # Try to get vocab_size from metadata first
228
+ vocab_size = checkpoint.get('vocab_size', None)
229
+ embedding_dim = checkpoint.get('embedding_dim', 16)
230
+
231
+ # If not in metadata, try to infer from model state
232
+ if vocab_size is None:
233
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
234
+ if 'text_encoder.embedding.weight' in state_dict:
235
+ vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
236
+ else:
237
+ raise ValueError("Could not determine vocab_size from checkpoint")
238
+
239
+ # Load state dict
240
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
241
+ else:
242
+ raise ValueError("Checkpoint must be a dictionary")
243
+
244
+ # Initialize model
245
+ model = cls(vocab_size=vocab_size, embedding_dim=embedding_dim)
246
+ model.load_state_dict(state_dict)
247
+ model = model.to(device_obj)
248
+
249
+ # Load tokenizer if vocab path is provided
250
+ if vocab_path and os.path.exists(vocab_path):
251
+ tokenizer = Tokenizer()
252
+ with open(vocab_path, 'r') as f:
253
+ vocab_dict = json.load(f)
254
+ tokenizer.load_vocab(vocab_dict)
255
+ model.tokenizer = tokenizer
256
+
257
+ model.eval()
258
+ return model
259
+
260
+ def save_pretrained(self, save_directory: str, vocab_path: Optional[str] = None):
261
+ """
262
+ Save the model and optionally the tokenizer vocabulary.
263
+
264
+ Args:
265
+ save_directory: Directory to save the model
266
+ vocab_path: Optional path to save tokenizer vocabulary
267
+ """
268
+ os.makedirs(save_directory, exist_ok=True)
269
+
270
+ # Save model checkpoint
271
+ model_path = os.path.join(save_directory, config.color_model_path)
272
+ checkpoint = {
273
+ 'model_state_dict': self.state_dict(),
274
+ 'vocab_size': self.vocab_size,
275
+ 'embedding_dim': self.embedding_dim
276
+ }
277
+ torch.save(checkpoint, model_path)
278
+
279
+ # Save tokenizer vocabulary if available
280
+ if self.tokenizer is not None:
281
+ vocab_dict = dict(self.tokenizer.word2idx)
282
+ if vocab_path is None:
283
+ vocab_path = os.path.join(save_directory, config.tokeniser_path)
284
+ with open(vocab_path, 'w') as f:
285
+ json.dump(vocab_dict, f)
286
+
287
+ return model_path, vocab_path
288
 
289
 
290
+ # -------------------------------
291
+ # Loss Functions and Utilities
292
+ # -------------------------------
293
  def clip_loss(image_emb, text_emb, temperature=0.07):
294
+ """
295
+ CLIP contrastive loss function.
296
+
297
+ Args:
298
+ image_emb: Image embeddings [batch_size, embedding_dim]
299
+ text_emb: Text embeddings [batch_size, embedding_dim]
300
+ temperature: Temperature scaling parameter
301
+
302
+ Returns:
303
+ Contrastive loss value
304
+ """
305
  logits = image_emb @ text_emb.T / temperature
306
  labels = torch.arange(len(image_emb), device=image_emb.device)
307
  loss_i2t = F.cross_entropy(logits, labels)
308
  loss_t2i = F.cross_entropy(logits.T, labels)
309
  return (loss_i2t + loss_t2i) / 2
310
 
 
311
  def collate_batch(batch):
312
+ """
313
+ Collate function for DataLoader that pads sequences and filters None values.
314
+
315
+ Args:
316
+ batch: List of (image, tokens) tuples or None
317
+
318
+ Returns:
319
+ Tuple of (images, padded_tokens, lengths) or None if batch is empty
320
+ """
321
  batch = [b for b in batch if b is not None]
322
  if len(batch) == 0:
323
  return None
 
330
 
331
 
332
  if __name__ == "__main__":
333
+ """
334
+ Training script for ColorCLIP model.
335
+ This code only runs when the file is executed directly, not when imported.
336
+ """
337
+ # Configuration
338
+ batch_size = 16
339
+ lr = 1e-4
340
+ epochs=50
341
+
342
+
343
+
344
+ # Load dataset and split train/test
345
+ tokenizer = Tokenizer()
346
+ df = pd.read_csv(config.local_dataset_path)
347
 
348
+ # Data preparation: Reduce to main colors only (11 classes instead of 34)
349
  main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
350
+ df = df[df[config.color_column].isin(main_colors)].copy()
351
  print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors")
352
+ print(f"🎨 Colors: {sorted(df[config.color_column].unique())}")
353
 
354
+ tokenizer.fit(df[config.text_column].tolist())
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # Filter only rows with a valid local file
357
+ df_local = df[df[config.column_local_image_path].astype(str).str.len() > 0]
358
+ df_local = df_local[df_local[config.column_local_image_path].apply(lambda p: os.path.isfile(p))]
359
  df_local = df_local.reset_index(drop=True)
360
 
361
 
 
369
  train_dataset = ColorDataset(df_train, tokenizer)
370
  test_dataset = ColorDataset(df_test, tokenizer)
371
 
372
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)
373
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=0)
374
 
375
+ device = config.device
376
  print(f"Using device: {device}")
377
 
378
+ model = ColorCLIP(vocab_size=tokenizer.counter, embedding_dim=config.color_emb_dim, tokenizer=tokenizer).to(device)
379
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) # Add weight decay
380
 
381
  # Save tokenizer vocab once (or update) so evaluation can reload the same mapping
382
  here = os.path.dirname(__file__)
383
+ vocab_out = os.path.join(here, config.tokeniser_path)
384
  with open(vocab_out, "w") as f:
385
  json.dump(dict(tokenizer.word2idx), f)
386
+ print(f"Tokenizer vocabulary saved to: {vocab_out}")
387
 
388
 
389
+ for epoch in range(epochs):
 
 
 
 
390
  model.train()
391
+ pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - train", leave=False)
392
+ epoch_losses = []
393
  for batch in train_loader:
394
  if batch is None:
395
  pbar.update(1)
 
403
  loss = clip_loss(img_emb, text_emb)
404
  loss.backward()
405
  optimizer.step()
406
+ epoch_losses.append(loss.item())
407
+ pbar.set_postfix({"loss": f"{loss.item():.4f}", "avg": f"{sum(epoch_losses)/len(epoch_losses):.4f}"})
408
  pbar.update(1)
409
  pbar.close()
410
+
411
+ avg_train_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else None
412
+ if avg_train_loss is not None:
413
+ print(f"[Train] Epoch {epoch+1}/{epochs} - avg loss: {avg_train_loss:.4f}")
414
  else:
415
+ print(f"[Train] Epoch {epoch+1}/{epochs} - no valid batches")
416
 
417
  # Eval rapide sur test avec barre
418
  model.eval()
419
  test_losses = []
420
  with torch.no_grad():
421
+ pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{epochs} - test", leave=False)
422
  for batch in test_loader:
423
  if batch is None:
424
  pbar_t.update(1)
 
432
  pbar_t.update(1)
433
  pbar_t.close()
434
  if len(test_losses) > 0:
435
+ avg_test_loss = sum(test_losses) / len(test_losses)
436
+ print(f"[Test ] Epoch {epoch+1}/{epochs} - avg loss: {avg_test_loss:.4f}")
437
  else:
438
+ print(f"[Test ] Epoch {epoch+1}/{epochs} - no valid batches")
439
 
440
  # --- Save checkpoint at every epoch ---
441
  ckpt_dir = here
442
+ latest_path = os.path.join(ckpt_dir, config.color_model_path)
443
+ epoch_path = os.path.join(ckpt_dir, f"color_model_epoch_{epoch+1}.pt")
444
+ checkpoint = {
445
+ 'model_state_dict': model.state_dict(),
446
+ 'vocab_size': model.vocab_size,
447
+ 'embedding_dim': model.embedding_dim
448
+ }
449
+ torch.save(checkpoint, latest_path)
450
+ torch.save(checkpoint, epoch_path)
451
  print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}")