Arjs commited on
Commit
e9d38e7
·
verified ·
1 Parent(s): e9c9d34

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +137 -137
inference.py CHANGED
@@ -1,138 +1,138 @@
1
- import os
2
- import json
3
- import time
4
- from openai import OpenAI
5
- from server.cust_env_environment import DocSweeperEnvironment
6
- from models import DocAction
7
-
8
- IMAGE_NAME = os.getenv("IMAGE_NAME")
9
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
-
11
- API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
12
- MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
13
- BENCHMARK_NAME = "doc_sweeper"
14
-
15
- def run_inference(task_name: str):
16
- api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
17
- model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
18
- hf_token = os.environ.get("HF_TOKEN") or API_KEY
19
-
20
- if not api_base_url:
21
- raise ValueError("Missing API base url")
22
- if not model_name:
23
- raise ValueError("Missing model name")
24
- if not hf_token:
25
- raise ValueError("Missing hf_token")
26
-
27
- client = OpenAI(
28
- api_key=hf_token,
29
- base_url=api_base_url,
30
- timeout=15.0,
31
- max_retries=1
32
- )
33
-
34
- env = DocSweeperEnvironment(task=task_name)
35
- obs = env.reset()
36
-
37
- done = False
38
- total_reward = 0.0
39
- step_count = 0
40
- rewards_history = []
41
- MAX_STEPS = 20
42
-
43
- print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
44
-
45
- system_prompt = f"""
46
- You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
47
-
48
- YOUR CURRENT TASK: '{task_name}'
49
- - If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
50
- - If 'config_migration': Open docker-compose files. Update version to '3.8' and migrate 'links:' to 'networks:'.
51
- - If 'broken_links': Find broken relative links containing '../old-docs/' and edit them to point strictly to './new-docs/'.
52
-
53
- WORKFLOW RULES:
54
- 1. PLAN IN THOUGHT: Use the 'thought' field to reason. NEVER use a tool called "plan". Valid tools are strictly: 'open', 'edit', 'grep', 'done'.
55
- 2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
56
- 3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
57
- 4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
58
-
59
- OUTPUT SCHEMA:
60
- You MUST output ONLY a single raw JSON object EXACTLY matching this structure:
61
- {{
62
- "thought": "<Mandatory step-by-step reasoning>",
63
- "tool_name": "<MUST be one of: 'open', 'edit', 'grep', 'done'>",
64
- "path": "<Optional. File path for 'open'>",
65
- "old_str": "<Optional. Exact match string for 'edit'>",
66
- "new_str": "<Optional. Replacement string for 'edit'>",
67
- "search_query": "<Optional. Text to search for 'grep'>"
68
- }}
69
- """
70
-
71
- messages = [{"role": "system", "content": system_prompt}]
72
- start_time = time.time()
73
-
74
- while not done and step_count < MAX_STEPS:
75
- step_count += 1
76
- current_state_prompt = f"""
77
- [ENVIRONMENT OBSERVATION]
78
- Active File: {obs.active_file or 'None'}
79
- Terminal Feedback: {obs.terminal_feedback}
80
- Directory Tree: {json.dumps(obs.directory_tree)}
81
- File Content: {obs.file_content}
82
- Linter Issues: {obs.issues_detected}
83
- """
84
- messages.append({"role": "user", "content": current_state_prompt})
85
-
86
- try:
87
- response = client.chat.completions.create(
88
- model=model_name,
89
- messages=messages,
90
- response_format={"type": "json_object"}
91
- )
92
-
93
- raw_reply = response.choices[0].message.content
94
- messages.append({"role": "assistant", "content": raw_reply})
95
-
96
- action_json = json.loads(raw_reply)
97
- if isinstance(action_json, list):
98
- action_json = action_json[0] if len(action_json) > 0 else {"tool_name": "done"}
99
-
100
- thought = action_json.pop("thought", "None")
101
-
102
- valid_fields = DocAction.model_fields.keys()
103
- safe_kwargs = {k: v for k, v in action_json.items() if k in valid_fields}
104
-
105
- action = DocAction(**safe_kwargs)
106
- obs = env.step(action)
107
-
108
- total_reward += obs.reward
109
- rewards_history.append(obs.reward)
110
- done = obs.done
111
-
112
- action_str = f"{action.tool_name}"
113
- done_str = str(done).lower()
114
- print(f"[STEP] step={step_count} action={action_str} reward={obs.reward:.2f} done={done_str} error=null", flush=True)
115
-
116
- except Exception as e:
117
- error_msg = str(e).replace('\n', ' ')
118
- obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
119
- rewards_history.append(0.0)
120
-
121
- if "timeout" in error_msg.lower() or "connection" in error_msg.lower():
122
- done = True
123
-
124
- done_str = str(done).lower()
125
- print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
126
-
127
- final_score = max(0.0, min(1.0, total_reward))
128
- success = final_score > 0.0
129
- success_str = str(success).lower()
130
- rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
131
-
132
- print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}", flush=True)
133
-
134
-
135
- if __name__ == "__main__":
136
- tasks = ["version_bump", "config_migration", "broken_links"]
137
- for task in tasks:
138
  run_inference(task)
 
