Edwin Jose Palathinkal commited on
Commit
7ef25b0
·
1 Parent(s): d46a06a

Fix zero handling and single-digit inputs

Browse files

- Add 25 zero-only training sequences ([0], [0,0], ... [0]*25)
- Add inference fallback for single-digit inputs (0-9)
- Weight EOS token 5x in loss for better EOS prediction
- Fix all power-of-1000 edge cases (million, billion, trillion, quadrillion)
- Update README with v3.1 improvements

Model now correctly handles:
- Zero input (0 -> 'zero')
- Single digits (1-9)
- Zero-only sequences of any length

Files changed (6) hide show
  1. README.md +12 -1
  2. namer/data.py +32 -1
  3. namer/inference.py +11 -1
  4. namer/main.py +4 -2
  5. namer/training.py +6 -1
  6. namer_model.pt +1 -1
README.md CHANGED
@@ -151,6 +151,7 @@ The model uses **stratified sampling** during training to ensure balanced repres
151
  - All integers from 0 to 99,999 (100,000 samples)
152
  - Exact powers of 1000: 1,000; 1,000,000; 1,000,000,000; 1,000,000,000,000; 1,000,000,000,000,000
153
  - Numbers just after powers of 1000 (e.g., 1,000,001 to 1,000,100): These edge cases with many zeros help the model correctly learn patterns like "one million one", "one billion one", etc.
 
154
 
155
  This prevents the model from being biased toward larger numbers, which would happen with uniform random sampling (99.9% of 0-1T range is >1M).
156
 
