Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from env import SchedulingEnv, Action | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = str(error).replace('\n', ' ') if error else "null" | |
| done_val = str(done).lower() | |
| # clean newlines from action | |
| action = action.replace('\n', ' ') | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) | |
| def main(): | |
| hf_token = os.getenv("HF_TOKEN") | |
| api_key = os.getenv("API_KEY") | |
| if api_key: | |
| api_base_url = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| elif hf_token: | |
| api_key = hf_token | |
| api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| else: | |
| # Fallback if neither are set | |
| api_key = "sk-..." | |
| api_base_url = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| model_name = os.getenv("MODEL_NAME", "gpt-4o") | |
| client = OpenAI( | |
| base_url=api_base_url, | |
| api_key=api_key, | |
| ) | |
| task_level = os.getenv("TASK_LEVEL", "easy") | |
| env = SchedulingEnv(task_level=task_level) | |
| obs = env.reset() | |
| log_start(task=task_level, env="scheduling_benchmark", model=model_name) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are an AI Scheduling Assistant. " | |
| "You manage calendars, check availability, and book meetings. " | |
| "Always use the `step_environment` tool to take actions. " | |
| "You must carefully read the task description and take iterative steps. " | |
| "Do not assume availability, check calendars first." | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Initial Observation:\n{json.dumps(obs)}" | |
| } | |
| ] | |
| action_schema = Action.model_json_schema() | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "step_environment", | |
| "description": "Execute an action in the environment.", | |
| "parameters": action_schema | |
| } | |
| } | |
| ] | |
| rewards_list = [] | |
| total_reward = 0.0 | |
| steps_taken = 0 | |
| done = False | |
| for step_num in range(1, env.max_steps + 1): | |
| steps_taken = step_num | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| tools=tools, | |
| tool_choice={"type": "function", "function": {"name": "step_environment"}} | |
| ) | |
| except Exception as e: | |
| error_msg = f"Model API Error: {str(e)}" | |
| log_step(step=step_num, action="api_call_failed", reward=total_reward, done=True, error=error_msg) | |
| break | |
| response_message = response.choices[0].message | |
| messages.append(response_message) | |
| tool_calls = response_message.tool_calls | |
| if not tool_calls: | |
| # Model failed to call a tool, force a submit task | |
| action = Action(action_type="submit_task") | |
| action_str = "submit_task (force_fallback)" | |
| obs, reward, done = env.step(action) | |
| total_reward = reward | |
| rewards_list.append(reward) | |
| log_step(step=step_num, action=action_str, reward=reward, done=done, error=obs['error_message']) | |
| break | |
| tool_call = tool_calls[0] | |
| action_args_str = tool_call.function.arguments | |
| try: | |
| action_args = json.loads(action_args_str) | |
| action = Action(**action_args) | |
| except Exception as e: | |
| # Fallback if invalid action | |
| action = Action(action_type="submit_task") | |
| action_args_str = f"INVALID_JSON: {action_args_str}" | |
| messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call.id, | |
| "name": tool_call.function.name, | |
| "content": json.dumps({"action_taken": action.model_dump()}) | |
| }) | |
| obs, reward, done = env.step(action) | |
| total_reward = reward | |
| rewards_list.append(reward) | |
| # log step as per stdout rules | |
| log_step(step=step_num, action=action_args_str, reward=reward, done=done, error=obs.get('error_message')) | |
| if done: | |
| break | |
| messages.append({ | |
| "role": "user", | |
| "content": f"Observation:\n{json.dumps(obs)}\nReward so far: {reward}" | |
| }) | |
| success = total_reward >= 1.0 | |
| log_end(success=success, steps=steps_taken, score=total_reward, rewards=rewards_list) | |
| if __name__ == "__main__": | |
| main() | |