File size: 3,755 Bytes
47e89f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import json
import argparse
from openai import OpenAI
from core.environment import EmailOpsEnv
from core.models import Action

# Mandatory environment variables with defaults per OpenEnv spec
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN") # No default for token

def run_baseline(api_key: str, model_name: str, base_url: str):
    client = OpenAI(api_key=api_key, base_url=base_url)
    env = EmailOpsEnv()
    
    tasks = ["easy", "medium", "hard"]
    
    print(f"Running baseline on model: {model_name}")
    print("=" * 40)
    
    for task_name in tasks:
        # START: Structured logging for OpenEnv automated grading
        print(f"START: {task_name}")
        
        obs = env.reset(task_name)
        
        step_count = 0
        max_steps = 15
        is_done = False
        total_reward = 0.0
        
        while not is_done and step_count < max_steps:
            system_prompt = (
                "You are an intelligent email operations agent. "
                f"Your current goal is: {env.task.description}\n"
                "You must perform actions to achieve this goal. Once you are finished, output the 'submit' action.\n"
                "Available action types:\n"
                " - open_email (requires email_id)\n"
                " - close_email\n"
                " - move_email (requires email_id, folder_name)\n"
                " - reply (requires email_id, reply_body)\n"
                " - delete_email (requires email_id)\n"
                " - flag_email (requires email_id)\n"
                " - submit"
            )
            
            try:
                response = client.beta.chat.completions.parse(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": f"Current Observation:\n{obs.model_dump_json(indent=2)}\nWhat is your next action?"}
                    ],
                    response_format=Action,
                    temperature=0.1
                )
                
                action = response.choices[0].message.parsed
                if not action:
                    break
                
                # STEP: Structured logging for OpenEnv automated grading
                print(f"STEP: {action.model_dump_json()}")
                       
                obs, reward, is_done, metrics = env.step(action)
                total_reward = reward
                
                if action.action_type == "submit":
                    break
                    
            except Exception as e:
                print(f"Error during inference: {e}")
                break
                
            step_count += 1
            
        # END: Structured logging for OpenEnv automated grading
        result = {
            "task": task_name,
            "steps": step_count,
            "reward": total_reward,
            "metrics": env.metrics
        }
        print(f"END: {json.dumps(result)}")
        print("-" * 40)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Prioritizing environment variables as per requirements
    parser.add_argument("--api-key", type=str, default=HF_TOKEN)
    parser.add_argument("--model", type=str, default=MODEL_NAME)
    parser.add_argument("--base-url", type=str, default=API_BASE_URL)
    args = parser.parse_args()
    
    # HF_TOKEN is mandatory for automated submissions
    if not args.api_key:
        print("Please set HF_TOKEN environment variable.")
        exit(1)
        
    run_baseline(args.api_key, args.model, args.base_url)