AbstractPhil commited on
Commit
27c2e91
Β·
verified Β·
1 Parent(s): c1ea032

Update modeling_geolip_vit.py

Browse files
Files changed (1) hide show
  1. modeling_geolip_vit.py +80 -26
modeling_geolip_vit.py CHANGED
@@ -862,31 +862,33 @@ class DualStreamViT(nn.Module):
862
 
863
  class MasteryQueue:
864
  """
865
- Cross-batch embedding cache for progressive hard contrastive learning.
866
-
867
- Stage 1: Inactive. Standard InfoNCE handles in-batch discrimination.
868
- Stage 2: Activates when nce_acc=1.0 for `patience` consecutive batches.
869
- Caches embeddings + labels from recent batches.
870
- compute_loss uses the queue to find:
871
- - hard negatives: closest different-class embedding
872
- - hard positives: furthest same-class embedding
873
- Margin loss forces the model to separate these.
874
-
875
- This creates the class-level asymmetry that moves CV toward the
876
- natural 0.20-0.23 band. Dense same-class regions + sparse boundaries
877
- = volume variation = higher CV.
878
  """
879
- def __init__(self, dim, max_size=4096, patience=50, device='cuda',
880
- margin_start=0.1, margin_end=0.3, margin_warmup=5000):
 
 
881
  self.dim = dim
 
882
  self.max_size = max_size
 
883
  self.patience = patience
884
  self.device = device
885
  self.active = False
886
 
887
  # Queue storage
888
- self._embs = None # (Q, dim)
889
- self._labels = None # (Q,)
890
 
891
  # Activation tracking
892
  self._perfect_count = 0
@@ -896,11 +898,19 @@ class MasteryQueue:
896
  # Progressive margin
897
  self._margin_start = margin_start
898
  self._margin_end = margin_end
899
- self._margin_warmup = margin_warmup # batches after activation to reach max
900
- self._mastery_steps = 0 # batches since activation
 
 
 
 
 
 
 
 
901
 
902
  def check_activation(self, nce_acc):
903
- """Call each batch. Activates when nce_acc=1.0 for patience steps."""
904
  self._total_batches += 1
905
  if nce_acc >= 0.99:
906
  self._perfect_count += 1
@@ -912,21 +922,64 @@ class MasteryQueue:
912
  self._activated_at = self._total_batches
913
  print(f"\n β˜… MASTERY ACTIVATED at batch {self._total_batches} "
914
  f"(nce_acc=1.0 for {self.patience} consecutive) "
915
- f"[InfoNCE stays ON, margin {self._margin_start}β†’{self._margin_end}]")
 
916
 
917
  if self.active:
918
  self._mastery_steps += 1
919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  @property
921
  def current_margin(self):
922
- """Progressive margin: linearly ramps from start to end over warmup steps."""
923
  if not self.active:
924
  return self._margin_start
925
  t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
926
  return self._margin_start + t * (self._margin_end - self._margin_start)
927
 
928
  def push(self, emb, labels):
929
- """Add batch to queue. FIFO eviction."""
930
  emb = emb.detach().to(self.device)
931
  labels = labels.detach().to(self.device)
932
 
@@ -934,11 +987,10 @@ class MasteryQueue:
934
  self._embs = emb
935
  self._labels = labels
936
  else:
937
- self._embs = torch.cat([self._embs, emb], 0)[-self.max_size:]
938
- self._labels = torch.cat([self._labels, labels], 0)[-self.max_size:]
939
 
940
  def get(self):
941
- """Return current queue contents."""
942
  if self._embs is None:
943
  return None, None
944
  return self._embs, self._labels
@@ -955,6 +1007,8 @@ class MasteryQueue:
955
  'activated_at': self._activated_at,
956
  'mastery_steps': self._mastery_steps,
957
  'current_margin': self.current_margin,
 
 
958
  }
959
 
960
 
 
862
 
863
  class MasteryQueue:
864
  """
865
+ Cross-batch embedding cache with adaptive queue sizing.
866
+
867
+ Activation: when nce_acc >= 0.99 for `patience` consecutive batches.
868
+ Progressive margin: ramps from margin_start β†’ margin_end over margin_warmup.
869
+
870
+ Adaptive queue sizing (call update_size each epoch):
871
+ - Monitors train_acc - val_acc gap (overfitting indicator)
872
+ - Gap growing β†’ increase queue (more diverse negatives = regularization)
873
+ - Gap shrinking β†’ decrease queue (tighter contrastive signal)
874
+ - Cooldown prevents oscillation: no resize for `resize_cooldown` epochs
875
+ after each change.
 
 
876
  """