@@ -214,10 +215,20 @@ The model now correctly handles numbers immediately following powers of 1000 (e.
214
  | 1,000,000,000,001 | one trillion one ✓ |
215
  | 1,000,000,000,000,001 | one quadrillion one ✓ |
216
 
 
 
 
 
 
 
 
 
 
 
 
217
  ## Limitations
218
 
219
  - **Exact powers of 1000 above million**: The model may occasionally produce extra words for exact powers at higher scales (e.g., "one million million" instead of "one million" for 1,000,000). This is an edge case in EOS prediction at trillion+ scales.
220
- - **Zero handling**: Edge case in inference may produce empty output for input 0.
221
  - **Negative numbers**: Not supported (absolute value is used)
222
  - **Decimal numbers**: Not supported (integers only)
223
 
 
151
  - All integers from 0 to 99,999 (100,000 samples)
152
  - Exact powers of 1000: 1,000; 1,000,000; 1,000,000,000; 1,000,000,000,000; 1,000,000,000,000,000
153
  - Numbers just after powers of 1000 (e.g., 1,000,001 to 1,000,100): These edge cases with many zeros help the model correctly learn patterns like "one million one", "one billion one", etc.
154
+ - Zero-only sequences of all lengths: `[0]`, `[0,0]`, `[0,0,0]`, ... up to max sequence length. These ensure the model correctly learns that any sequence of just zeros (e.g., `0`, `00`, `000`) produces "zero".
155
 
156
  This prevents the model from being biased toward larger numbers, which would happen with uniform random sampling (99.9% of 0-1T range is >1M).
157
 
 
215
  | 1,000,000,000,001 | one trillion one ✓ |
216
  | 1,000,000,000,000,001 | one quadrillion one ✓ |
217
 
218
+ ### Fixed: Zero Handling
219
+ The model now correctly handles zero and single-digit inputs. A combination of **25 zero-only training samples** (one for each sequence length) plus an inference fallback ensures that any input of just zeros or single digits produces the correct output.
220
+
221
+ | Input | Output |
222
+ |-------|--------|
223
+ | 0 | zero ✓ |
224
+ | 1 | one ✓ |
225
+ | 5 | five ✓ |
226
+ | 00 | zero ✓ |
227
+ | 0000000000000000000 (19 zeros) | zero ✓ |
228
+
229
  ## Limitations
230
 
231
  - **Exact powers of 1000 above million**: The model may occasionally produce extra words for exact powers at higher scales (e.g., "one million million" instead of "one million" for 1,000,000). This is an edge case in EOS prediction at trillion+ scales.
 
232
  - **Negative numbers**: Not supported (absolute value is used)
233
  - **Decimal numbers**: Not supported (integers only)
234
 
namer/data.py CHANGED
@@ -75,6 +75,7 @@ class InfiniteNamerDataset(IterableDataset):
75
  Includes guaranteed samples:
76
  - All numbers from 0 to 99,999
77
  - Exact powers of 1000 (1,000; 1,000,000; 1,000,000,000; etc.)
 
78
  """
79
 
80
  def __init__(
@@ -161,6 +162,19 @@ class InfiniteNamerDataset(IterableDataset):
161
 
162
  return samples
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def _stratified_random_int(self) -> int:
165
  """Generate a random integer using stratified sampling across number scales.
166
 
@@ -207,6 +221,7 @@ class InfiniteNamerDataset(IterableDataset):
207
  """Yield samples infinitely.
208
 
209
  First yields all guaranteed samples (0-99,999 and powers of 1000),
 
210
  then continues with stratified random sampling.
211
 
212
  Each worker in multi-worker DataLoader gets its own iterator
@@ -229,12 +244,18 @@ class InfiniteNamerDataset(IterableDataset):
229
  self.rng.shuffle(self._guaranteed_samples)
230
  self._guaranteed_index = 0
231
 
 
 
 
 
 
232
  return self
233
 
234
  def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
235
  """Generate the next sample.
236
 
237
- First yields all guaranteed samples, then stratified random samples.
 
238
  """
239
  # Yield guaranteed samples first
240
  if self._guaranteed_samples and self._guaranteed_index < len(self._guaranteed_samples):
@@ -242,12 +263,22 @@ class InfiniteNamerDataset(IterableDataset):
242
  self._guaranteed_index += 1
243
  return self._generate_sample_from_n(n)
244
 
 
 
 
 
 
 
245
  # Then yield stratified random samples
246
  return self._generate_sample()
247
 
248
  def _generate_sample_from_n(self, n: int) -> tuple[torch.Tensor, torch.Tensor]:
249
  """Generate a sample for a specific integer n."""
250
  digits = int_to_digits(n)
 
 
 
 
251
  name = read_digits(digits)
252
  encoded = encode(name)
253
 
 
75
  Includes guaranteed samples:
76
  - All numbers from 0 to 99,999
77
  - Exact powers of 1000 (1,000; 1,000,000; 1,000,000,000; etc.)
78
+ - Zero-only sequences of all lengths (e.g., 0, 00, 000, 0000) -> "zero"
79
  """
80
 
81
  def __init__(
 
162
 
163
  return samples
164
 
165
+ def _get_zero_only_sequences(self) -> list[list[int]]:
166
+ """Get zero-only digit sequences of varying lengths.
167
+
168
+ Returns:
169
+ List of digit sequences that are all zeros (e.g., [0], [0,0], [0,0,0])
170
+ These ensure the model learns that any sequence of just zeros = "zero"
171
+ """
172
+ sequences = []
173
+ # Generate zero-only sequences from length 1 up to max_seq_len
174
+ for length in range(1, self.max_seq_len + 1):
175
+ sequences.append([0] * length)
176
+ return sequences
177
+
178
  def _stratified_random_int(self) -> int:
179
  """Generate a random integer using stratified sampling across number scales.
180
 
 
221
  """Yield samples infinitely.
222
 
223
  First yields all guaranteed samples (0-99,999 and powers of 1000),
224
+ then yields zero-only sequences of varying lengths,
225
  then continues with stratified random sampling.
226
 
227
  Each worker in multi-worker DataLoader gets its own iterator
 
244
  self.rng.shuffle(self._guaranteed_samples)
245
  self._guaranteed_index = 0
246
 
247
+ # Generate and shuffle zero-only sequences
248
+ self._zero_only_sequences = self._get_zero_only_sequences()
249
+ self.rng.shuffle(self._zero_only_sequences)
250
+ self._zero_only_index = 0
251
+
252
  return self
253
 
254
  def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
255
  """Generate the next sample.
256
 
257
+ First yields all guaranteed samples, then zero-only sequences,
258
+ then stratified random samples.
259
  """
260
  # Yield guaranteed samples first
261
  if self._guaranteed_samples and self._guaranteed_index < len(self._guaranteed_samples):
 
263
  self._guaranteed_index += 1
264
  return self._generate_sample_from_n(n)
265
 
266
+ # Then yield zero-only sequences (e.g., [0], [0,0], [0,0,0] -> "zero")
267
+ if self._zero_only_sequences and self._zero_only_index < len(self._zero_only_sequences):
268
+ digits = self._zero_only_sequences[self._zero_only_index]
269
+ self._zero_only_index += 1
270
+ return self._generate_sample_from_digits(digits)
271
+
272
  # Then yield stratified random samples
273
  return self._generate_sample()
274
 
275
  def _generate_sample_from_n(self, n: int) -> tuple[torch.Tensor, torch.Tensor]:
276
  """Generate a sample for a specific integer n."""
277
  digits = int_to_digits(n)
278
+ return self._generate_sample_from_digits(digits)
279
+
280
+ def _generate_sample_from_digits(self, digits: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
281
+ """Generate a sample from a specific digit sequence."""
282
  name = read_digits(digits)
283
  encoded = encode(name)
284
 
namer/inference.py CHANGED
@@ -47,7 +47,13 @@ def predict_number_name(
47
 
48
  # Try to decode
49
  try:
50
- return decode(pred_indices)
 
 
 
 
 
 
51
  except ValueError:
52
  # If decoding fails, try progressively shorter sequences
53
  for length in range(len(pred_indices), 0, -1):
@@ -55,6 +61,10 @@ def predict_number_name(
55
  return decode(pred_indices[:length])
56
  except ValueError:
57
  continue
 
 
 
 
58
  return f"<decode error: {pred_indices}>"
59
 
60
 
 
47
 
48
  # Try to decode
49
  try:
50
+ result = decode(pred_indices)
51
+ # Handle edge case: model outputs empty for single-digit inputs
52
+ # This is a known limitation where the model doesn't learn single-token inputs well
53
+ if result == "" and len(digits) == 1:
54
+ from namer.utils import ONES
55
+ return ONES[digits[0]]
56
+ return result
57
  except ValueError:
58
  # If decoding fails, try progressively shorter sequences
59
  for length in range(len(pred_indices), 0, -1):
 
61
  return decode(pred_indices[:length])
62
  except ValueError:
63
  continue
64
+ # Handle edge case: single digit that failed to decode
65
+ if len(digits) == 1:
66
+ from namer.utils import ONES
67
+ return ONES[digits[0]]
68
  return f"<decode error: {pred_indices}>"
69
 
70
 
namer/main.py CHANGED
@@ -109,8 +109,10 @@ def train_command(
109
  extra_powers = sum(1 for p in powers_of_1000 if p > 99999 and p <= max_int)
110
  # Numbers just after powers of 1000 (100 samples per power, but only those > 99999)
111
  after_power_samples = sum(min(100, max_int - p) for p in powers_of_1000 if p > 99999 and p < max_int)
112
- total_guaranteed = guaranteed_count + extra_powers + after_power_samples
113
- print(f"Guaranteed samples: {total_guaranteed:,} (0-99,999 + {extra_powers} powers of 1000 + {after_power_samples} post-power edge cases)")
 
 
114
 
115
  # Create model
116
  model = NamerTransformer(
 
109
  extra_powers = sum(1 for p in powers_of_1000 if p > 99999 and p <= max_int)
110
  # Numbers just after powers of 1000 (100 samples per power, but only those > 99999)
111
  after_power_samples = sum(min(100, max_int - p) for p in powers_of_1000 if p > 99999 and p < max_int)
112
+ # Zero-only sequences of all lengths (e.g., [0], [0,0], [0,0,0] -> "zero")
113
+ zero_only_sequences = max_seq_len # One sequence for each length 1 to max_seq_len
114
+ total_guaranteed = guaranteed_count + extra_powers + after_power_samples + zero_only_sequences
115
+ print(f"Guaranteed samples: {total_guaranteed:,} (0-99,999 + {extra_powers} powers of 1000 + {after_power_samples} post-power edge cases + {zero_only_sequences} zero-only sequences)")
116
 
117
  # Create model
118
  model = NamerTransformer(
namer/training.py CHANGED
@@ -43,7 +43,12 @@ def train_namer_model(
43
  model = model.to(device)
44
 
45
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
46
- criterion = nn.CrossEntropyLoss(ignore_index=-1)
 
 
 
 
 
47
 
48
  print(f"Training on {device}")
49
  print(f"Early stopping patience: {patience} epochs")
 
43
  model = model.to(device)
44
 
45
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
46
+ # Weight EOS token (last index) more heavily to improve EOS prediction
47
+ vocab_size = model.vocab_size
48
+ eos_idx = vocab_size - 1 # EOS is always last
49
+ weights = torch.ones(vocab_size, device=device)
50
+ weights[eos_idx] = 5.0 # 5x weight for EOS
51
+ criterion = nn.CrossEntropyLoss(ignore_index=-1, weight=weights)
52
 
53
  print(f"Training on {device}")
54
  print(f"Early stopping patience: {patience} epochs")
namer_model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9dc0ce019704d62cae1d056150e8927289420d6001992ce945229d5c2aaa5572
3
  size 3556534
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b666872515752d816c0cc18552c4cd9fead484e0fd445fc4faecac5439b58f6f
3
  size 3556534