Scheduling-agent / inference.py
Aryan
changed variable name
d4cfa03
Raw
History Blame Contribute Delete
5.27 kB
import os
import json
from typing import List, Optional
from openai import OpenAI
from env import SchedulingEnv, Action
from dotenv import load_dotenv
load_dotenv()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = str(error).replace('\n', ' ') if error else "null"
done_val = str(done).lower()
# clean newlines from action
action = action.replace('\n', ' ')
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def main():
hf_token = os.getenv("HF_TOKEN")
api_key = os.getenv("API_KEY")
if api_key:
api_base_url = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
elif hf_token:
api_key = hf_token
api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
else:
# Fallback if neither are set
api_key = "sk-..."
api_base_url = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("MODEL_NAME", "gpt-4o")
client = OpenAI(
base_url=api_base_url,
api_key=api_key,
)
task_level = os.getenv("TASK_LEVEL", "easy")
env = SchedulingEnv(task_level=task_level)
obs = env.reset()
log_start(task=task_level, env="scheduling_benchmark", model=model_name)
messages = [
{
"role": "system",
"content": (
"You are an AI Scheduling Assistant. "
"You manage calendars, check availability, and book meetings. "
"Always use the `step_environment` tool to take actions. "
"You must carefully read the task description and take iterative steps. "
"Do not assume availability, check calendars first."
)
},
{
"role": "user",
"content": f"Initial Observation:\n{json.dumps(obs)}"
}
]
action_schema = Action.model_json_schema()
tools = [
{
"type": "function",
"function": {
"name": "step_environment",
"description": "Execute an action in the environment.",
"parameters": action_schema
}
}
]
rewards_list = []
total_reward = 0.0
steps_taken = 0
done = False
for step_num in range(1, env.max_steps + 1):
steps_taken = step_num
try:
response = client.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
tool_choice={"type": "function", "function": {"name": "step_environment"}}
)
except Exception as e:
error_msg = f"Model API Error: {str(e)}"
log_step(step=step_num, action="api_call_failed", reward=total_reward, done=True, error=error_msg)
break
response_message = response.choices[0].message
messages.append(response_message)
tool_calls = response_message.tool_calls
if not tool_calls:
# Model failed to call a tool, force a submit task
action = Action(action_type="submit_task")
action_str = "submit_task (force_fallback)"
obs, reward, done = env.step(action)
total_reward = reward
rewards_list.append(reward)
log_step(step=step_num, action=action_str, reward=reward, done=done, error=obs['error_message'])
break
tool_call = tool_calls[0]
action_args_str = tool_call.function.arguments
try:
action_args = json.loads(action_args_str)
action = Action(**action_args)
except Exception as e:
# Fallback if invalid action
action = Action(action_type="submit_task")
action_args_str = f"INVALID_JSON: {action_args_str}"
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": json.dumps({"action_taken": action.model_dump()})
})
obs, reward, done = env.step(action)
total_reward = reward
rewards_list.append(reward)
# log step as per stdout rules
log_step(step=step_num, action=action_args_str, reward=reward, done=done, error=obs.get('error_message'))
if done:
break
messages.append({
"role": "user",
"content": f"Observation:\n{json.dumps(obs)}\nReward so far: {reward}"
})
success = total_reward >= 1.0
log_end(success=success, steps=steps_taken, score=total_reward, rewards=rewards_list)
if __name__ == "__main__":
main()