1
+ import os
2
+ import json
3
+ import time
4
+ from openai import OpenAI
5
+ from server.cust_env_environment import DocSweeperEnvironment
6
+ from models import DocAction
7
+
8
+ IMAGE_NAME = os.getenv("IMAGE_NAME")
9
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
+
11
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
12
+ MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
13
+ BENCHMARK_NAME = "doc_sweeper"
14
+
15
+ def run_inference(task_name: str):
16
+ api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
17
+ model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
18
+ hf_token = os.environ.get("HF_TOKEN") or API_KEY
19
+
20
+ if not api_base_url:
21
+ raise ValueError("Missing API base url")
22
+ if not model_name:
23
+ raise ValueError("Missing model name")
24
+ if not hf_token:
25
+ raise ValueError("Missing hf_token")
26
+
27
+ client = OpenAI(
28
+ api_key=hf_token,
29
+ base_url=api_base_url,
30
+ timeout=15.0,
31
+ max_retries=1
32
+ )
33
+
34
+ env = DocSweeperEnvironment(task=task_name)
35
+ obs = env.reset()
36
+
37
+ done = False
38
+ total_reward = 0.0
39
+ step_count = 0
40
+ rewards_history = []
41
+ MAX_STEPS = 20
42
+
43
+ print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
44
+
45
+ system_prompt = f"""
46
+ You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
47
+
48
+ YOUR CURRENT TASK: '{task_name}'
49
+ - If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
50
+ - If 'config_migration': Open docker-compose files. Update version to '3.8' and migrate 'links:' to 'networks:'.
51
+ - If 'broken_links': Find broken relative links containing '../old-docs/' and edit them to point strictly to './new-docs/'.
52
+
53
+ WORKFLOW RULES:
54
+ 1. PLAN IN THOUGHT: Use the 'thought' field to reason. NEVER use a tool called "plan". Valid tools are strictly: 'open', 'edit', 'grep', 'done'.
55
+ 2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
56
+ 3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
57
+ 4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
58
+
59
+ OUTPUT SCHEMA:
60
+ You MUST output ONLY a single raw JSON object EXACTLY matching this structure:
61
+ {{
62
+ "thought": "<Mandatory step-by-step reasoning>",
63
+ "tool_name": "<MUST be one of: 'open', 'edit', 'grep', 'done'>",
64
+ "path": "<Optional. File path for 'open'>",
65
+ "old_str": "<Optional. Exact match string for 'edit'>",
66
+ "new_str": "<Optional. Replacement string for 'edit'>",
67
+ "search_query": "<Optional. Text to search for 'grep'>"
68
+ }}
69
+ """
70
+
71
+ messages = [{"role": "system", "content": system_prompt}]
72
+ start_time = time.time()
73
+
74
+ while not done and step_count < MAX_STEPS:
75
+ step_count += 1
76
+ current_state_prompt = f"""
77
+ [ENVIRONMENT OBSERVATION]
78
+ Active File: {obs.active_file or 'None'}
79
+ Terminal Feedback: {obs.terminal_feedback}
80
+ Directory Tree: {json.dumps(obs.directory_tree)}
81
+ File Content: {obs.file_content}
82
+ Linter Issues: {obs.issues_detected}
83
+ """
84
+ messages.append({"role": "user", "content": current_state_prompt})
85
+
86
+ try:
87
+ response = client.chat.completions.create(
88
+ model=model_name,
89
+ messages=messages,
90
+ response_format={"type": "json_object"}
91
+ )
92
+
93
+ raw_reply = response.choices[0].message.content
94
+ messages.append({"role": "assistant", "content": raw_reply})
95
+
96
+ action_json = json.loads(raw_reply)
97
+ if isinstance(action_json, list):
98
+ action_json = action_json[0] if len(action_json) > 0 else {"tool_name": "done"}
99
+
100
+ thought = action_json.pop("thought", "None")
101
+
102
+ valid_fields = DocAction.model_fields.keys()
103
+ safe_kwargs = {k: v for k, v in action_json.items() if k in valid_fields}
104
+
105
+ action = DocAction(**safe_kwargs)
106
+ obs = env.step(action)
107
+
108
+ total_reward += obs.reward
109
+ rewards_history.append(obs.reward)
110
+ done = obs.done
111
+
112
+ action_str = f"{action.tool_name}"
113
+ done_str = str(done).lower()
114
+ print(f"[STEP] step={step_count} action={action_str} reward={obs.reward:.2f} done={done_str} error=null", flush=True)
115
+
116
+ except Exception as e:
117
+ error_msg = str(e).replace('\n', ' ')
118
+ obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
119
+ rewards_history.append(0.0)
120
+
121
+ if "timeout" in error_msg.lower() or "connection" in error_msg.lower():
122
+ done = True
123
+
124
+ done_str = str(done).lower()
125
+ print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
126
+
127
+ final_score = max(0.01, min(.99, total_reward))
128
+ success = final_score > 0.0
129
+ success_str = str(success).lower()
130
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
131
+
132
+ print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}", flush=True)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ tasks = ["version_bump", "config_migration", "broken_links"]
137
+ for task in tasks:
138
  run_inference(task)