Spaces:
Runtime error
Runtime error
| """Run a HuggingFace LLM agent against the HFT Oversight Environment. | |
| Collects trajectories for fine-tuning. Start with difficulty=1 (obvious errors). | |
| Usage: | |
| uv run python run_agent.py --episodes 10 | |
| uv run python run_agent.py --episodes 10 --difficulty 1 | |
| """ | |
| import json | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from huggingface_hub import InferenceClient | |
| from server.environment import HFTOversightEnvironment | |
| from models import OversightAction | |
| # --- Config --- | |
| from huggingface_hub import get_token | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") or get_token() or "" | |
| MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" | |
| SYSTEM_PROMPT = """You are an HFT oversight agent. You manage trading bots and must find and shut down broken ones. | |
| Each turn, respond with ONLY a JSON action. No other text. | |
| Commands: | |
| - {"command": "list_bots"} | |
| - {"command": "read_logs", "bot_id": "NAME"} | |
| - {"command": "check_pnl", "bot_id": "NAME"} | |
| - {"command": "shutdown", "bot_id": "NAME", "reason": "WHY"} | |
| Look for errors, bad prices, or suspicious behavior. Shut down broken bots.""" | |
| def parse_action(text: str) -> OversightAction: | |
| text = text.strip() | |
| # Strip markdown code blocks | |
| if "```" in text: | |
| text = text.split("```")[1].removeprefix("json").strip() | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| if start >= 0 and end > start: | |
| data = json.loads(text[start:end]) | |
| return OversightAction(**data) | |
| raise ValueError(f"Could not parse action from: {text}") | |
| def run_episode(client: InferenceClient, difficulty: int = 1) -> dict: | |
| env = HFTOversightEnvironment() | |
| env._difficulty = difficulty | |
| obs = env.reset() | |
| print(f"\n{'='*60}") | |
| print(f"EPISODE (difficulty={difficulty})") | |
| print(f"{'='*60}") | |
| print(obs.response[:300]) | |
| if obs.alerts: | |
| print(f"Alerts: {obs.alerts}") | |
| # Build conversation | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": obs.response + (f"\n\nAlerts: {obs.alerts}" if obs.alerts else "")}, | |
| ] | |
| trajectory = [] | |
| total_reward = 0.0 | |
| while not obs.done: | |
| # Query the model | |
| try: | |
| response = client.chat_completion( | |
| messages=messages, | |
| max_tokens=200, | |
| temperature=0.3, | |
| ) | |
| llm_text = response.choices[0].message.content | |
| except Exception as e: | |
| print(f" Model error: {e}") | |
| llm_text = '{"command": "pass_turn"}' | |
| consecutive_errors = consecutive_errors + 1 if 'consecutive_errors' in dir() else 1 | |
| if consecutive_errors >= 3: | |
| print(" 3 consecutive model errors — aborting episode.") | |
| break | |
| print(f"\n LLM (step {obs.timestep + 1}): {llm_text[:150]}") | |
| try: | |
| action = parse_action(llm_text) | |
| except (ValueError, json.JSONDecodeError) as e: | |
| print(f" Parse error: {e}") | |
| action = OversightAction(command="pass_turn") | |
| # Step environment | |
| obs = env.step(action) | |
| total_reward += obs.reward | |
| print(f" ENV: {obs.response[:150]}") | |
| print(f" [reward={obs.reward}, total={total_reward}, step={obs.timestep}/{obs.max_timesteps}]") | |
| # Record trajectory step | |
| trajectory.append({ | |
| "messages_so_far": [m.copy() for m in messages], | |
| "assistant_response": llm_text, | |
| "action": action.model_dump(exclude_none=True), | |
| "reward": obs.reward, | |
| "cumulative_reward": total_reward, | |
| "done": obs.done, | |
| }) | |
| # Feed back to conversation | |
| messages.append({"role": "assistant", "content": llm_text}) | |
| env_msg = obs.response | |
| if obs.alerts: | |
| env_msg += f"\n\nAlerts: {obs.alerts}" | |
| env_msg += f"\n\n[Step {obs.timestep}/{obs.max_timesteps}]" | |
| messages.append({"role": "user", "content": env_msg}) | |
| print(f"\n DONE — Total reward: {total_reward}") | |
| return { | |
| "difficulty": difficulty, | |
| "total_reward": total_reward, | |
| "steps": len(trajectory), | |
| "trajectory": trajectory, | |
| "full_conversation": messages, | |
| } | |
| DIFFICULTY_LEVELS = [1, 2, 3, 5, 7] | |
| FAST_SOLVE_THRESHOLD = 3 # solved in <= this many steps = "quick" | |
| STREAK_TO_ADVANCE = 3 # consecutive wins to level up | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--episodes", type=int, default=20) | |
| parser.add_argument("--difficulty", type=int, default=1, help="Starting difficulty") | |
| parser.add_argument("--adaptive", action="store_true", default=True, | |
| help="Auto-increase difficulty (default: on)") | |
| parser.add_argument("--no-adaptive", dest="adaptive", action="store_false") | |
| parser.add_argument("--output", type=str, default="trajectories.jsonl") | |
| args = parser.parse_args() | |
| if not HF_TOKEN: | |
| print("Set HF_TOKEN env var: export HF_TOKEN=hf_xxx") | |
| sys.exit(1) | |
| client = InferenceClient(model=MODEL_ID, token=HF_TOKEN) | |
| # Adaptive difficulty state | |
| difficulty = args.difficulty | |
| win_streak = 0 | |
| level_idx = DIFFICULTY_LEVELS.index(difficulty) if difficulty in DIFFICULTY_LEVELS else 0 | |
| print(f"Model: {MODEL_ID}") | |
| print(f"Running {args.episodes} episodes, starting difficulty={difficulty}") | |
| if args.adaptive: | |
| print(f"Adaptive mode: level up after {STREAK_TO_ADVANCE} wins or a fast solve (<={FAST_SOLVE_THRESHOLD} steps)") | |
| all_results = [] | |
| for i in range(args.episodes): | |
| print(f"\n{'─'*60}") | |
| print(f"Episode {i+1}/{args.episodes} | difficulty={difficulty} | win_streak={win_streak}") | |
| print(f"{'─'*60}") | |
| result = run_episode(client, difficulty) | |
| result["difficulty"] = difficulty | |
| all_results.append(result) | |
| if args.adaptive: | |
| won = result["total_reward"] > 0 | |
| fast = won and result["steps"] <= FAST_SOLVE_THRESHOLD | |
| if won: | |
| win_streak += 1 | |
| else: | |
| win_streak = 0 | |
| should_advance = (win_streak >= STREAK_TO_ADVANCE) or fast | |
| if should_advance and level_idx < len(DIFFICULTY_LEVELS) - 1: | |
| level_idx += 1 | |
| difficulty = DIFFICULTY_LEVELS[level_idx] | |
| win_streak = 0 | |
| print(f"\n >> LEVEL UP! Now at difficulty {difficulty}") | |
| elif won: | |
| print(f"\n >> Win streak: {win_streak}/{STREAK_TO_ADVANCE}") | |
| if not won and level_idx > 0: | |
| # Drop back down after a failure at a new level | |
| pass # stay at current level, just reset streak | |
| # Save trajectories | |
| with open(args.output, "w") as f: | |
| for r in all_results: | |
| f.write(json.dumps(r) + "\n") | |
| # Summary by difficulty | |
| print(f"\n{'='*60}") | |
| print("RESULTS SUMMARY") | |
| print(f"{'='*60}") | |
| for lvl in DIFFICULTY_LEVELS: | |
| lvl_results = [r for r in all_results if r["difficulty"] == lvl] | |
| if not lvl_results: | |
| continue | |
| rewards = [r["total_reward"] for r in lvl_results] | |
| wins = sum(1 for r in rewards if r > 0) | |
| avg_steps = sum(r["steps"] for r in lvl_results) / len(lvl_results) | |
| print(f" Difficulty {lvl}: {wins}/{len(lvl_results)} wins, " | |
| f"avg reward={sum(rewards)/len(rewards):.1f}, avg steps={avg_steps:.1f}") | |
| print(f"\n Total episodes: {len(all_results)}") | |
| print(f" Max difficulty reached: {max(r['difficulty'] for r in all_results)}") | |
| print(f" Saved to: {args.output}") | |
| if __name__ == "__main__": | |
| main() | |