DocSweeper / inference.py
Arjs's picture
Upload inference.py
e9d38e7 verified
import os
import json
import time
from openai import OpenAI
from server.cust_env_environment import DocSweeperEnvironment
from models import DocAction
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
BENCHMARK_NAME = "doc_sweeper"
def run_inference(task_name: str):
api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
hf_token = os.environ.get("HF_TOKEN") or API_KEY
if not api_base_url:
raise ValueError("Missing API base url")
if not model_name:
raise ValueError("Missing model name")
if not hf_token:
raise ValueError("Missing hf_token")
client = OpenAI(
api_key=hf_token,
base_url=api_base_url,
timeout=15.0,
max_retries=1
)
env = DocSweeperEnvironment(task=task_name)
obs = env.reset()
done = False
total_reward = 0.0
step_count = 0
rewards_history = []
MAX_STEPS = 20
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
system_prompt = f"""
You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
YOUR CURRENT TASK: '{task_name}'
- If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
- If 'config_migration': Open docker-compose files. Update version to '3.8' and migrate 'links:' to 'networks:'.
- If 'broken_links': Find broken relative links containing '../old-docs/' and edit them to point strictly to './new-docs/'.
WORKFLOW RULES:
1. PLAN IN THOUGHT: Use the 'thought' field to reason. NEVER use a tool called "plan". Valid tools are strictly: 'open', 'edit', 'grep', 'done'.
2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
OUTPUT SCHEMA:
You MUST output ONLY a single raw JSON object EXACTLY matching this structure:
{{
"thought": "<Mandatory step-by-step reasoning>",
"tool_name": "<MUST be one of: 'open', 'edit', 'grep', 'done'>",
"path": "<Optional. File path for 'open'>",
"old_str": "<Optional. Exact match string for 'edit'>",
"new_str": "<Optional. Replacement string for 'edit'>",
"search_query": "<Optional. Text to search for 'grep'>"
}}
"""
messages = [{"role": "system", "content": system_prompt}]
start_time = time.time()
while not done and step_count < MAX_STEPS:
step_count += 1
current_state_prompt = f"""
[ENVIRONMENT OBSERVATION]
Active File: {obs.active_file or 'None'}
Terminal Feedback: {obs.terminal_feedback}
Directory Tree: {json.dumps(obs.directory_tree)}
File Content: {obs.file_content}
Linter Issues: {obs.issues_detected}
"""
messages.append({"role": "user", "content": current_state_prompt})
try:
response = client.chat.completions.create(
model=model_name,
messages=messages,
response_format={"type": "json_object"}
)
raw_reply = response.choices[0].message.content
messages.append({"role": "assistant", "content": raw_reply})
action_json = json.loads(raw_reply)
if isinstance(action_json, list):
action_json = action_json[0] if len(action_json) > 0 else {"tool_name": "done"}
thought = action_json.pop("thought", "None")
valid_fields = DocAction.model_fields.keys()
safe_kwargs = {k: v for k, v in action_json.items() if k in valid_fields}
action = DocAction(**safe_kwargs)
obs = env.step(action)
total_reward += obs.reward
rewards_history.append(obs.reward)
done = obs.done
action_str = f"{action.tool_name}"
done_str = str(done).lower()
print(f"[STEP] step={step_count} action={action_str} reward={obs.reward:.2f} done={done_str} error=null", flush=True)
except Exception as e:
error_msg = str(e).replace('\n', ' ')
obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
rewards_history.append(0.0)
if "timeout" in error_msg.lower() or "connection" in error_msg.lower():
done = True
done_str = str(done).lower()
print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
final_score = max(0.01, min(.99, total_reward))
success = final_score > 0.0
success_str = str(success).lower()
rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}", flush=True)
if __name__ == "__main__":
tasks = ["version_bump", "config_migration", "broken_links"]
for task in tasks:
run_inference(task)