hft_oversight / run_agent.py
schangg's picture
Upload folder using huggingface_hub
adf36ff verified
"""Run a HuggingFace LLM agent against the HFT Oversight Environment.
Collects trajectories for fine-tuning. Start with difficulty=1 (obvious errors).
Usage:
uv run python run_agent.py --episodes 10
uv run python run_agent.py --episodes 10 --difficulty 1
"""
import json
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from huggingface_hub import InferenceClient
from server.environment import HFTOversightEnvironment
from models import OversightAction
# --- Config ---
from huggingface_hub import get_token
HF_TOKEN = os.environ.get("HF_TOKEN", "") or get_token() or ""
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
SYSTEM_PROMPT = """You are an HFT oversight agent. You manage trading bots and must find and shut down broken ones.
Each turn, respond with ONLY a JSON action. No other text.
Commands:
- {"command": "list_bots"}
- {"command": "read_logs", "bot_id": "NAME"}
- {"command": "check_pnl", "bot_id": "NAME"}
- {"command": "shutdown", "bot_id": "NAME", "reason": "WHY"}
Look for errors, bad prices, or suspicious behavior. Shut down broken bots."""
def parse_action(text: str) -> OversightAction:
text = text.strip()
# Strip markdown code blocks
if "```" in text:
text = text.split("```")[1].removeprefix("json").strip()
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
data = json.loads(text[start:end])
return OversightAction(**data)
raise ValueError(f"Could not parse action from: {text}")
def run_episode(client: InferenceClient, difficulty: int = 1) -> dict:
env = HFTOversightEnvironment()
env._difficulty = difficulty
obs = env.reset()
print(f"\n{'='*60}")
print(f"EPISODE (difficulty={difficulty})")
print(f"{'='*60}")
print(obs.response[:300])
if obs.alerts:
print(f"Alerts: {obs.alerts}")
# Build conversation
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": obs.response + (f"\n\nAlerts: {obs.alerts}" if obs.alerts else "")},
]
trajectory = []
total_reward = 0.0
while not obs.done:
# Query the model
try:
response = client.chat_completion(
messages=messages,
max_tokens=200,
temperature=0.3,
)
llm_text = response.choices[0].message.content
except Exception as e:
print(f" Model error: {e}")
llm_text = '{"command": "pass_turn"}'
consecutive_errors = consecutive_errors + 1 if 'consecutive_errors' in dir() else 1
if consecutive_errors >= 3:
print(" 3 consecutive model errors — aborting episode.")
break
print(f"\n LLM (step {obs.timestep + 1}): {llm_text[:150]}")
try:
action = parse_action(llm_text)
except (ValueError, json.JSONDecodeError) as e:
print(f" Parse error: {e}")
action = OversightAction(command="pass_turn")
# Step environment
obs = env.step(action)
total_reward += obs.reward
print(f" ENV: {obs.response[:150]}")
print(f" [reward={obs.reward}, total={total_reward}, step={obs.timestep}/{obs.max_timesteps}]")
# Record trajectory step
trajectory.append({
"messages_so_far": [m.copy() for m in messages],
"assistant_response": llm_text,
"action": action.model_dump(exclude_none=True),
"reward": obs.reward,
"cumulative_reward": total_reward,
"done": obs.done,
})
# Feed back to conversation
messages.append({"role": "assistant", "content": llm_text})
env_msg = obs.response
if obs.alerts:
env_msg += f"\n\nAlerts: {obs.alerts}"
env_msg += f"\n\n[Step {obs.timestep}/{obs.max_timesteps}]"
messages.append({"role": "user", "content": env_msg})
print(f"\n DONE — Total reward: {total_reward}")
return {
"difficulty": difficulty,
"total_reward": total_reward,
"steps": len(trajectory),
"trajectory": trajectory,
"full_conversation": messages,
}
DIFFICULTY_LEVELS = [1, 2, 3, 5, 7]
FAST_SOLVE_THRESHOLD = 3 # solved in <= this many steps = "quick"
STREAK_TO_ADVANCE = 3 # consecutive wins to level up
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--difficulty", type=int, default=1, help="Starting difficulty")
parser.add_argument("--adaptive", action="store_true", default=True,
help="Auto-increase difficulty (default: on)")
parser.add_argument("--no-adaptive", dest="adaptive", action="store_false")
parser.add_argument("--output", type=str, default="trajectories.jsonl")
args = parser.parse_args()
if not HF_TOKEN:
print("Set HF_TOKEN env var: export HF_TOKEN=hf_xxx")
sys.exit(1)
client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
# Adaptive difficulty state
difficulty = args.difficulty
win_streak = 0
level_idx = DIFFICULTY_LEVELS.index(difficulty) if difficulty in DIFFICULTY_LEVELS else 0
print(f"Model: {MODEL_ID}")
print(f"Running {args.episodes} episodes, starting difficulty={difficulty}")
if args.adaptive:
print(f"Adaptive mode: level up after {STREAK_TO_ADVANCE} wins or a fast solve (<={FAST_SOLVE_THRESHOLD} steps)")
all_results = []
for i in range(args.episodes):
print(f"\n{'─'*60}")
print(f"Episode {i+1}/{args.episodes} | difficulty={difficulty} | win_streak={win_streak}")
print(f"{'─'*60}")
result = run_episode(client, difficulty)
result["difficulty"] = difficulty
all_results.append(result)
if args.adaptive:
won = result["total_reward"] > 0
fast = won and result["steps"] <= FAST_SOLVE_THRESHOLD
if won:
win_streak += 1
else:
win_streak = 0
should_advance = (win_streak >= STREAK_TO_ADVANCE) or fast
if should_advance and level_idx < len(DIFFICULTY_LEVELS) - 1:
level_idx += 1
difficulty = DIFFICULTY_LEVELS[level_idx]
win_streak = 0
print(f"\n >> LEVEL UP! Now at difficulty {difficulty}")
elif won:
print(f"\n >> Win streak: {win_streak}/{STREAK_TO_ADVANCE}")
if not won and level_idx > 0:
# Drop back down after a failure at a new level
pass # stay at current level, just reset streak
# Save trajectories
with open(args.output, "w") as f:
for r in all_results:
f.write(json.dumps(r) + "\n")
# Summary by difficulty
print(f"\n{'='*60}")
print("RESULTS SUMMARY")
print(f"{'='*60}")
for lvl in DIFFICULTY_LEVELS:
lvl_results = [r for r in all_results if r["difficulty"] == lvl]
if not lvl_results:
continue
rewards = [r["total_reward"] for r in lvl_results]
wins = sum(1 for r in rewards if r > 0)
avg_steps = sum(r["steps"] for r in lvl_results) / len(lvl_results)
print(f" Difficulty {lvl}: {wins}/{len(lvl_results)} wins, "
f"avg reward={sum(rewards)/len(rewards):.1f}, avg steps={avg_steps:.1f}")
print(f"\n Total episodes: {len(all_results)}")
print(f" Max difficulty reached: {max(r['difficulty'] for r in all_results)}")
print(f" Saved to: {args.output}")
if __name__ == "__main__":
main()