Update modeling_geolip_vit.py
Browse files- modeling_geolip_vit.py +80 -26
modeling_geolip_vit.py
CHANGED
|
@@ -862,31 +862,33 @@ class DualStreamViT(nn.Module):
|
|
| 862 |
|
| 863 |
class MasteryQueue:
|
| 864 |
"""
|
| 865 |
-
Cross-batch embedding cache
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
natural 0.20-0.23 band. Dense same-class regions + sparse boundaries
|
| 877 |
-
= volume variation = higher CV.
|
| 878 |
"""
|
| 879 |
-
def __init__(self, dim,
|
| 880 |
-
|
|
|
|
|
|
|
| 881 |
self.dim = dim
|
|
|
|
| 882 |
self.max_size = max_size
|
|
|
|
| 883 |
self.patience = patience
|
| 884 |
self.device = device
|
| 885 |
self.active = False
|
| 886 |
|
| 887 |
# Queue storage
|
| 888 |
-
self._embs = None
|
| 889 |
-
self._labels = None
|
| 890 |
|
| 891 |
# Activation tracking
|
| 892 |
self._perfect_count = 0
|
|
@@ -896,11 +898,19 @@ class MasteryQueue:
|
|
| 896 |
# Progressive margin
|
| 897 |
self._margin_start = margin_start
|
| 898 |
self._margin_end = margin_end
|
| 899 |
-
self._margin_warmup = margin_warmup
|
| 900 |
-
self._mastery_steps = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
|
| 902 |
def check_activation(self, nce_acc):
|
| 903 |
-
"""Call each batch. Activates when nce_acc=
|
| 904 |
self._total_batches += 1
|
| 905 |
if nce_acc >= 0.99:
|
| 906 |
self._perfect_count += 1
|
|
@@ -912,21 +922,64 @@ class MasteryQueue:
|
|
| 912 |
self._activated_at = self._total_batches
|
| 913 |
print(f"\n β
MASTERY ACTIVATED at batch {self._total_batches} "
|
| 914 |
f"(nce_acc=1.0 for {self.patience} consecutive) "
|
| 915 |
-
f"[InfoNCE stays ON, margin {self._margin_start}β{self._margin_end}]"
|
|
|
|
| 916 |
|
| 917 |
if self.active:
|
| 918 |
self._mastery_steps += 1
|
| 919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
@property
|
| 921 |
def current_margin(self):
|
| 922 |
-
"""Progressive margin: linearly ramps from start to end over warmup steps."""
|
| 923 |
if not self.active:
|
| 924 |
return self._margin_start
|
| 925 |
t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
|
| 926 |
return self._margin_start + t * (self._margin_end - self._margin_start)
|
| 927 |
|
| 928 |
def push(self, emb, labels):
|
| 929 |
-
"""Add batch to queue. FIFO eviction."""
|
| 930 |
emb = emb.detach().to(self.device)
|
| 931 |
labels = labels.detach().to(self.device)
|
| 932 |
|
|
@@ -934,11 +987,10 @@ class MasteryQueue:
|
|
| 934 |
self._embs = emb
|
| 935 |
self._labels = labels
|
| 936 |
else:
|
| 937 |
-
self._embs = torch.cat([self._embs, emb], 0)[-self.
|
| 938 |
-
self._labels = torch.cat([self._labels, labels], 0)[-self.
|
| 939 |
|
| 940 |
def get(self):
|
| 941 |
-
"""Return current queue contents."""
|
| 942 |
if self._embs is None:
|
| 943 |
return None, None
|
| 944 |
return self._embs, self._labels
|
|
@@ -955,6 +1007,8 @@ class MasteryQueue:
|
|
| 955 |
'activated_at': self._activated_at,
|
| 956 |
'mastery_steps': self._mastery_steps,
|
| 957 |
'current_margin': self.current_margin,
|
|
|
|
|
|
|
| 958 |
}
|
| 959 |
|
| 960 |
|
|
|
|
| 862 |
|
| 863 |
class MasteryQueue:
|
| 864 |
"""
|
| 865 |
+
Cross-batch embedding cache with adaptive queue sizing.
|
| 866 |
+
|
| 867 |
+
Activation: when nce_acc >= 0.99 for `patience` consecutive batches.
|
| 868 |
+
Progressive margin: ramps from margin_start β margin_end over margin_warmup.
|
| 869 |
+
|
| 870 |
+
Adaptive queue sizing (call update_size each epoch):
|
| 871 |
+
- Monitors train_acc - val_acc gap (overfitting indicator)
|
| 872 |
+
- Gap growing β increase queue (more diverse negatives = regularization)
|
| 873 |
+
- Gap shrinking β decrease queue (tighter contrastive signal)
|
| 874 |
+
- Cooldown prevents oscillation: no resize for `resize_cooldown` epochs
|
| 875 |
+
after each change.
|
|
|
|
|
|
|
| 876 |
"""
|
| 877 |
+
def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096,
|
| 878 |
+
patience=50, device='cuda',
|
| 879 |
+
margin_start=0.1, margin_end=0.3, margin_warmup=5000,
|
| 880 |
+
resize_step=1024, resize_cooldown=5, overfit_threshold=3.0):
|
| 881 |
self.dim = dim
|
| 882 |
+
self.min_size = min_size
|
| 883 |
self.max_size = max_size
|
| 884 |
+
self._current_max = initial_size
|
| 885 |
self.patience = patience
|
| 886 |
self.device = device
|
| 887 |
self.active = False
|
| 888 |
|
| 889 |
# Queue storage
|
| 890 |
+
self._embs = None
|
| 891 |
+
self._labels = None
|
| 892 |
|
| 893 |
# Activation tracking
|
| 894 |
self._perfect_count = 0
|
|
|
|
| 898 |
# Progressive margin
|
| 899 |
self._margin_start = margin_start
|
| 900 |
self._margin_end = margin_end
|
| 901 |
+
self._margin_warmup = margin_warmup
|
| 902 |
+
self._mastery_steps = 0
|
| 903 |
+
|
| 904 |
+
# Adaptive sizing
|
| 905 |
+
self._resize_step = resize_step
|
| 906 |
+
self._resize_cooldown = resize_cooldown
|
| 907 |
+
self._overfit_threshold = overfit_threshold
|
| 908 |
+
self._epochs_since_resize = resize_cooldown # allow first resize
|
| 909 |
+
self._prev_gap = None
|
| 910 |
+
self._resize_history = []
|
| 911 |
|
| 912 |
def check_activation(self, nce_acc):
|
| 913 |
+
"""Call each batch. Activates when nce_acc >= 0.99 for patience steps."""
|
| 914 |
self._total_batches += 1
|
| 915 |
if nce_acc >= 0.99:
|
| 916 |
self._perfect_count += 1
|
|
|
|
| 922 |
self._activated_at = self._total_batches
|
| 923 |
print(f"\n β
MASTERY ACTIVATED at batch {self._total_batches} "
|
| 924 |
f"(nce_acc=1.0 for {self.patience} consecutive) "
|
| 925 |
+
f"[InfoNCE stays ON, margin {self._margin_start}β{self._margin_end}]"
|
| 926 |
+
f" queue={self._current_max}")
|
| 927 |
|
| 928 |
if self.active:
|
| 929 |
self._mastery_steps += 1
|
| 930 |
|
| 931 |
+
def update_size(self, train_acc, val_acc, epoch):
|
| 932 |
+
"""
|
| 933 |
+
Call once per epoch. Adjusts queue size based on overfit gap.
|
| 934 |
+
|
| 935 |
+
Gap = train_acc - val_acc.
|
| 936 |
+
Gap growing β queue grows (more negatives = regularization)
|
| 937 |
+
Gap shrinking β queue shrinks (tighter signal)
|
| 938 |
+
Cooldown prevents oscillation.
|
| 939 |
+
"""
|
| 940 |
+
if not self.active:
|
| 941 |
+
return
|
| 942 |
+
|
| 943 |
+
self._epochs_since_resize += 1
|
| 944 |
+
gap = train_acc - val_acc
|
| 945 |
+
|
| 946 |
+
if self._prev_gap is not None and self._epochs_since_resize >= self._resize_cooldown:
|
| 947 |
+
delta = gap - self._prev_gap
|
| 948 |
+
old_size = self._current_max
|
| 949 |
+
|
| 950 |
+
if delta > self._overfit_threshold:
|
| 951 |
+
# Overfitting increasing β grow queue for regularization
|
| 952 |
+
self._current_max = min(
|
| 953 |
+
self._current_max + self._resize_step, self.max_size)
|
| 954 |
+
elif delta < -self._overfit_threshold:
|
| 955 |
+
# Overfitting decreasing β shrink queue for sharper signal
|
| 956 |
+
self._current_max = max(
|
| 957 |
+
self._current_max - self._resize_step, self.min_size)
|
| 958 |
+
|
| 959 |
+
if self._current_max != old_size:
|
| 960 |
+
direction = "β" if self._current_max > old_size else "β"
|
| 961 |
+
print(f" β Queue {direction} {old_size}β{self._current_max} "
|
| 962 |
+
f"(gap {self._prev_gap:.1f}β{gap:.1f}, Ξ={delta:+.1f})")
|
| 963 |
+
self._epochs_since_resize = 0
|
| 964 |
+
self._resize_history.append(
|
| 965 |
+
(epoch, old_size, self._current_max, gap))
|
| 966 |
+
|
| 967 |
+
# Trim queue if it shrunk
|
| 968 |
+
if self._embs is not None and self._embs.shape[0] > self._current_max:
|
| 969 |
+
self._embs = self._embs[-self._current_max:]
|
| 970 |
+
self._labels = self._labels[-self._current_max:]
|
| 971 |
+
|
| 972 |
+
self._prev_gap = gap
|
| 973 |
+
|
| 974 |
@property
|
| 975 |
def current_margin(self):
|
|
|
|
| 976 |
if not self.active:
|
| 977 |
return self._margin_start
|
| 978 |
t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
|
| 979 |
return self._margin_start + t * (self._margin_end - self._margin_start)
|
| 980 |
|
| 981 |
def push(self, emb, labels):
|
| 982 |
+
"""Add batch to queue. FIFO eviction at current_max."""
|
| 983 |
emb = emb.detach().to(self.device)
|
| 984 |
labels = labels.detach().to(self.device)
|
| 985 |
|
|
|
|
| 987 |
self._embs = emb
|
| 988 |
self._labels = labels
|
| 989 |
else:
|
| 990 |
+
self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:]
|
| 991 |
+
self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:]
|
| 992 |
|
| 993 |
def get(self):
|
|
|
|
| 994 |
if self._embs is None:
|
| 995 |
return None, None
|
| 996 |
return self._embs, self._labels
|
|
|
|
| 1007 |
'activated_at': self._activated_at,
|
| 1008 |
'mastery_steps': self._mastery_steps,
|
| 1009 |
'current_margin': self.current_margin,
|
| 1010 |
+
'current_max': self._current_max,
|
| 1011 |
+
'resize_history': self._resize_history,
|
| 1012 |
}
|
| 1013 |
|
| 1014 |
|