Leacb4 commited on
Commit
bda7a5a
Β·
verified Β·
1 Parent(s): 398de18

Upload hierarchy_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hierarchy_model.py +53 -59
hierarchy_model.py CHANGED
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
10
  import re
11
  import requests
12
  from io import BytesIO
13
- from config import hierarchy_model_path, device, hierarchy_emb_dim, color_emb_dim, local_dataset_path, main_model_path, color_model_path
14
 
15
  # -------------------------
16
  # 1) Dataset
@@ -41,11 +41,11 @@ class HierarchyDataset(Dataset):
41
 
42
  # Check local image availability
43
  if use_local_images:
44
- if 'local_image_path' not in dataframe.columns:
45
- print("⚠️ Column 'local_image_path' not found. Using URLs.")
46
  self.use_local_images = False
47
  else:
48
- local_available = dataframe['local_image_path'].notna().sum()
49
  total = len(dataframe)
50
  print(f"πŸ“ Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)")
51
 
@@ -61,29 +61,24 @@ class HierarchyDataset(Dataset):
61
  row = self.dataframe.iloc[idx]
62
 
63
  # Try to load local image first
64
- if self.use_local_images and pd.notna(row.get('local_image_path', '')):
65
- local_path = row['local_image_path']
66
  image = Image.open(local_path).convert("RGB")
67
  # Check if image is a dictionary of bytes
68
- elif isinstance(row['image_url'], dict):
69
- image = Image.open(BytesIO(row['image_url']['bytes'])).convert('RGB')
70
  # Otherwise, try to download from URL
71
  else:
72
- try:
73
- image = self._download_image(row['image_url'])
74
- except Exception as e:
75
- print(f"⚠️ Failed to load image {idx}: {e}")
76
- # Create a blank image as fallback
77
- image = Image.new('RGB', (224, 224), color='gray')
78
-
79
  # Apply transforms
80
  if hasattr(self, 'training_mode') and not self.training_mode:
81
  image = self.val_transform(image)
82
  else:
83
  image = self.transform(image)
84
 
85
- description = row['text']
86
- hierarchy = row['hierarchy']
87
 
88
  return image, description, hierarchy
89
 
@@ -190,7 +185,7 @@ class HierarchyExtractor:
190
  return self.class_to_idx[hierarchy]
191
  return None
192
 
193
- def get_hierarchy_embedding(self, text, embed_dim=64):
194
  """Create embedding from hierarchy index"""
195
  hierarchy_idx = self.extract_hierarchy_idx(text)
196
  if hierarchy_idx is not None:
@@ -351,6 +346,9 @@ class Model(nn.Module):
351
  """Get text embeddings for a given text string or list of strings"""
352
 
353
  with torch.no_grad():
 
 
 
354
  # Handle case where text is a list/tuple of hierarchies
355
  if isinstance(text, (list, tuple)):
356
  # Process multiple hierarchies
@@ -365,7 +363,7 @@ class Model(nn.Module):
365
  raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}")
366
 
367
  # Convert to tensor and move to device
368
- hierarchy_indices = torch.tensor(hierarchy_indices, device=device)
369
 
370
  # Get text embeddings for all hierarchies
371
  output = self.forward(hierarchy_indices=hierarchy_indices)
@@ -379,7 +377,7 @@ class Model(nn.Module):
379
  raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
380
 
381
  # Convert to tensor and move to device
382
- hierarchy_indices = torch.tensor([hierarchy_idx], device=device)
383
 
384
  # Get text embeddings
385
  output = self.forward(hierarchy_indices=hierarchy_indices)
@@ -490,7 +488,7 @@ def collate_fn(batch, hierarchy_extractor):
490
  return {
491
  'image': images,
492
  'hierarchy_indices': hierarchy_indices,
493
- 'hierarchy': hierarchies
494
  }
495
 
496
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
@@ -525,7 +523,7 @@ def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, sch
525
  for batch in pbar:
