Leacb4 commited on
Commit
347015b
·
verified ·
1 Parent(s): eface4c

Upload color_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. color_model.py +528 -0
color_model.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torchvision import transforms, models
7
+ from PIL import Image
8
+ import requests
9
+ from io import BytesIO
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import pandas as pd
13
+ from tqdm.auto import tqdm
14
+
15
+ import asyncio
16
+ import aiohttp
17
+ import pandas as pd
18
+ import os
19
+ from pathlib import Path
20
+ from tqdm.asyncio import tqdm
21
+ import ssl
22
+ import logging
23
+ from typing import Optional, List, Tuple
24
+ from urllib.parse import urlparse
25
+ import hashlib
26
+ from config import local_dataset_path
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
+ logger = logging.getLogger(__name__)
31
+
32
+ class ImageDownloader:
33
+ """Enhanced image downloader with better error handling, retry logic, and progress tracking."""
34
+
35
+ def __init__(self,
36
+ output_dir: str = "athleta_images",
37
+ max_concurrent: int = 10,
38
+ timeout: int = 30,
39
+ retry_attempts: int = 3,
40
+ verify_ssl: bool = True):
41
+ """
42
+ Initialize the ImageDownloader.
43
+
44
+ Args:
45
+ output_dir: Directory to save downloaded images
46
+ max_concurrent: Maximum number of concurrent downloads
47
+ timeout: Request timeout in seconds
48
+ retry_attempts: Number of retry attempts for failed downloads
49
+ verify_ssl: Whether to verify SSL certificates
50
+ """
51
+ self.output_dir = Path(output_dir)
52
+ self.max_concurrent = max_concurrent
53
+ self.timeout = aiohttp.ClientTimeout(total=timeout)
54
+ self.retry_attempts = retry_attempts
55
+ self.verify_ssl = verify_ssl
56
+
57
+ # Create output directory
58
+ self.output_dir.mkdir(exist_ok=True)
59
+
60
+ # Statistics
61
+ self.stats = {
62
+ 'total': 0,
63
+ 'downloaded': 0,
64
+ 'skipped': 0,
65
+ 'failed': 0,
66
+ 'retries': 0
67
+ }
68
+
69
+ def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
70
+ """Create SSL context based on verification settings."""
71
+ if not self.verify_ssl:
72
+ ssl_context = ssl.create_default_context()
73
+ ssl_context.check_hostname = False
74
+ ssl_context.verify_mode = ssl.CERT_NONE
75
+ return ssl_context
76
+ return None
77
+
78
+ def _generate_filename(self, url: str, index: int) -> str:
79
+ """Generate a safe filename from URL or index."""
80
+ try:
81
+ # Try to extract filename from URL
82
+ parsed_url = urlparse(url)
83
+ filename = os.path.basename(parsed_url.path)
84
+ if filename and '.' in filename:
85
+ return filename
86
+ except Exception:
87
+ pass
88
+
89
+ # Fallback: use URL hash or index
90
+ try:
91
+ url_hash = hashlib.md5(url.encode()).hexdigest()[:8]
92
+ return f"image_{url_hash}.jpg"
93
+ except Exception:
94
+ return f"image_{index}.jpg"
95
+
96
+ async def _download_single_image(self,
97
+ session: aiohttp.ClientSession,
98
+ url: str,
99
+ save_path: Path,
100
+ index: int) -> bool:
101
+ """
102
+ Download a single image with retry logic.
103
+
104
+ Returns:
105
+ bool: True if successful, False otherwise
106
+ """
107
+ for attempt in range(self.retry_attempts):
108
+ try:
109
+ if attempt > 0:
110
+ self.stats['retries'] += 1
111
+ logger.info(f"Retry {attempt}/{self.retry_attempts} for {url}")
112
+
113
+ ssl_context = self._create_ssl_context()
114
+ connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
115
+
116
+ async with session.get(url, ssl=ssl_context, connector=connector) as response:
117
+ if response.status == 200:
118
+ content = await response.read()
119
+
120
+ # Validate that it's actually an image
121
+ if len(content) < 1024: # Too small to be a real image
122
+ logger.warning(f"Image too small, skipping: {url}")
123
+ return False
124
+
125
+ # Ensure directory exists
126
+ save_path.parent.mkdir(parents=True, exist_ok=True)
127
+
128
+ # Write file
129
+ with open(save_path, 'wb') as f:
130
+ f.write(content)
131
+
132
+ logger.debug(f"Successfully downloaded: {save_path}")
133
+ return True
134
+
135
+ elif response.status == 404:
136
+ logger.warning(f"Image not found (404): {url}")
137
+ return False
138
+
139
+ else:
140
+ logger.warning(f"HTTP {response.status} for {url}")
141
+ if attempt == self.retry_attempts - 1:
142
+ return False
143
+
144
+ except asyncio.TimeoutError:
145
+ logger.warning(f"Timeout downloading {url} (attempt {attempt + 1})")
146
+ if attempt == self.retry_attempts - 1:
147
+ return False
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error downloading {url}: {str(e)}")
151
+ if attempt == self.retry_attempts - 1:
152
+ return False
153
+
154
+ return False
155
+
156
+ async def _download_batch(self,
157
+ session: aiohttp.ClientSession,
158
+ batch: List[Tuple[str, Path, int]]) -> None:
159
+ """Download a batch of images concurrently."""
160
+ semaphore = asyncio.Semaphore(self.max_concurrent)
161
+
162
+ async def download_with_semaphore(url, save_path, index):
163
+ async with semaphore:
164
+ if save_path.exists():
165
+ logger.debug(f"File already exists, skipping: {save_path}")
166
+ self.stats['skipped'] += 1
167
+ return
168
+
169
+ success = await self._download_single_image(session, url, save_path, index)
170
+ if success:
171
+ self.stats['downloaded'] += 1
172
+ else:
173
+ self.stats['failed'] += 1
174
+
175
+ tasks = [download_with_semaphore(url, save_path, index)
176
+ for url, save_path, index in batch]
177
+ await asyncio.gather(*tasks, return_exceptions=True)
178
+
179
+ def _prepare_download_tasks(self, df: pd.DataFrame) -> List[Tuple[str, Path, int]]:
180
+ """Prepare download tasks from DataFrame."""
181
+ tasks = []
182
+
183
+ for index, row in df.iterrows():
184
+ # Check if image URL is valid
185
+ if pd.isna(row.get('image')) or not isinstance(row.get('image'), str):
186
+ logger.debug(f"Skipping row {index}: invalid image URL")
187
+ continue
188
+
189
+ url = row['image'].strip()
190
+ if not url or not url.startswith(('http://', 'https://')):
191
+ logger.debug(f"Skipping row {index}: invalid URL format")
192
+ continue
193
+
194
+ # Generate filename
195
+ filename = self._generate_filename(url, index)
196
+ save_path = self.output_dir / filename
197
+
198
+ tasks.append((url, save_path, index))
199
+
200
+ return tasks
201
+
202
+ async def download_all_images(self, df: pd.DataFrame) -> None:
203
+ """Download all images from the DataFrame."""
204
+ logger.info("Preparing download tasks...")
205
+ tasks = self._prepare_download_tasks(df)
206
+ self.stats['total'] = len(tasks)
207
+
208
+ if not tasks:
209
+ logger.warning("No valid image URLs found in the dataset")
210
+ return
211
+
212
+ logger.info(f"Found {len(tasks)} valid image URLs to download")
213
+
214
+ # Create session with proper configuration
215
+ ssl_context = self._create_ssl_context()
216
+ connector = aiohttp.TCPConnector(ssl=ssl_context) if ssl_context else None
217
+
218
+ async with aiohttp.ClientSession(
219
+ timeout=self.timeout,
220
+ connector=connector,
221
+ headers={'User-Agent': 'Mozilla/5.0 (compatible; ImageDownloader/1.0)'}
222
+ ) as session:
223
+
224
+ # Process in batches to avoid overwhelming the server
225
+ batch_size = self.max_concurrent * 2
226
+ for i in range(0, len(tasks), batch_size):
227
+ batch = tasks[i:i + batch_size]
228
+ logger.info(f"Processing batch {i//batch_size + 1}/{(len(tasks)-1)//batch_size + 1}")
229
+
230
+ await self._download_batch(session, batch)
231
+
232
+ # Small delay between batches to be respectful
233
+ if i + batch_size < len(tasks):
234
+ await asyncio.sleep(1)
235
+
236
+ def print_statistics(self) -> None:
237
+ """Print download statistics."""
238
+ logger.info("Download Statistics:")
239
+ logger.info(f" Total URLs processed: {self.stats['total']}")
240
+ logger.info(f" Successfully downloaded: {self.stats['downloaded']}")
241
+ logger.info(f" Skipped (already exists): {self.stats['skipped']}")
242
+ logger.info(f" Failed: {self.stats['failed']}")
243
+ logger.info(f" Retry attempts: {self.stats['retries']}")
244
+
245
+ if self.stats['total'] > 0:
246
+ success_rate = (self.stats['downloaded'] / self.stats['total']) * 100
247
+ logger.info(f" Success rate: {success_rate:.1f}%")
248
+
249
+
250
+ import os
251
+ import time
252
+ import json
253
+ import torch
254
+ from torch.utils.data import Dataset, DataLoader
255
+ from torchvision import transforms, models
256
+ from PIL import Image
257
+ import requests
258
+ from io import BytesIO
259
+ import torch.nn as nn
260
+ import torch.nn.functional as F
261
+ import pandas as pd
262
+ from tqdm.auto import tqdm
263
+
264
+ class ColorDataset(Dataset):
265
+ def __init__(self, dataframe, tokenizer, transform=None):
266
+ """
267
+ dataframe : pd.DataFrame avec colonnes 'image_url' et 'text'
268
+ tokenizer : fonction qui convertit texte -> list d'entiers (tokens)
269
+ transform : transformations image
270
+ """
271
+ self.df = dataframe.reset_index(drop=True)
272
+ self.tokenizer = tokenizer
273
+ self.transform = transform or transforms.Compose([
274
+ transforms.Resize((224,224)),
275
+ transforms.ToTensor(),
276
+ transforms.Normalize(mean=[0.485,0.456,0.406],
277
+ std=[0.229,0.224,0.225])
278
+ ])
279
+
280
+ def __len__(self):
281
+ return len(self.df)
282
+
283
+ def __getitem__(self, idx):
284
+ row = self.df.iloc[idx]
285
+ try:
286
+ src = row.get('local_image_path', None)
287
+ if not src or not os.path.isfile(src):
288
+ return None # filtered by collate
289
+ img = Image.open(src).convert("RGB")
290
+ img = self.transform(img)
291
+ tokens = torch.tensor(self.tokenizer(row['text']), dtype=torch.long)
292
+ return img, tokens
293
+ except Exception:
294
+ return None
295
+
296
+ from collections import defaultdict
297
+
298
+ class SimpleTokenizer:
299
+ def __init__(self):
300
+ self.word2idx = defaultdict(lambda: 0) # 0 = pad/unknown
301
+ self.idx2word = {}
302
+ self.counter = 1
303
+
304
+ def preprocess_text(self, text):
305
+ """Extract color-related keywords from text"""
306
+ # Color-related keywords to keep
307
+ color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
308
+ 'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime',
309
+ 'violet', 'turquoise', 'teal', 'tan', 'snow', 'silver', 'plum',
310
+ 'olive', 'fuchsia', 'gold', 'cream', 'ivory', 'maroon']
311
+
312
+ # Keep only color-related words and basic descriptive words
313
+ descriptive_words = ['shirt', 'dress', 'top', 'bottom', 'shoe', 'bag', 'hat', 'short', 'long', 'sleeve']
314
+
315
+ words = text.lower().split()
316
+ filtered_words = []
317
+ for word in words:
318
+ # Keep color words and some descriptive words
319
+ if word in color_keywords or word in descriptive_words:
320
+ filtered_words.append(word)
321
+
322
+ return ' '.join(filtered_words) if filtered_words else text.lower()
323
+
324
+ def fit(self, texts):
325
+ for text in texts:
326
+ processed_text = self.preprocess_text(text)
327
+ for word in processed_text.split():
328
+ if word not in self.word2idx:
329
+ self.word2idx[word] = self.counter
330
+ self.idx2word[self.counter] = word
331
+ self.counter += 1
332
+
333
+ def __call__(self, text):
334
+ processed_text = self.preprocess_text(text)
335
+ return [self.word2idx[word] for word in processed_text.split()]
336
+
337
+ def load_vocab(self, word2idx_dict):
338
+ self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()})
339
+ self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0}
340
+ self.counter = max(self.word2idx.values(), default=0) + 1
341
+
342
+ class ImageEncoder(nn.Module):
343
+ def __init__(self, embedding_dim=16):
344
+ super().__init__()
345
+ self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
346
+ self.backbone.fc = nn.Sequential(
347
+ nn.Dropout(0.1), # Add regularization
348
+ nn.Linear(self.backbone.fc.in_features, embedding_dim)
349
+ )
350
+
351
+ def forward(self, x):
352
+ x = self.backbone(x)
353
+ return F.normalize(x, dim=-1)
354
+
355
+ class TextEncoder(nn.Module):
356
+ def __init__(self, vocab_size, embedding_dim=16):
357
+ super().__init__()
358
+ self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) # Keep 32 dimensions
359
+ self.dropout = nn.Dropout(0.1) # Add regularization
360
+ self.fc = nn.Linear(32, embedding_dim)
361
+
362
+ def forward(self, x, lengths=None):
363
+ emb = self.embedding(x) # [B, T, 32]
364
+ emb = self.dropout(emb) # Apply dropout
365
+ if lengths is not None:
366
+ summed = emb.sum(dim=1) # [B, 32]
367
+ mean = summed / lengths.unsqueeze(1).clamp_min(1)
368
+ else:
369
+ mean = emb.mean(dim=1)
370
+ return F.normalize(self.fc(mean), dim=-1)
371
+
372
+ class ColorCLIP(nn.Module):
373
+ def __init__(self, vocab_size, embedding_dim=16): # Keep 16 dimensions
374
+ super().__init__()
375
+ self.image_encoder = ImageEncoder(embedding_dim)
376
+ self.text_encoder = TextEncoder(vocab_size, embedding_dim)
377
+
378
+ def forward(self, image, text, lengths=None):
379
+ return self.image_encoder(image), self.text_encoder(text, lengths)
380
+
381
+ def get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
382
+ token_lists = [self.tokenizer(t) for t in texts]
383
+ max_len = max((len(toks) for toks in token_lists), default=0)
384
+ padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists]
385
+ input_ids = torch.tensor(padded, dtype=torch.long, device=next(self.parameters()).device)
386
+ lengths = torch.tensor([len(toks) for toks in token_lists], dtype=torch.long, device=input_ids.device)
387
+ with torch.no_grad():
388
+ emb = self.text_encoder(input_ids, lengths)
389
+ return emb
390
+
391
+
392
+ def clip_loss(image_emb, text_emb, temperature=0.07):
393
+ logits = image_emb @ text_emb.T / temperature
394
+ labels = torch.arange(len(image_emb), device=image_emb.device)
395
+ loss_i2t = F.cross_entropy(logits, labels)
396
+ loss_t2i = F.cross_entropy(logits.T, labels)
397
+ return (loss_i2t + loss_t2i) / 2
398
+
399
+ # Collate qui pad les séquences et filtre les None
400
+ def collate_batch(batch):
401
+ batch = [b for b in batch if b is not None]
402
+ if len(batch) == 0:
403
+ return None
404
+ imgs, tokens = zip(*batch)
405
+ imgs = torch.stack(imgs, dim=0)
406
+ lengths = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
407
+ tokens_padded = nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)
408
+ return imgs, tokens_padded, lengths
409
+
410
+
411
+
412
+ if __name__ == "__main__":
413
+ # Chargement + split train/test + cache local
414
+ tokenizer = SimpleTokenizer()
415
+ df = pd.read_csv('df_color_with_local_paths.csv')
416
+
417
+ # Reduce to main colors only (11 classes instead of 34)
418
+ main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
419
+ df = df[df['color'].isin(main_colors)].copy()
420
+ print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors")
421
+ print(f"🎨 Colors: {sorted(df['color'].unique())}")
422
+
423
+ tokenizer.fit(df['text'].tolist())
424
+
425
+ # If no local paths column, download/calc it once
426
+ if 'local_image_path' not in df.columns or df['local_image_path'].isna().all():
427
+ downloader = ImageDownloader(
428
+ csv_path='new/df_color_with_local_paths.csv',
429
+ images_dir='data/images',
430
+ max_workers=16,
431
+ timeout=10
432
+ )
433
+ df_local = downloader.download_all_images()
434
+ else:
435
+ df_local = df
436
+
437
+ # Filter only rows with a valid local file
438
+ df_local = df_local[df_local['local_image_path'].astype(str).str.len() > 0]
439
+ df_local = df_local[df_local['local_image_path'].apply(lambda p: os.path.isfile(p))]
440
+ df_local = df_local.reset_index(drop=True)
441
+
442
+
443
+ # split 90/10
444
+ df_local = df_local.sample(frac=1.0, random_state=42).reset_index(drop=True)
445
+ split_idx = int(0.9 * len(df_local))
446
+ df_train = df_local.iloc[:split_idx].reset_index(drop=True)
447
+ df_test = df_local.iloc[split_idx:].reset_index(drop=True)
448
+
449
+
450
+ train_dataset = ColorDataset(df_train, tokenizer)
451
+ test_dataset = ColorDataset(df_test, tokenizer)
452
+
453
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch, num_workers=0)
454
+ test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_batch, num_workers=0)
455
+
456
+ device = "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"
457
+ print(f"Using device: {device}")
458
+
459
+ model = ColorCLIP(vocab_size=tokenizer.counter).to(device)
460
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) # Add weight decay
461
+
462
+ # Save tokenizer vocab once (or update) so evaluation can reload the same mapping
463
+ here = os.path.dirname(__file__)
464
+ vocab_out = os.path.join(here, "tokenizer_vocab.json")
465
+ with open(vocab_out, "w") as f:
466
+ json.dump(dict(tokenizer.word2idx), f)
467
+
468
+
469
+ from collections import defaultdict
470
+
471
+
472
+ EPOCHS = 50 # Increased from 10 to 50
473
+ for epoch in range(EPOCHS):
474
+ model.train()
475
+ pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{EPOCHS} - train", leave=False)
476
+ last_loss = None
477
+ for batch in train_loader:
478
+ if batch is None:
479
+ pbar.update(1)
480
+ continue
481
+ imgs, texts, lengths = batch
482
+ imgs = imgs.to(device)
483
+ texts = texts.to(device)
484
+ lengths = lengths.to(device)
485
+ optimizer.zero_grad()
486
+ img_emb, text_emb = model(imgs, texts, lengths)
487
+ loss = clip_loss(img_emb, text_emb)
488
+ loss.backward()
489
+ optimizer.step()
490
+ last_loss = loss.item()
491
+ pbar.set_postfix({"loss": f"{last_loss:.4f}"})
492
+ pbar.update(1)
493
+ pbar.close()
494
+ if last_loss is not None:
495
+ print(f"[Train] Epoch {epoch+1}/{EPOCHS} - last batch loss: {last_loss:.4f}")
496
+ else:
497
+ print(f"[Train] Epoch {epoch+1}/{EPOCHS} - no valid batches")
498
+
499
+ # Eval rapide sur test avec barre
500
+ model.eval()
501
+ test_losses = []
502
+ with torch.no_grad():
503
+ pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{EPOCHS} - test", leave=False)
504
+ for batch in test_loader:
505
+ if batch is None:
506
+ pbar_t.update(1)
507
+ continue
508
+ imgs, texts, lengths = batch
509
+ imgs = imgs.to(device)
510
+ texts = texts.to(device)
511
+ lengths = lengths.to(device)
512
+ img_emb, text_emb = model(imgs, texts, lengths)
513
+ test_losses.append(clip_loss(img_emb, text_emb).item())
514
+ pbar_t.update(1)
515
+ pbar_t.close()
516
+ if len(test_losses) > 0:
517
+ print(f"[Test ] Epoch {epoch+1}/{EPOCHS} - avg loss: {sum(test_losses)/len(test_losses):.4f}")
518
+ else:
519
+ print(f"[Test ] Epoch {epoch+1}/{EPOCHS} - no valid batches")
520
+
521
+ # --- Save checkpoint at every epoch ---
522
+ ckpt_dir = here
523
+ latest_path = os.path.join(ckpt_dir, "colorclip_image_text.pt")
524
+ epoch_path = os.path.join(ckpt_dir, f"colorclip_image_text_epoch_{epoch+1}.pt")
525
+ state_dict = model.state_dict()
526
+ torch.save(state_dict, latest_path)
527
+ torch.save(state_dict, epoch_path)
528
+ print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}")