Spaces:
Running
Running
File size: 5,454 Bytes
2073b3f 4b33490 2073b3f 4b33490 2073b3f 4b33490 2073b3f 4b33490 2073b3f 0f8f2c1 2073b3f e56d042 2073b3f e56d042 2073b3f 4b33490 eb4dbc2 4b33490 2073b3f 4b33490 e56d042 4b33490 e56d042 4b33490 e56d042 4b33490 e56d042 2073b3f 4b33490 e56d042 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import os
import textwrap
from typing import List
from openai import OpenAI
from client import AwsRlEnv
from models import AwsRlAction, AwsRlObservation
from dotenv import load_dotenv
load_dotenv() # Load variables from .env file if present
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.1-8B-Instruct"
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
BENCHMARK = "aws-rl-env"
MAX_STEPS = 15
client_llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
SYSTEM_PROMPT = textwrap.dedent(
"""
You are an AWS cloud engineer interacting with a real AWS environment via CLI.
Each turn you must send exactly ONE valid AWS CLI command (starting with 'aws').
You will be given a task to accomplish. Read the task description carefully.
Use the command output and error messages to guide your next action.
Rules:
- Only send AWS CLI commands (e.g. 'aws s3 ls', 'aws dynamodb create-table ...')
- One command per turn — no pipes, no shell syntax, no chaining
- Reply with ONLY the command, nothing else — no explanations, no quotes
- If unsure, use 'aws help' to get unstuck, but try to be specific to the service if possible (e.g. 'aws s3 help')
- When ever you need a hint, use 'aws help --task-hint' to get a task-specific hint (you can use this multiple times for more hints, but hints reduce your reward)
"""
).strip()
def build_user_prompt(
task_description: str,
step: int,
last_output: str,
last_error: str,
last_reward: float,
history: List[str],
) -> str:
history_block = "\n".join(history[-6:]) if history else "None"
return textwrap.dedent(
f"""
TASK: {task_description}
Step: {step}
Last command output: {last_output!r}
Last error: {last_error!r}
Last reward: {last_reward:.2f}
Previous steps:
{history_block}
Send your next AWS CLI command.
"""
).strip()
def get_model_command(
client: OpenAI,
task_description: str,
step: int,
last_output: str,
last_error: str,
last_reward: float,
history: List[str],
) -> str:
user_prompt = build_user_prompt(
task_description, step, last_output, last_error, last_reward, history
)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=800,
)
text = (completion.choices[0].message.content or "").strip()
# Strip markdown code fences if the model wraps the command
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(
line for line in lines if not line.startswith("```")
).strip()
return text if text.startswith("aws ") else "aws help"
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "aws help"
def run_task(env_url: str) -> None:
with AwsRlEnv(base_url=env_url).sync() as env:
for _ in range(11):
result = env.reset()
obs: AwsRlObservation = result.observation
last_output = obs.command_output
last_error = ""
last_reward = 0.0
history: List[str] = []
rewards: List[float] = []
print(f"[START] task={obs.task.task_id} env={BENCHMARK} model={MODEL_NAME}")
for step in range(1, MAX_STEPS + 1):
command = get_model_command(
client_llm,
obs.task.description,
obs.step_count,
last_output,
last_error,
last_reward,
history,
)
result = env.step(AwsRlAction(command=command))
obs: AwsRlObservation = result.observation
reward = obs.reward or 0.0
done = result.done
last_error = obs.error
last_output = obs.command_output
last_reward = reward
# Clamp reward to strictly (0, 1) for validator
if reward <= 0.0:
reward = 0.01
elif reward >= 1.0:
reward = 0.99
rewards.append(reward)
steps = step
done_str = "true" if done else "false"
print(
f"[STEP] step={step} action={command!r} reward={reward:.2f} done={done_str} error={last_error!r}"
)
# Task achieved — episode success
if obs.task_achieved:
break
if done:
break
score = max(rewards) if rewards else 0.1
score = min(max(score, 0.01), 0.99) # clamp to (0, 1)
success_str = "true" if obs.task_achieved else "false"
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={success_str} steps={steps} score={score:.2f} rewards={rewards_str}"
)
if __name__ == "__main__":
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
run_task(ENV_URL)
|