526
  images = batch['image'].to(device)
527
  hierarchy_indices = batch['hierarchy_indices'].to(device)
528
- target_hierarchies = batch['hierarchy']
529
 
530
  # Set dataset to training mode
531
  if hasattr(dataloader.dataset, 'set_training_mode'):
@@ -590,7 +588,7 @@ def validate(model, dataloader, device, hierarchy_classes):
590
  for batch in pbar:
591
  images = batch['image'].to(device)
592
  hierarchy_indices = batch['hierarchy_indices'].to(device)
593
- target_hierarchies = batch['hierarchy']
594
 
595
  # Set dataset to validation mode
596
  if hasattr(dataloader.dataset, 'set_training_mode'):
@@ -631,28 +629,24 @@ def validate(model, dataloader, device, hierarchy_classes):
631
 
632
  if __name__ == "__main__":
633
  # Configuration
634
- CSV = "data/data_hierarchy_with_local_paths.csv"
635
- DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
636
- BATCH = 16
637
- LR = 5e-5
638
- EPOCHS = 20
639
- VAL_SPLIT = 0.2
640
- EMB_DIM = 16
641
- DROPOUT = 0.4
642
- WEIGHT_DECAY = 1e-3
643
-
644
- print(f"πŸš€ Starting hierarchical training on device: {DEVICE}")
645
- print(f"πŸ“Š Config: {EPOCHS} epochs, batch={BATCH}, lr={LR}, embed_dim={EMB_DIM}")
646
 
647
  # Load dataset
648
- print(f"πŸ“ Using dataset: {CSV}")
649
- df = pd.read_csv(CSV)
650
- df = df[df['hierarchy'] != 'vest']
651
-
652
  print(f"πŸ“ Loaded {len(df)} samples")
653
 
654
  # Get unique hierarchy classes
655
- hierarchy_classes = sorted(df['hierarchy'].unique().tolist())
656
  print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
657
 
658
  # Create hierarchy extractor
@@ -661,9 +655,9 @@ if __name__ == "__main__":
661
  # Train/validation split
662
  train_df, val_df = train_test_split(
663
  df,
664
- test_size=VAL_SPLIT,
665
  random_state=42,
666
- stratify=df['hierarchy']
667
  )
668
  train_df = train_df.reset_index(drop=True)
669
  val_df = val_df.reset_index(drop=True)
@@ -677,13 +671,13 @@ if __name__ == "__main__":
677
  # Create data loaders
678
  train_dl = DataLoader(
679
  train_ds,
680
- batch_size=BATCH,
681
  shuffle=True,
682
  collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
683
  )
684
  val_dl = DataLoader(
685
  val_ds,
686
- batch_size=BATCH,
687
  shuffle=False,
688
  collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
689
  )
@@ -691,13 +685,13 @@ if __name__ == "__main__":
691
  # Create model
692
  model = Model(
693
  num_hierarchy_classes=len(hierarchy_classes),
694
- embed_dim=EMB_DIM,
695
- dropout=DROPOUT
696
- ).to(DEVICE)
697
 
698
  # Optimizer and scheduler
699
- optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
700
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=LR/10)
701
 
702
  print(f"🎯 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
703
  print("\n" + "="*80)
@@ -706,15 +700,15 @@ if __name__ == "__main__":
706
  best_val_loss = float('inf')
707
  training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []}
708
 
709
- for e in range(EPOCHS):
710
- print(f"\nπŸ”„ Epoch {e+1}/{EPOCHS}")
711
  print("-" * 50)
712
 
713
  # Training
714
- train_metrics = train_one_epoch(model, train_dl, optimizer, DEVICE, hierarchy_classes, scheduler)
715
 
716
  # Validation
717
- val_metrics = validate(model, val_dl, DEVICE, hierarchy_classes)
718
 
719
  # Track history
720
  training_history['train_loss'].append(train_metrics['loss'])
