Upload train_1.7B_grpo.py with huggingface_hub
Browse files- train_1.7B_grpo.py +12 -4
train_1.7B_grpo.py
CHANGED
|
@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
|
|
| 166 |
"""Score expansion. Returns 0.0-1.0 for RL reward."""
|
| 167 |
text = expansion.strip()
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
# HARD FAIL: Chat template artifacts
|
| 170 |
-
if any(token in text for token in ['<|im_start|>', '
|
| 171 |
-
'\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
| 172 |
return 0.0
|
| 173 |
|
| 174 |
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
|
|
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
|
|
| 275 |
elif not entities:
|
| 276 |
entity_score = 10
|
| 277 |
|
| 278 |
-
total = format_score + diversity_score + hyde_score + quality_score + entity_score
|
| 279 |
-
max_possible =
|
| 280 |
return max(0.0, min(1.0, total / max_possible))
|
| 281 |
|
| 282 |
|
|
|
|
| 166 |
"""Score expansion. Returns 0.0-1.0 for RL reward."""
|
| 167 |
text = expansion.strip()
|
| 168 |
|
| 169 |
+
# Strip end token if present
|
| 170 |
+
text = text.replace('<|im_end|>', '').strip()
|
| 171 |
+
|
| 172 |
+
# Check for <think>...</think> blocks - strip and mark as not skipped
|
| 173 |
+
skipped_think = 20 # Bonus for not using thinking mode
|
| 174 |
+
if '<think>' in text and '</think>' in text:
|
| 175 |
+
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
| 176 |
+
skipped_think = 0 # No bonus if thinking was used
|
| 177 |
+
|
| 178 |
# HARD FAIL: Chat template artifacts
|
| 179 |
+
if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
|
|
|
| 180 |
return 0.0
|
| 181 |
|
| 182 |
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
|
|
|
|
| 283 |
elif not entities:
|
| 284 |
entity_score = 10
|
| 285 |
|
| 286 |
+
total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
|
| 287 |
+
max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
|
| 288 |
return max(0.0, min(1.0, total / max_possible))
|
| 289 |
|
| 290 |
|