deepbattler / RL /rewrite_battleground_rewards.py
wyksdsg's picture
Upload folder using huggingface_hub
787c99c verified
#!/usr/bin/env python
import argparse
import json
from pathlib import Path
ROLE_TO_REWARD = {
"expert": 1.0,
"medium": 0.0,
"bad": -0.5,
}
def rewrite_file(input_path: Path, output_path: Path) -> None:
"""Rewrite rewards in a Battlegrounds RLAIF JSONL file based on candidate role.
Mapping:
- expert -> 1.0
- medium -> 0.0
- bad -> -0.5
Other roles keep their original reward.
"""
with input_path.open("r", encoding="utf-8") as fin, output_path.open(
"w", encoding="utf-8"
) as fout:
for line in fin:
line = line.strip()
if not line:
continue
obj = json.loads(line)
candidates = obj.get("candidates", [])
for cand in candidates:
role = cand.get("role")
if role in ROLE_TO_REWARD:
cand["reward"] = ROLE_TO_REWARD[role]
fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
def main() -> None:
parser = argparse.ArgumentParser(
description="Rewrite rewards in Battlegrounds RLAIF JSONL dataset based on candidate role."
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to input JSONL file (original dataset).",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Path to output JSONL file (rewritten dataset).",
)
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
raise SystemExit(f"Input file not found: {input_path}")
if output_path.exists() and input_path.resolve() == output_path.resolve():
raise SystemExit("Refusing to overwrite input file in-place. Use a different --output path.")
rewrite_file(input_path, output_path)
print(f"Rewritten dataset saved to: {output_path}")
if __name__ == "__main__":
main()