""" Quant Research Environment Implementation. A multi-phase code-writing environment where an AI agent acts as a quantitative researcher: exploring market data, writing Python trading strategy code, backtesting it, and iterating based on feedback. """ import json import logging from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import EnvironmentMetadata try: from ..models import QuantAction, QuantObservation, QuantState except ImportError: from models import QuantAction, QuantObservation, QuantState from . import backtester, constraints, grader, sandbox from .data_loader import get_data_preview, get_data_schema, get_test_data, get_train_data from .tasks import TASKS logger = logging.getLogger(__name__) class QuantResearchEnvironment(Environment): """ Quantitative trading strategy research environment. Agents interact through structured actions to explore data, write strategy code, backtest it, and iterate. Three tasks test progressively harder quant skills. """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): super().__init__() self._state = QuantState(episode_id=str(uuid4()), step_count=0) self._task_config = None self._submitted_code = "" self._research_md = "" self._best_score = 0.0 self._last_backtest = None self._nifty_train = None self._banknifty_train = None self._merged_train = None def get_metadata(self) -> EnvironmentMetadata: readme_content = None for path in ["/app/README.md", "README.md"]: try: with open(path, encoding="utf-8") as f: raw = f.read() # Strip YAML frontmatter (HF Spaces metadata between --- markers) if raw.startswith("---"): end = raw.find("---", 3) if end != -1: raw = raw[end + 3:].lstrip("\n") readme_content = raw break except FileNotFoundError: pass return EnvironmentMetadata( name="Quant Research Environment", description=( "A quantitative trading strategy research workbench. " "Agents explore market data, write Python trading strategies, " "backtest them, and iterate. Three tasks: Easy (SMA crossover), " "Medium (pairs trading), Hard (alpha research on hidden data)." ), version="1.0.0", readme_content=readme_content, ) def reset(self, seed=None, episode_id=None, **kwargs) -> QuantObservation: """ Reset environment for a specific task. Keyword args: task_id: "easy", "medium", or "hard" (default: "easy") """ task_id = kwargs.get("task_id", "easy") if task_id not in TASKS: task_id = "easy" self._task_config = TASKS[task_id] self._submitted_code = "" self._research_md = "" self._best_score = 0.0 self._last_backtest = None # Load data self._nifty_train, self._banknifty_train, self._merged_train = get_train_data() max_steps = self._task_config["max_steps"] self._state = QuantState( episode_id=episode_id or str(uuid4()), step_count=0, task_id=task_id, current_phase="explore", has_submitted_code=False, backtest_count=0, best_score=0.0, steps_remaining=max_steps, ) schema = get_data_schema(self._merged_train) preview = get_data_preview(self._merged_train) return QuantObservation( done=False, reward=0.0, task_id=task_id, task_description=self._task_config["description"], data_schema=schema, data_preview=preview, function_signature=self._task_config["function_signature"], steps_remaining=max_steps, phase="explore", message=f"Task '{self._task_config['name']}' loaded. " f"You have {max_steps} steps. Explore data, write code, backtest, and submit.", ) def step(self, action: QuantAction, timeout_s=None, **kwargs) -> QuantObservation: """Process agent action and return observation with reward.""" self._state.step_count += 1 self._state.steps_remaining = max( 0, self._task_config["max_steps"] - self._state.step_count ) # Check if episode is already done if self._state.steps_remaining <= 0 and action.action_type != "submit_final": return self._force_submit() action_type = action.action_type handlers = { "explore_data": self._handle_explore, "submit_code": self._handle_submit_code, "run_backtest": self._handle_backtest, "submit_research": self._handle_submit_research, "submit_final": self._handle_submit_final, } handler = handlers.get(action_type) if handler is None: return QuantObservation( done=False, reward=self._best_score, task_id=self._state.task_id, steps_remaining=self._state.steps_remaining, phase=self._state.current_phase, message=f"Unknown action_type '{action_type}'. " f"Use: explore_data, submit_code, run_backtest, submit_research, submit_final.", ) obs = handler(action.content) # Force submit if out of steps if self._state.steps_remaining <= 0 and not obs.done: return self._force_submit() return obs @property def state(self) -> QuantState: return self._state # ------------------------------------------------------------------ # Action handlers # ------------------------------------------------------------------ def _handle_explore(self, query: str) -> QuantObservation: """Execute a read-only data exploration query.""" self._state.current_phase = "explore" if not query.strip(): return self._obs(message="Empty query. Provide a pandas expression like 'df.describe()' or 'df.head()'.") output = sandbox.execute_exploration_query(query, self._merged_train) return self._obs( execution_output=output, phase="explore", message="Query executed. Use submit_code to write your strategy.", ) def _handle_submit_code(self, code: str) -> QuantObservation: """Validate and store agent strategy code.""" self._state.current_phase = "code" if not code.strip(): return self._obs(message="Empty code submission. Provide Python code defining generate_trades().") # Syntax check syntax_error = sandbox.check_syntax(code) if syntax_error: return self._obs( execution_output=syntax_error, message=f"Syntax error: {syntax_error}", reward=0.0, ) # Check for dangerous patterns violations = constraints.scan_for_dangerous_code(code) if violations: return self._obs( constraint_violations=violations, message="Code contains forbidden patterns (os, subprocess, etc). Revise and resubmit.", reward=0.0, ) # Static lookahead scan (for all tasks) has_lookahead, detail = constraints.scan_for_lookahead(code) if has_lookahead: return self._obs( constraint_violations=[detail], message=f"Lookahead detected: {detail}. Revise and resubmit.", reward=0.0, ) # Check for generate_trades function has_func = "def generate_trades" in code self._submitted_code = code self._state.has_submitted_code = True exec_status = { "parses": True, "has_signature": has_func, } reward = 0.10 if has_func else 0.05 if reward > self._best_score: self._best_score = reward self._state.best_score = reward msg = "Code submitted and validated." if has_func else "Code submitted but 'def generate_trades' not found." msg += " Use run_backtest to test your strategy." return self._obs( execution_output="Code validation passed." if has_func else "Warning: generate_trades() function not found.", phase="code", message=msg, reward=reward, ) def _handle_backtest(self, _content: str) -> QuantObservation: """Execute stored code and run backtest.""" self._state.current_phase = "backtest" if not self._submitted_code: return self._obs(message="No code submitted yet. Use submit_code first.") task_id = self._state.task_id # Execute strategy code result = sandbox.execute_strategy_code( self._submitted_code, self._nifty_train, self._banknifty_train, task_id, ) exec_status = { "parses": True, "has_signature": "def generate_trades" in self._submitted_code, "runs": result["success"], "correct_columns": False, } if not result["success"]: error_msg = result.get("error", "Unknown error") reward = 0.10 if exec_status["has_signature"] else 0.05 return self._obs( execution_output=f"Execution failed: {error_msg}\n{result.get('output', '')}", message=f"Code execution failed. Fix the error and resubmit.", reward=reward, ) trades_df = result["trades_df"] if trades_df is None or len(trades_df) == 0: return self._obs( execution_output="generate_trades() returned empty or None result.", message="Strategy produced no trades. Revise your code.", reward=0.20, ) # Check columns if task_id == "easy": required_cols = {"bar", "position"} else: required_cols = {"bar", "nifty_position", "banknifty_position"} actual_cols = set(trades_df.columns) exec_status["correct_columns"] = required_cols.issubset(actual_cols) if not exec_status["correct_columns"]: return self._obs( execution_output=f"Wrong columns. Expected {required_cols}, got {actual_cols}.", message=f"Output DataFrame has wrong columns. Expected: {required_cols}.", reward=0.20, ) # Run backtester try: if task_id == "easy": bt_results = backtester.replay_trades_single( trades_df, self._merged_train["Close_nifty"] ) else: bt_results = backtester.replay_trades_multi( trades_df, self._nifty_train["Close"], self._banknifty_train["Close"], ) except Exception as e: return self._obs( execution_output=f"Backtest error: {e}", message="Backtest failed. Check your output format.", reward=0.20, ) self._last_backtest = bt_results self._state.backtest_count += 1 # Grade gt = self._task_config["ground_truth"] tol = self._task_config["tolerances"] if task_id == "easy": score = grader.grade_easy(exec_status, bt_results, gt, tol) elif task_id == "medium": score = grader.grade_medium(exec_status, bt_results, gt, tol) else: score = grader.grade_hard(exec_status, bt_results, None, False) if score > self._best_score: self._best_score = score self._state.best_score = score # Format results for agent display_results = {k: v for k, v in bt_results.items() if k != "bar_pnls"} if "exposure_violations" in display_results: n_violations = len(display_results["exposure_violations"]) display_results["exposure_violations"] = f"{n_violations} violations" violation_msgs = [] if "exposure_violations" in bt_results and bt_results["exposure_violations"]: n = len(bt_results["exposure_violations"]) max_exp = bt_results.get("max_net_exposure_ratio", 0) violation_msgs.append(f"Exposure constraint violated {n} times (max: {max_exp:.4f} > 0.80)") return self._obs( execution_output=result.get("output", ""), backtest_results=json.dumps(display_results, indent=2), constraint_violations=violation_msgs, phase="backtest", message=f"Backtest complete. Score: {score:.2f}/1.00. " f"Best score: {self._best_score:.2f}. " f"Use submit_code to revise or submit_final to lock in.", reward=score, ) def _handle_submit_research(self, content: str) -> QuantObservation: """Store research notes (hard task only).""" if self._state.task_id != "hard": return self._obs(message="submit_research is only for the hard task.") self._research_md = content return self._obs( message="Research notes saved.", phase="code", ) def _handle_submit_final(self, _content: str) -> QuantObservation: """Lock in final submission and compute final score.""" if not self._submitted_code: return QuantObservation( done=True, reward=0.0, task_id=self._state.task_id, steps_remaining=0, phase="done", message="No code was submitted. Final score: 0.0", ) task_id = self._state.task_id # For hard task: run on test data + runtime lookahead check if task_id == "hard" and self._last_backtest is not None: score = self._grade_hard_final() elif self._last_backtest is not None: # For easy/medium: use last backtest score score = self._best_score else: # No backtest was run -- run one now result = sandbox.execute_strategy_code( self._submitted_code, self._nifty_train, self._banknifty_train, task_id, ) if result["success"] and result["trades_df"] is not None: score = 0.20 # Minimum for running code else: score = 0.10 self._state.current_phase = "done" return QuantObservation( done=True, reward=score, task_id=task_id, steps_remaining=0, phase="done", message=f"Final submission locked. Score: {score:.4f}/1.00", ) def _grade_hard_final(self) -> float: """Run full grading for hard task including OOS and lookahead.""" exec_status = { "parses": True, "has_signature": "def generate_trades" in self._submitted_code, "runs": True, "correct_columns": True, } train_bt = self._last_backtest # Runtime lookahead detection has_lookahead, _detail = constraints.detect_runtime_lookahead( self._submitted_code, self._nifty_train, self._banknifty_train, ) # Run on test data oos_results = None try: nifty_test, banknifty_test, _merged_test = get_test_data() result = sandbox.execute_strategy_code( self._submitted_code, nifty_test, banknifty_test, "hard", ) if result["success"] and result["trades_df"] is not None: trades_df = result["trades_df"] oos_results = backtester.replay_trades_multi( trades_df, nifty_test["Close"], banknifty_test["Close"], ) except FileNotFoundError: logger.warning("Test data not available; grading on train data only") except Exception as e: logger.error(f"OOS evaluation failed: {e}") return grader.grade_hard(exec_status, train_bt, oos_results, has_lookahead) def _force_submit(self) -> QuantObservation: """Force a final submission when steps are exhausted.""" return self._handle_submit_final("") def _obs(self, **kwargs) -> QuantObservation: """Build an observation with defaults.""" defaults = { "done": False, "reward": self._best_score, "task_id": self._state.task_id, "steps_remaining": self._state.steps_remaining, "phase": self._state.current_phase, } defaults.update(kwargs) return QuantObservation(**defaults)