Majen commited on
Commit
fc89f5e
·
verified ·
1 Parent(s): ad1d804

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +82 -82
inference.py CHANGED
@@ -1,82 +1,82 @@
1
- """Submission-grade baseline inference runner for RecallTrace."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import os
7
- from typing import Any, List
8
-
9
- from openai import OpenAI
10
-
11
- from env.env import RecallTraceEnv
12
- from env.models import RecallAction
13
- from grader.grader import grade_finalize_info
14
- from baseline.policy import choose_heuristic_action, choose_llm_action
15
-
16
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
18
- API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN", "")
19
- BENCHMARK = "RecallTrace"
20
-
21
-
22
- def log_start(task: str, env: str, model: str) -> None:
23
- print(f"[START] task={task} env={env} model={model}", flush=True)
24
-
25
-
26
- def log_step(step: int, action: RecallAction, reward: float, done: bool, error: str | None) -> None:
27
- payload = json.dumps(action.model_dump(exclude_none=True), sort_keys=True)
28
- error_text = error if error is not None else "null"
29
- print(f"[STEP] step={step} action={payload} reward={reward:.4f} done={str(done).lower()} error={error_text}", flush=True)
30
-
31
-
32
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
33
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={json.dumps([round(r, 4) for r in rewards])}", flush=True)
34
-
35
-
36
- def run_task(task_id: str, client: OpenAI | None) -> float:
37
- env = RecallTraceEnv(task_id=task_id)
38
- observation = env.reset()
39
-
40
- history: List[dict[str, Any]] = []
41
- rewards: List[float] = []
42
- steps_taken = 0
43
- final_info: dict[str, Any] = {"score": 0.0}
44
-
45
- log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME if client else "heuristic-baseline")
46
-
47
- for step in range(1, env.task.max_steps + 1):
48
- llm_action = choose_llm_action(client, MODEL_NAME, observation, history)
49
- action = llm_action or choose_heuristic_action(observation)
50
-
51
- observation, reward, done, info = env.step(action)
52
- rewards.append(reward)
53
- steps_taken = step
54
- final_info = info
55
- log_step(step=step, action=action, reward=reward, done=done, error=info.get("error"))
56
-
57
- history.append(
58
- {
59
- "step": step,
60
- "action": action.model_dump(exclude_none=True),
61
- "reward": reward,
62
- "done": done,
63
- "message": info.get("message"),
64
- }
65
- )
66
- if done:
67
- break
68
-
69
- grade = grade_finalize_info(task_id, steps_taken, final_info)
70
- log_end(success=grade.success, steps=steps_taken, score=grade.score, rewards=rewards)
71
- return grade.score
72
-
73
-
74
- def main() -> None:
75
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
76
- task_scores = [run_task(task.task_id, client) for task in RecallTraceEnv.available_tasks()]
77
- average_score = sum(task_scores) / len(task_scores)
78
- print(json.dumps({"benchmark": BENCHMARK, "average_score": round(average_score, 4), "task_scores": task_scores}), flush=True)
79
-
80
-
81
- if __name__ == "__main__":
82
- main()
 
1
+ """Submission-grade baseline inference runner for RecallTrace."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any, List
8
+
9
+ from openai import OpenAI
10
+
11
+ from env.env import RecallTraceEnv
12
+ from env.models import RecallAction
13
+ from grader.grader import grade_finalize_info
14
+ from baseline.policy import choose_heuristic_action, choose_llm_action
15
+
16
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
18
+ API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
19
+ BENCHMARK = "RecallTrace"
20
+
21
+
22
+ def log_start(task: str, env: str, model: str) -> None:
23
+ print(f"[START] task={task} env={env} model={model}", flush=True)
24
+
25
+
26
+ def log_step(step: int, action: RecallAction, reward: float, done: bool, error: str | None) -> None:
27
+ payload = json.dumps(action.model_dump(exclude_none=True), sort_keys=True)
28
+ error_text = error if error is not None else "null"
29
+ print(f"[STEP] step={step} action={payload} reward={reward:.2f} done={str(done).lower()} error={error_text}", flush=True)
30
+
31
+
32
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
33
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={",".join(f"{r:.2f}" for r in rewards)}", flush=True)
34
+
35
+
36
+ def run_task(task_id: str, client: OpenAI | None) -> float:
37
+ env = RecallTraceEnv(task_id=task_id)
38
+ observation = env.reset()
39
+
40
+ history: List[dict[str, Any]] = []
41
+ rewards: List[float] = []
42
+ steps_taken = 0
43
+ final_info: dict[str, Any] = {"score": 0.0}
44
+
45
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME if client else "heuristic-baseline")
46
+
47
+ for step in range(1, env.task.max_steps + 1):
48
+ llm_action = choose_llm_action(client, MODEL_NAME, observation, history)
49
+ action = llm_action or choose_heuristic_action(observation)
50
+
51
+ observation, reward, done, info = env.step(action)
52
+ rewards.append(reward)
53
+ steps_taken = step
54
+ final_info = info
55
+ log_step(step=step, action=action, reward=reward, done=done, error=info.get("error"))
56
+
57
+ history.append(
58
+ {
59
+ "step": step,
60
+ "action": action.model_dump(exclude_none=True),
61
+ "reward": reward,
62
+ "done": done,
63
+ "message": info.get("message"),
64
+ }
65
+ )
66
+ if done:
67
+ break
68
+
69
+ grade = grade_finalize_info(task_id, steps_taken, final_info)
70
+ log_end(success=grade.success, steps=steps_taken, score=grade.score, rewards=rewards)
71
+ return grade.score
72
+
73
+
74
+ def main() -> None:
75
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
76
+ task_scores = [run_task(task.task_id, client) for task in RecallTraceEnv.available_tasks()]
77
+ average_score = sum(task_scores) / len(task_scores)
78
+ print(json.dumps({"benchmark": BENCHMARK, "average_score": round(average_score, 4), "task_scores": task_scores}), flush=True)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()