Upload handler.py
Browse files- handler.py +17 -2
handler.py
CHANGED
|
@@ -181,6 +181,7 @@ class _DecodingParams:
|
|
| 181 |
temperature: float = 0.8
|
| 182 |
top_p: float = 0.9
|
| 183 |
max_new_tokens: int = 256
|
|
|
|
| 184 |
stop_quality: float = 0.6
|
| 185 |
|
| 186 |
@classmethod
|
|
@@ -508,16 +509,30 @@ class EndpointHandler:
|
|
| 508 |
if logits is None:
|
| 509 |
break
|
| 510 |
last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
|
| 511 |
-
next_logits = logits[0, last_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
next_token = self._sample_next_token(next_logits, decoding, rng)
|
| 513 |
token_ids.append(int(next_token))
|
| 514 |
steps += 1
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
outputs = self._run_candidate(base_feed, token_ids)
|
| 517 |
formatted_outputs = outputs
|
| 518 |
quality = self._extract_q_hat(outputs)
|
| 519 |
|
| 520 |
-
if
|
|
|
|
| 521 |
break
|
| 522 |
if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
|
| 523 |
break
|
|
|
|
| 181 |
temperature: float = 0.8
|
| 182 |
top_p: float = 0.9
|
| 183 |
max_new_tokens: int = 256
|
| 184 |
+
min_new_tokens: int = 16 # Minimum tokens before allowing EOS
|
| 185 |
stop_quality: float = 0.6
|
| 186 |
|
| 187 |
@classmethod
|
|
|
|
| 509 |
if logits is None:
|
| 510 |
break
|
| 511 |
last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
|
| 512 |
+
next_logits = logits[0, last_index].copy()
|
| 513 |
+
|
| 514 |
+
# Apply strong penalty to EOS token if we haven't reached min_new_tokens
|
| 515 |
+
# This reduces the probability of generating EOS prematurely
|
| 516 |
+
if steps < decoding.min_new_tokens:
|
| 517 |
+
next_logits[self._tokenizer.eos_token_id] -= 10.0
|
| 518 |
+
|
| 519 |
next_token = self._sample_next_token(next_logits, decoding, rng)
|
| 520 |
token_ids.append(int(next_token))
|
| 521 |
steps += 1
|
| 522 |
|
| 523 |
+
# Check if we generated EOS prematurely and replace with space
|
| 524 |
+
if token_ids[-1] == self._tokenizer.eos_token_id and steps < decoding.min_new_tokens:
|
| 525 |
+
# Find space token ID (fallback to 'a' if space not found)
|
| 526 |
+
space_token_id = self._tokenizer._token_to_id.get(" ", self._tokenizer._token_to_id.get("a", self._tokenizer.unk_token_id))
|
| 527 |
+
token_ids[-1] = space_token_id
|
| 528 |
+
# Note: In production, add logging here to track how often this happens
|
| 529 |
+
|
| 530 |
outputs = self._run_candidate(base_feed, token_ids)
|
| 531 |
formatted_outputs = outputs
|
| 532 |
quality = self._extract_q_hat(outputs)
|
| 533 |
|
| 534 |
+
# Only allow EOS break if we've generated at least min_new_tokens (excluding BOS)
|
| 535 |
+
if token_ids[-1] == self._tokenizer.eos_token_id and steps >= decoding.min_new_tokens:
|
| 536 |
break
|
| 537 |
if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
|
| 538 |
break
|