Leacb4 commited on
Commit
e647a2e
·
verified ·
1 Parent(s): 805d9e9

Upload main_clip_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main_clip_model.py +639 -0
main_clip_model.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set environment variable to disable tokenizers parallelism warnings
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ from config import local_dataset_path, column_local_image_path, color_emb_dim, hierarchy_emb_dim, color_model_path, hierarchy_model_path, device, main_model_path
13
+ import matplotlib.pyplot as plt
14
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
15
+ import warnings
16
+ from tqdm import tqdm
17
+ import numpy as np
18
+
19
+ # Suppress warnings
20
+ warnings.filterwarnings("ignore", category=FutureWarning)
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+
23
+ # -------------------------------
24
+ # Step 1: Custom Training Functions
25
+ # -------------------------------
26
+
27
+ def train_one_epoch(model, train_loader, optimizer, feature_models, device, clip_processor, temperature=0.07):
28
+ """
29
+ Train the model for one epoch
30
+ """
31
+ model.train()
32
+ total_loss = 0.0
33
+ num_batches = 0
34
+
35
+ # Create progress bar for training
36
+ pbar = tqdm(train_loader, desc="Training", leave=False)
37
+
38
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
39
+ # Move data to device
40
+ images = images.to(device)
41
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
42
+
43
+ # Process text inputs
44
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
45
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
46
+
47
+ # Forward pass
48
+ optimizer.zero_grad()
49
+ outputs = model(**text_inputs, pixel_values=images)
50
+
51
+ text_features = outputs.text_embeds
52
+ image_features = outputs.image_embeds
53
+
54
+ # Get feature embeddings
55
+ # Use exact color-name embeddings if available (new color model)
56
+ if hasattr(feature_models['color'], 'get_color_name_embeddings'):
57
+ color_features = feature_models['color'].get_color_name_embeddings(colors)
58
+ else:
59
+ color_features = feature_models['color'].get_text_embeddings(colors)
60
+ hierarchy_features = feature_models['hierarchy'].get_text_embeddings(hierarchy)
61
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
62
+
63
+ # Calculate loss
64
+ loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
65
+
66
+ # Backward pass
67
+ loss.backward()
68
+ optimizer.step()
69
+
70
+ total_loss += loss.item()
71
+ num_batches += 1
72
+
73
+ # Update progress bar
74
+ pbar.set_postfix({
75
+ 'Loss': f'{loss.item():.4f}',
76
+ 'Avg Loss': f'{total_loss/num_batches:.4f}'
77
+ })
78
+
79
+ return total_loss / num_batches
80
+
81
+ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07):
82
+ """
83
+ Validate the model for one epoch
84
+ """
85
+ model.eval()
86
+ total_loss = 0.0
87
+ num_batches = 0
88
+
89
+ # Create progress bar for validation
90
+ pbar = tqdm(val_loader, desc="Validation", leave=False)
91
+
92
+ with torch.no_grad():
93
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
94
+ # Move data to device
95
+ images = images.to(device)
96
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
97
+
98
+ # Process text inputs
99
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
100
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
101
+
102
+ # Forward pass
103
+ outputs = model(**text_inputs, pixel_values=images)
104
+
105
+ text_features = outputs.text_embeds
106
+ image_features = outputs.image_embeds
107
+
108
+ # Get feature embeddings
109
+ if hasattr(feature_models['color'], 'get_color_name_embeddings'):
110
+ color_features = feature_models['color'].get_color_name_embeddings(colors)
111
+ else:
112
+ color_features = feature_models['color'].get_text_embeddings(colors)
113
+ hierarchy_features = feature_models['hierarchy'].get_text_embeddings(hierarchy)
114
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
115
+
116
+ # Calculate loss
117
+ loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
118
+
119
+ total_loss += loss.item()
120
+ num_batches += 1
121
+
122
+ # Update progress bar
123
+ pbar.set_postfix({
124
+ 'Loss': f'{loss.item():.4f}',
125
+ 'Avg Loss': f'{total_loss/num_batches:.4f}'
126
+ })
127
+
128
+ return total_loss / num_batches
129
+
130
+ def triple_contrastive_loss(text_features, image_features, attribute_features, temperature=0.07):
131
+ """
132
+ Calculate triple contrastive loss
133
+ """
134
+ text_features = F.normalize(text_features, dim=-1)
135
+ image_features = F.normalize(image_features, dim=-1)
136
+ attribute_features = F.normalize(attribute_features, dim=-1)
137
+
138
+ text_image_logits = (text_features[:, color_emb_dim+hierarchy_emb_dim:] @ image_features[:, color_emb_dim+hierarchy_emb_dim:].T) / temperature
139
+ text_attr_logits = (text_features[:, :color_emb_dim+hierarchy_emb_dim] @ attribute_features.T) / temperature
140
+ image_attr_logits = (attribute_features @ image_features[:,:color_emb_dim+hierarchy_emb_dim].T) / temperature
141
+
142
+ # Weight distribution
143
+ weight_text_image = 0.7
144
+ weight_attr_based = 0.15
145
+
146
+ logits = (weight_text_image * text_image_logits +
147
+ weight_attr_based * text_attr_logits +
148
+ weight_attr_based * image_attr_logits)
149
+
150
+ labels = torch.arange(len(text_features)).to(text_features.device)
151
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
152
+
153
+ return loss
154
+
155
+ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
156
+ color_model, colors, temperature=0.07, alignment_weight=0.3):
157
+ """
158
+ Enhanced contrastive loss with direct alignment between color model and main model
159
+
160
+ Args:
161
+ text_features: Main model text embeddings
162
+ image_features: Main model image embeddings
163
+ attribute_features: Concatenated color + hierarchy features
164
+ color_model: Pre-trained color model
165
+ colors: List of color strings for this batch
166
+ temperature: Temperature for contrastive loss
167
+ alignment_weight: Weight for the alignment loss
168
+ """
169
+
170
+ # Original triple contrastive loss
171
+ text_features_norm = F.normalize(text_features, dim=-1)
172
+ image_features_norm = F.normalize(image_features, dim=-1)
173
+ attribute_features_norm = F.normalize(attribute_features, dim=-1)
174
+
175
+ text_image_logits = (text_features_norm[:, color_emb_dim+hierarchy_emb_dim:] @
176
+ image_features_norm[:, color_emb_dim+hierarchy_emb_dim:].T) / temperature
177
+ text_attr_logits = (text_features_norm[:, :color_emb_dim+hierarchy_emb_dim] @
178
+ attribute_features_norm.T) / temperature
179
+ image_attr_logits = (attribute_features_norm @
180
+ image_features_norm[:,:color_emb_dim+hierarchy_emb_dim].T) / temperature
181
+
182
+ # Weight distribution for original loss
183
+ weight_text_image = 0.7
184
+ weight_attr_based = 0.15
185
+
186
+ original_logits = (weight_text_image * text_image_logits +
187
+ weight_attr_based * text_attr_logits +
188
+ weight_attr_based * image_attr_logits)
189
+
190
+ labels = torch.arange(len(text_features)).to(text_features.device)
191
+ original_loss = (F.cross_entropy(original_logits, labels) +
192
+ F.cross_entropy(original_logits.T, labels)) / 2
193
+
194
+ # NEW: Direct alignment loss between color model and main model first 16 logits
195
+ with torch.no_grad():
196
+ color_embeddings = color_model.get_text_embeddings(colors) # [batch_size, 16]
197
+
198
+ # Extract first 16 dimensions from main model text embeddings
199
+ main_color_text = text_features[:, :color_emb_dim] # [batch_size, 16]
200
+ main_color_image = image_features[:, :color_emb_dim] # [batch_size, 16]
201
+
202
+ # Normalize for better correlation
203
+ color_embeddings_norm = F.normalize(color_embeddings, dim=-1)
204
+ main_color_text_norm = F.normalize(main_color_text, dim=-1)
205
+ main_color_image_norm = F.normalize(main_color_image, dim=-1)
206
+
207
+ # Direct alignment loss using MSE and cosine similarity
208
+ text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm)
209
+ image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm)
210
+
211
+ # Also encourage high cosine similarity
212
+ text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
213
+ image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
214
+
215
+ alignment_loss = (text_alignment_loss + image_alignment_loss +
216
+ text_cosine_loss + image_cosine_loss) / 4
217
+
218
+ # Combine losses
219
+ total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
220
+
221
+ return total_loss, {
222
+ 'original_loss': original_loss.item(),
223
+ 'alignment_loss': alignment_loss.item(),
224
+ 'text_alignment': text_alignment_loss.item(),
225
+ 'image_alignment': image_alignment_loss.item(),
226
+ 'text_cosine': text_cosine_loss.item(),
227
+ 'image_cosine': image_cosine_loss.item()
228
+ }
229
+
230
+ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, color_model,
231
+ device, clip_processor, temperature=0.07, alignment_weight=0.3):
232
+ """
233
+ Enhanced training with direct color alignment loss
234
+ """
235
+ model.train()
236
+ total_loss = 0.0
237
+ total_metrics = {
238
+ 'original_loss': 0.0,
239
+ 'alignment_loss': 0.0,
240
+ 'text_alignment': 0.0,
241
+ 'image_alignment': 0.0,
242
+ 'text_cosine': 0.0,
243
+ 'image_cosine': 0.0
244
+ }
245
+ num_batches = 0
246
+
247
+ pbar = tqdm(train_loader, desc="Training Enhanced", leave=False)
248
+
249
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
250
+ # Move data to device
251
+ images = images.to(device)
252
+ images = images.expand(-1, 3, -1, -1)
253
+
254
+ # Process text inputs
255
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
256
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
257
+
258
+ # Forward pass
259
+ optimizer.zero_grad()
260
+ outputs = model(**text_inputs, pixel_values=images)
261
+
262
+ text_features = outputs.text_embeds
263
+ image_features = outputs.image_embeds
264
+
265
+ # Get feature embeddings
266
+ if hasattr(feature_models['color'], 'get_color_name_embeddings'):
267
+ color_features = feature_models['color'].get_color_name_embeddings(colors)
268
+ else:
269
+ color_features = feature_models['color'].get_text_embeddings(colors)
270
+ hierarchy_features = feature_models['hierarchy'].get_text_embeddings(hierarchy)
271
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
272
+
273
+ # Calculate enhanced loss
274
+ loss, metrics = enhanced_contrastive_loss(
275
+ text_features, image_features, concat_features,
276
+ color_model, colors, temperature, alignment_weight
277
+ )
278
+
279
+ # Backward pass
280
+ loss.backward()
281
+ optimizer.step()
282
+
283
+ total_loss += loss.item()
284
+ for key, value in metrics.items():
285
+ total_metrics[key] += value
286
+ num_batches += 1
287
+
288
+ # Update progress bar
289
+ pbar.set_postfix({
290
+ 'Loss': f'{loss.item():.4f}',
291
+ 'Align': f'{metrics["alignment_loss"]:.4f}',
292
+ 'Text_Cos': f'{metrics["text_cosine"]:.4f}',
293
+ 'Img_Cos': f'{metrics["image_cosine"]:.4f}'
294
+ })
295
+
296
+ avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
297
+ return total_loss / num_batches, avg_metrics
298
+
299
+ def validate_correlation(model, color_model, val_loader, clip_processor, device):
300
+ """
301
+ Validate the correlation between color model and main model embeddings
302
+ """
303
+ model.eval()
304
+ color_model.eval()
305
+
306
+ all_color_embeddings = []
307
+ all_main_text_color = []
308
+ all_main_image_color = []
309
+
310
+ with torch.no_grad():
311
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(tqdm(val_loader, desc="Validation Correlation", leave=False)):
312
+ if batch_idx >= 50: # Limit validation samples
313
+ break
314
+
315
+ images = images.to(device)
316
+ images = images.expand(-1, 3, -1, -1)
317
+
318
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
319
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
320
+
321
+ # Get embeddings
322
+ outputs = model(**text_inputs, pixel_values=images)
323
+ if hasattr(color_model, 'get_color_name_embeddings'):
324
+ color_emb = color_model.get_color_name_embeddings(colors)
325
+ else:
326
+ color_emb = color_model.get_text_embeddings(colors)
327
+
328
+ # Extract first 16 dimensions
329
+ main_text_color = outputs.text_embeds[:, :color_emb_dim]
330
+ main_image_color = outputs.image_embeds[:, :color_emb_dim]
331
+
332
+ all_color_embeddings.append(color_emb.cpu().numpy())
333
+ all_main_text_color.append(main_text_color.cpu().numpy())
334
+ all_main_image_color.append(main_image_color.cpu().numpy())
335
+
336
+ # Compute correlations
337
+ color_emb = np.vstack(all_color_embeddings)
338
+ main_text = np.vstack(all_main_text_color)
339
+ main_image = np.vstack(all_main_image_color)
340
+
341
+ # Flatten for correlation
342
+ color_flat = color_emb.flatten()
343
+ text_flat = main_text.flatten()
344
+ image_flat = main_image.flatten()
345
+
346
+ text_correlation = np.corrcoef(color_flat, text_flat)[0, 1]
347
+ image_correlation = np.corrcoef(color_flat, image_flat)[0, 1]
348
+
349
+ return {
350
+ 'text_correlation': text_correlation,
351
+ 'image_correlation': image_correlation
352
+ }
353
+
354
+ # -------------------------------
355
+ # Step 2: Define Dataset
356
+ # -------------------------------
357
+
358
+ class CustomDataset(Dataset):
359
+ def __init__(self, dataframe, use_local_images=True, image_size=224):
360
+ self.dataframe = dataframe
361
+ self.use_local_images = use_local_images
362
+ self.image_size = image_size
363
+
364
+ # Transforms with augmentation for training
365
+ self.transform = transforms.Compose([
366
+ transforms.Resize((image_size, image_size)),
367
+ transforms.RandomHorizontalFlip(p=0.5),
368
+ transforms.RandomRotation(15),
369
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
370
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
371
+ transforms.ToTensor(),
372
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
373
+ ])
374
+
375
+ # Transforms for validation (no augmentation)
376
+ self.val_transform = transforms.Compose([
377
+ transforms.Resize((image_size, image_size)),
378
+ transforms.ToTensor(),
379
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
380
+ ])
381
+
382
+ self.training_mode = True
383
+
384
+ def set_training_mode(self, training=True):
385
+ self.training_mode = training
386
+
387
+ def __len__(self):
388
+ return len(self.dataframe)
389
+
390
+ def __getitem__(self, idx):
391
+ row = self.dataframe.iloc[idx]
392
+
393
+ image_data = row[column_local_image_path]
394
+ image = Image.open(image_data).convert("RGB")
395
+
396
+ # Apply appropriate transform
397
+ if self.training_mode:
398
+ image = self.transform(image)
399
+ else:
400
+ image = self.val_transform(image)
401
+
402
+ # Get text and labels
403
+ description = row['text']
404
+ color = row['color']
405
+ hierarchy = row['hierarchy']
406
+
407
+ return image, description, color, hierarchy
408
+
409
+ def train_model(model, train_loader, val_loader, feature_models, device,
410
+ num_epochs=20, learning_rate=1e-5, temperature=0.07,
411
+ save_path=main_model_path, use_enhanced_loss=False, alignment_weight=0.3, color_alignment_model=None):
412
+ """
413
+ Custom training loop using train_one_epoch and valid_one_epoch functions
414
+ """
415
+ model = model.to(device)
416
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
417
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
418
+
419
+ train_losses = []
420
+ val_losses = []
421
+ best_val_loss = float('inf')
422
+ patience_counter = 0
423
+ patience = 5
424
+
425
+ print(f"Starting training for {num_epochs} epochs...")
426
+ print(f"Learning rate: {learning_rate}")
427
+ print(f"Temperature: {temperature}")
428
+ print(f"Device: {device}")
429
+ print(f"Training samples: {len(train_loader.dataset)}")
430
+ print(f"Validation samples: {len(val_loader.dataset)}")
431
+ print(f"Batch size: {train_loader.batch_size}")
432
+ print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes")
433
+
434
+ # Create processor once for efficiency
435
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
436
+
437
+ # Create progress bar for epochs
438
+ epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
439
+
440
+ for epoch in epoch_pbar:
441
+ # Update epoch progress bar
442
+ epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
443
+
444
+ # Training
445
+ if use_enhanced_loss:
446
+ if color_alignment_model is None:
447
+ color_alignment_model = feature_models['color']
448
+ train_loss, align_metrics = train_one_epoch_enhanced(
449
+ model, train_loader, optimizer, feature_models, color_alignment_model, device, processor, temperature, alignment_weight
450
+ )
451
+ else:
452
+ train_loss = train_one_epoch(model, train_loader, optimizer, feature_models, device, processor, temperature)
453
+ align_metrics = None
454
+ train_losses.append(train_loss)
455
+
456
+ # Validation
457
+ val_loss = valid_one_epoch(model, val_loader, feature_models, device, processor, temperature)
458
+ val_losses.append(val_loss)
459
+
460
+ # Learning rate scheduling
461
+ scheduler.step(val_loss)
462
+
463
+ # Update epoch progress bar with metrics
464
+ postfix = {
465
+ 'Train Loss': f'{train_loss:.4f}',
466
+ 'Val Loss': f'{val_loss:.4f}',
467
+ 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
468
+ 'Best Val': f'{best_val_loss:.4f}'
469
+ }
470
+ if align_metrics is not None:
471
+ postfix.update({'Align': f"{align_metrics['alignment_loss']:.3f}", 'TextCos': f"{align_metrics['text_cosine']:.3f}", 'ImgCos': f"{align_metrics['image_cosine']:.3f}"})
472
+ epoch_pbar.set_postfix(postfix)
473
+
474
+ # Save best model
475
+ if val_loss < best_val_loss:
476
+ best_val_loss = val_loss
477
+ patience_counter = 0
478
+
479
+ # Save checkpoint
480
+ torch.save({
481
+ 'epoch': epoch,
482
+ 'model_state_dict': model.state_dict(),
483
+ 'optimizer_state_dict': optimizer.state_dict(),
484
+ 'train_loss': train_loss,
485
+ 'val_loss': val_loss,
486
+ 'best_val_loss': best_val_loss,
487
+ }, save_path)
488
+ else:
489
+ patience_counter += 1
490
+
491
+ # Early stopping
492
+ if patience_counter >= patience:
493
+ print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
494
+ break
495
+
496
+ # Plot training curves
497
+ plt.figure(figsize=(12, 4))
498
+
499
+ plt.subplot(1, 2, 1)
500
+ plt.plot(train_losses, label='Train Loss', color='blue')
501
+ plt.plot(val_losses, label='Val Loss', color='red')
502
+ plt.title('Training and Validation Loss')
503
+ plt.xlabel('Epoch')
504
+ plt.ylabel('Loss')
505
+ plt.legend()
506
+ plt.grid(True, alpha=0.3)
507
+
508
+ plt.subplot(1, 2, 2)
509
+ plt.plot(train_losses, label='Train Loss', color='blue')
510
+ plt.title('Training Loss')
511
+ plt.xlabel('Epoch')
512
+ plt.ylabel('Loss')
513
+ plt.legend()
514
+ plt.grid(True, alpha=0.3)
515
+
516
+ plt.tight_layout()
517
+ plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
518
+ plt.show()
519
+
520
+ print(f"\nTraining completed!")
521
+ print(f"Best validation loss: {best_val_loss:.4f}")
522
+ print(f"Final model saved to: {save_path}")
523
+ print(f"Training curves saved to: training_curves.png")
524
+
525
+ return train_losses, val_losses
526
+
527
+ def load_models():
528
+ # Load feature models
529
+ from color_model import ColorCLIP, SimpleTokenizer
530
+ from hierarchy_model import Model, HierarchyExtractor
531
+ import json
532
+
533
+ # Initialize tokenizer first
534
+ tokenizer = SimpleTokenizer()
535
+
536
+ # Load vocabulary if available
537
+ vocab_path = 'tokenizer_vocab.json'
538
+ if os.path.exists(vocab_path):
539
+ with open(vocab_path, 'r') as f:
540
+ vocab_dict = json.load(f)
541
+ tokenizer.load_vocab(vocab_dict)
542
+ print(f"Tokenizer vocabulary loaded from {vocab_path}")
543
+ else:
544
+ print(f"Warning: {vocab_path} not found. Using default tokenizer.")
545
+
546
+ # Load trained model first to get correct vocab size
547
+ checkpoint = torch.load(color_model_path, map_location=device)
548
+
549
+ # Extract vocab size from the checkpoint's embedding layer
550
+ vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
551
+ print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}")
552
+ print(f"Vocab size from tokenizer: {tokenizer.counter}")
553
+
554
+ # Use the larger of the two to ensure compatibility
555
+ vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter)
556
+
557
+ # Initialize model with correct vocab size
558
+ color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=color_emb_dim).to(device)
559
+ color_model.tokenizer = tokenizer
560
+
561
+ # Load the checkpoint
562
+ color_model.load_state_dict(checkpoint)
563
+ print(f"Model loaded from {color_model_path}")
564
+
565
+ color_model.eval()
566
+ color_model.name = 'color'
567
+
568
+ # Load hierarchy model (embed_dim=64)
569
+ hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=device)
570
+ hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
571
+ hierarchy_model = Model(
572
+ num_hierarchy_classes=len(hierarchy_classes),
573
+ embed_dim=hierarchy_emb_dim
574
+ ).to(device)
575
+ hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
576
+
577
+ # Set up hierarchy extractor
578
+ hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
579
+ hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
580
+ hierarchy_model.eval()
581
+ hierarchy_model.name = 'hierarchy'
582
+
583
+ feature_models = {model.name: model for model in [color_model, hierarchy_model]}
584
+
585
+ return feature_models
586
+
587
+ if __name__ == "__main__":
588
+ # Load and prepare data
589
+ import pandas as pd
590
+
591
+ print("Loading data...")
592
+ df = pd.read_csv(local_dataset_path)
593
+ print(f"Loaded {len(df)} samples")
594
+
595
+ # Filter out rows with NaN values in image path
596
+ df_clean = df.dropna(subset=[column_local_image_path])
597
+ print(f"After filtering NaN image paths: {len(df_clean)} samples")
598
+
599
+ # Create datasets
600
+ dataset = CustomDataset(df_clean)
601
+
602
+ # Split for train/val - use only a subset for faster training
603
+ # Use 10% of data for faster training
604
+ subset_size = min(10000, len(dataset)) # Max 10k samples
605
+ train_size = int(0.8 * subset_size)
606
+ val_size = subset_size - train_size
607
+
608
+ # Create subset with proper integer indices
609
+ subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
610
+ subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
611
+
612
+ train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
613
+
614
+ # Create dataloaders with optimized parameters
615
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
616
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
617
+
618
+ print(f"Train samples: {len(train_dataset)}")
619
+ print(f"Val samples: {len(val_dataset)}")
620
+
621
+ print("Loading models...")
622
+ feature_models = load_models()
623
+
624
+ # Create the main CLIP model
625
+ clip_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
626
+
627
+ print("Training...")
628
+
629
+ # Train using custom training loop
630
+ train_losses, val_losses = train_model(
631
+ model=clip_model,
632
+ train_loader=train_loader,
633
+ val_loader=val_loader,
634
+ feature_models=feature_models,
635
+ device=device,
636
+ num_epochs=20, # Reduced epochs for faster training
637
+ learning_rate=2e-5, # Slightly higher learning rate
638
+ temperature=0.07
639
+ )