Spaces:
Sleeping
Sleeping
File size: 5,716 Bytes
e9d38e7 e03aa3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import os
import json
import time
from openai import OpenAI
from server.cust_env_environment import DocSweeperEnvironment
from models import DocAction
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
BENCHMARK_NAME = "doc_sweeper"
def run_inference(task_name: str):
api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
hf_token = os.environ.get("HF_TOKEN") or API_KEY
if not api_base_url:
raise ValueError("Missing API base url")
if not model_name:
raise ValueError("Missing model name")
if not hf_token:
raise ValueError("Missing hf_token")
client = OpenAI(
api_key=hf_token,
base_url=api_base_url,
timeout=15.0,
max_retries=1
)
env = DocSweeperEnvironment(task=task_name)
obs = env.reset()
done = False
total_reward = 0.0
step_count = 0
rewards_history = []
MAX_STEPS = 20
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
system_prompt = f"""
You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
YOUR CURRENT TASK: '{task_name}'
- If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
- If 'config_migration': Open docker-compose files. Update version to '3.8' and migrate 'links:' to 'networks:'.
- If 'broken_links': Find broken relative links containing '../old-docs/' and edit them to point strictly to './new-docs/'.
WORKFLOW RULES:
1. PLAN IN THOUGHT: Use the 'thought' field to reason. NEVER use a tool called "plan". Valid tools are strictly: 'open', 'edit', 'grep', 'done'.
2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
OUTPUT SCHEMA:
You MUST output ONLY a single raw JSON object EXACTLY matching this structure:
{{
"thought": "<Mandatory step-by-step reasoning>",
"tool_name": "<MUST be one of: 'open', 'edit', 'grep', 'done'>",
"path": "<Optional. File path for 'open'>",
"old_str": "<Optional. Exact match string for 'edit'>",
"new_str": "<Optional. Replacement string for 'edit'>",
"search_query": "<Optional. Text to search for 'grep'>"
}}
"""
messages = [{"role": "system", "content": system_prompt}]
start_time = time.time()
while not done and step_count < MAX_STEPS:
step_count += 1
current_state_prompt = f"""
[ENVIRONMENT OBSERVATION]
Active File: {obs.active_file or 'None'}
Terminal Feedback: {obs.terminal_feedback}
Directory Tree: {json.dumps(obs.directory_tree)}
File Content: {obs.file_content}
Linter Issues: {obs.issues_detected}
"""
messages.append({"role": "user", "content": current_state_prompt})
try:
response = client.chat.completions.create(
model=model_name,
messages=messages,
response_format={"type": "json_object"}
)
raw_reply = response.choices[0].message.content
messages.append({"role": "assistant", "content": raw_reply})
action_json = json.loads(raw_reply)
if isinstance(action_json, list):
action_json = action_json[0] if len(action_json) > 0 else {"tool_name": "done"}
thought = action_json.pop("thought", "None")
valid_fields = DocAction.model_fields.keys()
safe_kwargs = {k: v for k, v in action_json.items() if k in valid_fields}
action = DocAction(**safe_kwargs)
obs = env.step(action)
total_reward += obs.reward
rewards_history.append(obs.reward)
done = obs.done
action_str = f"{action.tool_name}"
done_str = str(done).lower()
print(f"[STEP] step={step_count} action={action_str} reward={obs.reward:.2f} done={done_str} error=null", flush=True)
except Exception as e:
error_msg = str(e).replace('\n', ' ')
obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
rewards_history.append(0.0)
if "timeout" in error_msg.lower() or "connection" in error_msg.lower():
done = True
done_str = str(done).lower()
print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
final_score = max(0.01, min(.99, total_reward))
success = final_score > 0.0
success_str = str(success).lower()
rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}", flush=True)
if __name__ == "__main__":
tasks = ["version_bump", "config_migration", "broken_links"]
for task in tasks:
run_inference(task) |