@@ -741,10 +735,10 @@ if __name__ == "__main__":
741
  'hierarchy_classes': hierarchy_classes,
742
  'epoch': e+1,
743
  'config': {
744
- 'embed_dim': EMB_DIM,
745
- 'dropout': DROPOUT
746
  }
747
- }, "final_model_16.pth")
748
 
749
  # Save model every 2 epochs
750
  if (e + 1) % 2 == 0:
@@ -755,10 +749,10 @@ if __name__ == "__main__":
755
  'hierarchy_classes': hierarchy_classes,
756
  'epoch': e+1,
757
  'config': {
758
- 'embed_dim': EMB_DIM,
759
- 'dropout': DROPOUT
760
  }
761
- }, f"model_checkpoint_epoch_{e+1}_16.pth")
762
 
763
  print("\n" + "="*80)
764
  print("πŸŽ‰ Training completed!")
 
10
  import re
11
  import requests
12
  from io import BytesIO
13
+ import config
14
 
15
  # -------------------------
16
  # 1) Dataset
 
41
 
42
  # Check local image availability
43
  if use_local_images:
44
+ if config.column_local_image_path not in dataframe.columns:
45
+ print(f"⚠️ Column {config.column_local_image_path} not found. Using URLs.")
46
  self.use_local_images = False
47
  else:
48
+ local_available = dataframe[config.column_local_image_path].notna().sum()
49
  total = len(dataframe)
50
  print(f"πŸ“ Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)")
51
 
 
61
  row = self.dataframe.iloc[idx]
62
 
63
  # Try to load local image first
64
+ if self.use_local_images and pd.notna(row.get(config.column_local_image_path, '')):
65
+ local_path = row[config.column_local_image_path]
66
  image = Image.open(local_path).convert("RGB")
67
  # Check if image is a dictionary of bytes
68
+ elif isinstance(row[config.column_url_image], dict):
69
+ image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
70
  # Otherwise, try to download from URL
71
  else:
72
+ image = self._download_image(row[config.column_url_image])
73
+
 
 
 
 
 
74
  # Apply transforms
75
  if hasattr(self, 'training_mode') and not self.training_mode:
76
  image = self.val_transform(image)
77
  else:
78
  image = self.transform(image)
79
 
80
+ description = row[config.text_column]
81
+ hierarchy = row[config.hierarchy_column]
82
 
83
  return image, description, hierarchy
84
 
 
185
  return self.class_to_idx[hierarchy]
186
  return None
187
 
188
+ def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim):
189
  """Create embedding from hierarchy index"""
190
  hierarchy_idx = self.extract_hierarchy_idx(text)
191
  if hierarchy_idx is not None:
 
346
  """Get text embeddings for a given text string or list of strings"""
347
 
348
  with torch.no_grad():
349
+ # Get the device of the model
350
+ model_device = next(self.parameters()).device
351
+
352
  # Handle case where text is a list/tuple of hierarchies
353
  if isinstance(text, (list, tuple)):
354
  # Process multiple hierarchies
 
363
  raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}")
364
 
365
  # Convert to tensor and move to device
366
+ hierarchy_indices = torch.tensor(hierarchy_indices, device=model_device)
367
 
368
  # Get text embeddings for all hierarchies
369
  output = self.forward(hierarchy_indices=hierarchy_indices)
 
377
  raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}")
378
 
379
  # Convert to tensor and move to device
380
+ hierarchy_indices = torch.tensor([hierarchy_idx], device=model_device)
381
 
382
  # Get text embeddings
383
  output = self.forward(hierarchy_indices=hierarchy_indices)
 
488
  return {
489
  'image': images,
490
  'hierarchy_indices': hierarchy_indices,
491
+ config.hierarchy_column: hierarchies
492
  }
493
 
494
  def calculate_accuracy(logits, target_hierarchies, hierarchy_classes):
 
523
  for batch in pbar:
524
  images = batch['image'].to(device)
525
  hierarchy_indices = batch['hierarchy_indices'].to(device)