877
+ def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096,
878
+ patience=50, device='cuda',
879
+ margin_start=0.1, margin_end=0.3, margin_warmup=5000,
880
+ resize_step=1024, resize_cooldown=5, overfit_threshold=3.0):
881
  self.dim = dim
882
+ self.min_size = min_size
883
  self.max_size = max_size
884
+ self._current_max = initial_size
885
  self.patience = patience
886
  self.device = device
887
  self.active = False
888
 
889
  # Queue storage
890
+ self._embs = None
891
+ self._labels = None
892
 
893
  # Activation tracking
894
  self._perfect_count = 0
 
898
  # Progressive margin
899
  self._margin_start = margin_start
900
  self._margin_end = margin_end
901
+ self._margin_warmup = margin_warmup
902
+ self._mastery_steps = 0
903
+
904
+ # Adaptive sizing
905
+ self._resize_step = resize_step
906
+ self._resize_cooldown = resize_cooldown
907
+ self._overfit_threshold = overfit_threshold
908
+ self._epochs_since_resize = resize_cooldown # allow first resize
909
+ self._prev_gap = None
910
+ self._resize_history = []
911
 
912
  def check_activation(self, nce_acc):
913
+ """Call each batch. Activates when nce_acc >= 0.99 for patience steps."""
914
  self._total_batches += 1
915
  if nce_acc >= 0.99:
916
  self._perfect_count += 1
 
922
  self._activated_at = self._total_batches
923
  print(f"\n β˜… MASTERY ACTIVATED at batch {self._total_batches} "
924
  f"(nce_acc=1.0 for {self.patience} consecutive) "
925
+ f"[InfoNCE stays ON, margin {self._margin_start}β†’{self._margin_end}]"
926
+ f" queue={self._current_max}")
927
 
928
  if self.active:
929
  self._mastery_steps += 1
930
 
931
+ def update_size(self, train_acc, val_acc, epoch):
932
+ """
933
+ Call once per epoch. Adjusts queue size based on overfit gap.
934
+
935
+ Gap = train_acc - val_acc.
936
+ Gap growing β†’ queue grows (more negatives = regularization)
937
+ Gap shrinking β†’ queue shrinks (tighter signal)
938
+ Cooldown prevents oscillation.
939
+ """
940
+ if not self.active:
941
+ return
942
+
943
+ self._epochs_since_resize += 1
944
+ gap = train_acc - val_acc
945
+
946
+ if self._prev_gap is not None and self._epochs_since_resize >= self._resize_cooldown:
947
+ delta = gap - self._prev_gap
948
+ old_size = self._current_max
949
+
950
+ if delta > self._overfit_threshold:
951
+ # Overfitting increasing β†’ grow queue for regularization
952
+ self._current_max = min(
953
+ self._current_max + self._resize_step, self.max_size)
954
+ elif delta < -self._overfit_threshold:
955
+ # Overfitting decreasing β†’ shrink queue for sharper signal
956
+ self._current_max = max(
957
+ self._current_max - self._resize_step, self.min_size)
958
+
959
+ if self._current_max != old_size:
960
+ direction = "↑" if self._current_max > old_size else "↓"
961
+ print(f" βš™ Queue {direction} {old_size}β†’{self._current_max} "
962
+ f"(gap {self._prev_gap:.1f}β†’{gap:.1f}, Ξ”={delta:+.1f})")
963
+ self._epochs_since_resize = 0
964
+ self._resize_history.append(
965
+ (epoch, old_size, self._current_max, gap))
966
+
967
+ # Trim queue if it shrunk
968
+ if self._embs is not None and self._embs.shape[0] > self._current_max:
969
+ self._embs = self._embs[-self._current_max:]
970
+ self._labels = self._labels[-self._current_max:]
971
+
972
+ self._prev_gap = gap
973
+
974
  @property
975
  def current_margin(self):
 
976
  if not self.active:
977
  return self._margin_start
978
  t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
979
  return self._margin_start + t * (self._margin_end - self._margin_start)
980
 
981
  def push(self, emb, labels):
982
+ """Add batch to queue. FIFO eviction at current_max."""
983
  emb = emb.detach().to(self.device)
984
  labels = labels.detach().to(self.device)
985
 
 
987
  self._embs = emb
988
  self._labels = labels
989
  else:
990
+ self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:]
991
+ self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:]
992
 
993
  def get(self):
 
994
  if self._embs is None:
995
  return None, None
996
  return self._embs, self._labels
 
1007
  'activated_at': self._activated_at,
1008
  'mastery_steps': self._mastery_steps,
1009
  'current_margin': self.current_margin,
1010
+ 'current_max': self._current_max,
1011
+ 'resize_history': self._resize_history,
1012
  }
1013
 
1014