File size: 5,454 Bytes
2073b3f
 
4b33490
2073b3f
 
 
 
4b33490
 
2073b3f
 
 
4b33490
 
 
 
 
 
2073b3f
4b33490
2073b3f
 
 
 
 
 
 
 
 
 
 
 
0f8f2c1
 
2073b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e56d042
2073b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e56d042
2073b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
4b33490
eb4dbc2
4b33490
 
 
 
2073b3f
4b33490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e56d042
4b33490
 
 
 
 
 
 
 
 
 
 
 
 
e56d042
4b33490
 
 
 
e56d042
 
 
4b33490
 
 
 
 
 
 
 
 
 
 
 
 
e56d042
 
 
2073b3f
 
 
4b33490
 
e56d042
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import textwrap
from typing import List

from openai import OpenAI

from client import AwsRlEnv
from models import AwsRlAction, AwsRlObservation
from dotenv import load_dotenv

load_dotenv()  # Load variables from .env file if present

API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.1-8B-Instruct"
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")

BENCHMARK = "aws-rl-env"
MAX_STEPS = 15

client_llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
SYSTEM_PROMPT = textwrap.dedent(
    """
    You are an AWS cloud engineer interacting with a real AWS environment via CLI.
    Each turn you must send exactly ONE valid AWS CLI command (starting with 'aws').

    You will be given a task to accomplish. Read the task description carefully.
    Use the command output and error messages to guide your next action.

    Rules:
    - Only send AWS CLI commands (e.g. 'aws s3 ls', 'aws dynamodb create-table ...')
    - One command per turn — no pipes, no shell syntax, no chaining
    - Reply with ONLY the command, nothing else — no explanations, no quotes
    - If unsure, use 'aws help' to get unstuck, but try to be specific to the service if possible (e.g. 'aws s3 help')
    - When ever you need a hint, use 'aws help --task-hint' to get a task-specific hint (you can use this multiple times for more hints, but hints reduce your reward)
    """
).strip()


def build_user_prompt(
    task_description: str,
    step: int,
    last_output: str,
    last_error: str,
    last_reward: float,
    history: List[str],
) -> str:
    history_block = "\n".join(history[-6:]) if history else "None"
    return textwrap.dedent(
        f"""
        TASK: {task_description}

        Step: {step}
        Last command output: {last_output!r}
        Last error: {last_error!r}
        Last reward: {last_reward:.2f}

        Previous steps:
        {history_block}

        Send your next AWS CLI command.
        """
    ).strip()


def get_model_command(
    client: OpenAI,
    task_description: str,
    step: int,
    last_output: str,
    last_error: str,
    last_reward: float,
    history: List[str],
) -> str:
    user_prompt = build_user_prompt(
        task_description, step, last_output, last_error, last_reward, history
    )
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            max_tokens=800,
        )
        text = (completion.choices[0].message.content or "").strip()
        # Strip markdown code fences if the model wraps the command
        if text.startswith("```"):
            lines = text.split("\n")
            text = "\n".join(
                line for line in lines if not line.startswith("```")
            ).strip()
        return text if text.startswith("aws ") else "aws help"
    except Exception as exc:
        print(f"[DEBUG] Model request failed: {exc}", flush=True)
        return "aws help"


def run_task(env_url: str) -> None:

    with AwsRlEnv(base_url=env_url).sync() as env:
        for _ in range(11):
            result = env.reset()
            obs: AwsRlObservation = result.observation
            last_output = obs.command_output
            last_error = ""
            last_reward = 0.0
            history: List[str] = []
            rewards: List[float] = []
            print(f"[START] task={obs.task.task_id} env={BENCHMARK} model={MODEL_NAME}")

            for step in range(1, MAX_STEPS + 1):
                command = get_model_command(
                    client_llm,
                    obs.task.description,
                    obs.step_count,
                    last_output,
                    last_error,
                    last_reward,
                    history,
                )

                result = env.step(AwsRlAction(command=command))
                obs: AwsRlObservation = result.observation

                reward = obs.reward or 0.0
                done = result.done
                last_error = obs.error
                last_output = obs.command_output
                last_reward = reward

                # Clamp reward to strictly (0, 1) for validator
                if reward <= 0.0:
                    reward = 0.01
                elif reward >= 1.0:
                    reward = 0.99

                rewards.append(reward)
                steps = step

                done_str = "true" if done else "false"
                print(
                    f"[STEP] step={step} action={command!r} reward={reward:.2f} done={done_str} error={last_error!r}"
                )

                # Task achieved — episode success
                if obs.task_achieved:
                    break

                if done:
                    break

            score = max(rewards) if rewards else 0.1
            score = min(max(score, 0.01), 0.99)  # clamp to (0, 1)

            success_str = "true" if obs.task_achieved else "false"
            rewards_str = ",".join(f"{r:.2f}" for r in rewards)
            print(
                f"[END] success={success_str} steps={steps} score={score:.2f} rewards={rewards_str}"
            )


if __name__ == "__main__":
    ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")

    run_task(ENV_URL)