AbstractPhil commited on
Commit
185899c
·
verified ·
1 Parent(s): 36af25f

Create big_vocabulary_tests.py

Browse files
Files changed (1) hide show
  1. big_vocabulary_tests.py +70 -0
big_vocabulary_tests.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CV at D=32 with absurd vocabulary sizes.
3
+ Does V matter at scale? We say no.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import math
10
+ import time
11
+
12
+
13
+ def cayley_menger_vol2(points):
14
+ B, N, D = points.shape
15
+ gram = torch.bmm(points, points.transpose(1, 2))
16
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
17
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
18
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
19
+ cm[:, 0, 1:] = 1.0
20
+ cm[:, 1:, 0] = 1.0
21
+ cm[:, 1:, 1:] = d2
22
+ k = N - 1
23
+ sign = (-1.0) ** (k + 1)
24
+ fact = math.factorial(k)
25
+ return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
26
+
27
+
28
+ def cv_metric(weight, n_samples=500, n_points=5, pool_size=512):
29
+ V, D = weight.shape
30
+ pool = min(V, pool_size)
31
+ indices = torch.stack([
32
+ torch.randperm(pool, device=weight.device)[:n_points]
33
+ for _ in range(n_samples)
34
+ ])
35
+ pts = weight[:pool][indices]
36
+ vol2 = cayley_menger_vol2(pts)
37
+ valid = vol2 > 1e-20
38
+ if valid.sum() < 10:
39
+ return None
40
+ vols = vol2[valid].sqrt()
41
+ return (vols.std() / (vols.mean() + 1e-8)).item()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ D = 32
46
+ vocabs = [32, 512, 8192, 65536, 131072, 500000, 1000000, 4000000, 13000000]
47
+
48
+ print(f"D={D} fixed. CV across vocab sizes.")
49
+ print(f"Pool capped at 512 for fair comparison.")
50
+ print("=" * 60)
51
+
52
+ for V in vocabs:
53
+ t0 = time.time()
54
+ # Use raw tensor instead of nn.Embedding for huge sizes
55
+ weight = torch.randn(V, D)
56
+ cv = cv_metric(weight, n_samples=500)
57
+ elapsed = time.time() - t0
58
+ mem_mb = V * D * 4 / 1e6
59
+ print(f" V={V:>10,} D={D} CV={cv:.4f} {elapsed:.1f}s {mem_mb:.0f}MB")
60
+
61
+ # Also uncap the pool for the big ones
62
+ print()
63
+ print("=" * 60)
64
+ print("Now uncapped pool (sample from ALL embeddings):")
65
+ print("=" * 60)
66
+
67
+ for V in [512, 8192, 65536, 500000]:
68
+ weight = torch.randn(V, D)
69
+ cv = cv_metric(weight, n_samples=500, pool_size=V)
70
+ print(f" V={V:>10,} D={D} CV={cv:.4f} pool={V}")