jadechoghari commited on
Commit
ae34bdc
·
verified ·
1 Parent(s): 465963f

Create env.py

Browse files
Files changed (1) hide show
  1. env.py +51 -0
env.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import verifiers as vf
2
+ from verifiers.envs.textarena_env import TextArenaEnv
3
+
4
+ THINK_GUESS_SYSTEM_PROMPT = """..."""
5
+ NOTHINK_GUESS_SYSTEM_PROMPT = """..."""
6
+
7
+ def wordle_feedback_fn(observation: str) -> str:
8
+ if "Feedback:" in observation:
9
+ return observation.split("Feedback:")[-1]
10
+ else:
11
+ return observation
12
+
13
+ def check_answer_reward_func(parser, completion, answer, **kwargs) -> float:
14
+ guess = parser.parse_answer(completion)
15
+ return 1.0 if guess == "[" + answer + "]" else 0.0
16
+
17
+ def count_turns_reward_func(parser, completion, answer, **kwargs) -> float:
18
+ num_turns = len([x for x in completion if x["role"] == "assistant"])
19
+ is_correct = check_answer_reward_func(parser, completion, answer, **kwargs)
20
+ return is_correct / (num_turns + 1)
21
+
22
+ def partial_credit_reward_func(parser, completion, **kwargs) -> float:
23
+ final_env_response = parser.get_user_messages(completion)[-1]["content"].strip()
24
+ guess, scoring = final_env_response.split("\n")[:2]
25
+ num_greens = scoring.count("G")
26
+ num_yellows = scoring.count("Y")
27
+ return 0.2 * num_greens + 0.1 * num_yellows
28
+
29
+ def load_env(num_train_examples=2000, num_eval_examples=20, use_think=True):
30
+ if use_think:
31
+ system_prompt = THINK_GUESS_SYSTEM_PROMPT
32
+ parser = vf.XMLParser(fields=["think", "guess"], answer_field="guess")
33
+ else:
34
+ system_prompt = NOTHINK_GUESS_SYSTEM_PROMPT
35
+ parser = vf.XMLParser(fields=["guess"], answer_field="guess")
36
+
37
+ rubric = vf.Rubric(parser=parser)
38
+ rubric.add_reward_func(check_answer_reward_func)
39
+ rubric.add_reward_func(partial_credit_reward_func)
40
+ rubric.add_reward_func(count_turns_reward_func)
41
+ rubric.add_reward_func(parser.get_format_reward_func(), weight=0.2)
42
+
43
+ return TextArenaEnv(
44
+ game="Wordle-v0",
45
+ num_train_examples=num_train_examples,
46
+ num_eval_examples=num_eval_examples,
47
+ system_prompt=system_prompt,
48
+ parser=parser,
49
+ rubric=rubric,
50
+ feedback_fn=wordle_feedback_fn,
51
+ )