gnai-creator commited on
Commit
f06ec25
·
verified ·
1 Parent(s): e63609a

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -2
handler.py CHANGED
@@ -181,6 +181,7 @@ class _DecodingParams:
181
  temperature: float = 0.8
182
  top_p: float = 0.9
183
  max_new_tokens: int = 256
 
184
  stop_quality: float = 0.6
185
 
186
  @classmethod
@@ -508,16 +509,30 @@ class EndpointHandler:
508
  if logits is None:
509
  break
510
  last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
511
- next_logits = logits[0, last_index]
 
 
 
 
 
 
512
  next_token = self._sample_next_token(next_logits, decoding, rng)
513
  token_ids.append(int(next_token))
514
  steps += 1
515
 
 
 
 
 
 
 
 
516
  outputs = self._run_candidate(base_feed, token_ids)
517
  formatted_outputs = outputs
518
  quality = self._extract_q_hat(outputs)
519
 
520
- if token_ids[-1] == self._tokenizer.eos_token_id:
 
521
  break
522
  if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
523
  break
 
181
  temperature: float = 0.8
182
  top_p: float = 0.9
183
  max_new_tokens: int = 256
184
+ min_new_tokens: int = 16 # Minimum tokens before allowing EOS
185
  stop_quality: float = 0.6
186
 
187
  @classmethod
 
509
  if logits is None:
510
  break
511
  last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
512
+ next_logits = logits[0, last_index].copy()
513
+
514
+ # Apply strong penalty to EOS token if we haven't reached min_new_tokens
515
+ # This reduces the probability of generating EOS prematurely
516
+ if steps < decoding.min_new_tokens:
517
+ next_logits[self._tokenizer.eos_token_id] -= 10.0
518
+
519
  next_token = self._sample_next_token(next_logits, decoding, rng)
520
  token_ids.append(int(next_token))
521
  steps += 1
522
 
523
+ # Check if we generated EOS prematurely and replace with space
524
+ if token_ids[-1] == self._tokenizer.eos_token_id and steps < decoding.min_new_tokens:
525
+ # Find space token ID (fallback to 'a' if space not found)
526
+ space_token_id = self._tokenizer._token_to_id.get(" ", self._tokenizer._token_to_id.get("a", self._tokenizer.unk_token_id))
527
+ token_ids[-1] = space_token_id
528
+ # Note: In production, add logging here to track how often this happens
529
+
530
  outputs = self._run_candidate(base_feed, token_ids)
531
  formatted_outputs = outputs
532
  quality = self._extract_q_hat(outputs)
533
 
534
+ # Only allow EOS break if we've generated at least min_new_tokens (excluding BOS)
535
+ if token_ids[-1] == self._tokenizer.eos_token_id and steps >= decoding.min_new_tokens:
536
  break
537
  if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
538
  break