526
+ target_hierarchies = batch[config.hierarchy_column]
527
 
528
  # Set dataset to training mode
529
  if hasattr(dataloader.dataset, 'set_training_mode'):
 
588
  for batch in pbar:
589
  images = batch['image'].to(device)
590
  hierarchy_indices = batch['hierarchy_indices'].to(device)
591
+ target_hierarchies = batch[config.hierarchy_column]
592
 
593
  # Set dataset to validation mode
594
  if hasattr(dataloader.dataset, 'set_training_mode'):
 
629
 
630
  if __name__ == "__main__":
631
  # Configuration
632
+ device = config.device
633
+ batch_size = 16
634
+ lr = 5e-5
635
+ epochs = 20
636
+ val_split = 0.2
637
+ dropout = 0.4
638
+ weight_decay = 1e-3
639
+
640
+ print(f"πŸš€ Starting hierarchical training on device: {device}")
641
+ print(f"πŸ“Š Config: {epochs} epochs, batch={batch_size}, lr={lr}, embed_dim={config.hierarchy_emb_dim}")
 
 
642
 
643
  # Load dataset
644
+ print(f"πŸ“ Using dataset: { config.local_dataset_path}")
645
+ df = pd.read_csv(config.local_dataset_path)
 
 
646
  print(f"πŸ“ Loaded {len(df)} samples")
647
 
648
  # Get unique hierarchy classes
649
+ hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
650
  print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
651
 
652
  # Create hierarchy extractor
 
655
  # Train/validation split
656
  train_df, val_df = train_test_split(
657
  df,
658
+ test_size=val_split,
659
  random_state=42,
660
+ stratify=df[config.hierarchy_column]
661
  )
662
  train_df = train_df.reset_index(drop=True)
663
  val_df = val_df.reset_index(drop=True)
 
671
  # Create data loaders
672
  train_dl = DataLoader(
673
  train_ds,
674
+ batch_size=batch_size,
675
  shuffle=True,
676
  collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
677
  )
678
  val_dl = DataLoader(
679
  val_ds,
680
+ batch_size=batch_size,
681
  shuffle=False,
682
  collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor)
683
  )
 
685
  # Create model
686
  model = Model(
687
  num_hierarchy_classes=len(hierarchy_classes),
688
+ embed_dim=config.hierarchy_emb_dim,
689
+ dropout=dropout
690
+ ).to(device)
691
 
692
  # Optimizer and scheduler
693
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
694
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=lr/10)
695
 
696
  print(f"🎯 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
697
  print("\n" + "="*80)
 
700
  best_val_loss = float('inf')
701
  training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []}
702
 
703
+ for e in range(epochs):
704
+ print(f"\nπŸ”„ Epoch {e+1}/{epochs}")
705
  print("-" * 50)
706
 
707
  # Training
708
+ train_metrics = train_one_epoch(model, train_dl, optimizer, device, hierarchy_classes, scheduler)
709
 
710
  # Validation
711
+ val_metrics = validate(model, val_dl, device, hierarchy_classes)
712
 
713
  # Track history
714
  training_history['train_loss'].append(train_metrics['loss'])
 
735
  'hierarchy_classes': hierarchy_classes,
736
  'epoch': e+1,
737
  'config': {
738
+ 'embed_dim': config.hierarchy_emb_dim,
739
+ 'dropout': dropout
740
  }
741
+ }, config.hierarchy_model_path)
742
 
743
  # Save model every 2 epochs
744
  if (e + 1) % 2 == 0:
 
749
  'hierarchy_classes': hierarchy_classes,
750
  'epoch': e+1,
751
  'config': {
752
+ 'embed_dim': config.hierarchy_emb_dim,
753
+ 'dropout': dropout
754
  }
755
+ }, f"model_checkpoint_epoch_{e+1}.pth")
756
 
757
  print("\n" + "="*80)
758
  print("πŸŽ‰ Training completed!")