Spaces:
Running
Running
| """ | |
| Quant Research Environment Implementation. | |
| A multi-phase code-writing environment where an AI agent acts as a | |
| quantitative researcher: exploring market data, writing Python trading | |
| strategy code, backtesting it, and iterating based on feedback. | |
| """ | |
| import json | |
| import logging | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| try: | |
| from ..models import QuantAction, QuantObservation, QuantState | |
| except ImportError: | |
| from models import QuantAction, QuantObservation, QuantState | |
| from . import backtester, constraints, grader, sandbox | |
| from .data_loader import get_data_preview, get_data_schema, get_test_data, get_train_data | |
| from .tasks import TASKS | |
| logger = logging.getLogger(__name__) | |
| class QuantResearchEnvironment(Environment): | |
| """ | |
| Quantitative trading strategy research environment. | |
| Agents interact through structured actions to explore data, | |
| write strategy code, backtest it, and iterate. Three tasks | |
| test progressively harder quant skills. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = QuantState(episode_id=str(uuid4()), step_count=0) | |
| self._task_config = None | |
| self._submitted_code = "" | |
| self._research_md = "" | |
| self._best_score = 0.0 | |
| self._last_backtest = None | |
| self._nifty_train = None | |
| self._banknifty_train = None | |
| self._merged_train = None | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| readme_content = None | |
| for path in ["/app/README.md", "README.md"]: | |
| try: | |
| with open(path, encoding="utf-8") as f: | |
| raw = f.read() | |
| # Strip YAML frontmatter (HF Spaces metadata between --- markers) | |
| if raw.startswith("---"): | |
| end = raw.find("---", 3) | |
| if end != -1: | |
| raw = raw[end + 3:].lstrip("\n") | |
| readme_content = raw | |
| break | |
| except FileNotFoundError: | |
| pass | |
| return EnvironmentMetadata( | |
| name="Quant Research Environment", | |
| description=( | |
| "A quantitative trading strategy research workbench. " | |
| "Agents explore market data, write Python trading strategies, " | |
| "backtest them, and iterate. Three tasks: Easy (SMA crossover), " | |
| "Medium (pairs trading), Hard (alpha research on hidden data)." | |
| ), | |
| version="1.0.0", | |
| readme_content=readme_content, | |
| ) | |
| def reset(self, seed=None, episode_id=None, **kwargs) -> QuantObservation: | |
| """ | |
| Reset environment for a specific task. | |
| Keyword args: | |
| task_id: "easy", "medium", or "hard" (default: "easy") | |
| """ | |
| task_id = kwargs.get("task_id", "easy") | |
| if task_id not in TASKS: | |
| task_id = "easy" | |
| self._task_config = TASKS[task_id] | |
| self._submitted_code = "" | |
| self._research_md = "" | |
| self._best_score = 0.0 | |
| self._last_backtest = None | |
| # Load data | |
| self._nifty_train, self._banknifty_train, self._merged_train = get_train_data() | |
| max_steps = self._task_config["max_steps"] | |
| self._state = QuantState( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| current_phase="explore", | |
| has_submitted_code=False, | |
| backtest_count=0, | |
| best_score=0.0, | |
| steps_remaining=max_steps, | |
| ) | |
| schema = get_data_schema(self._merged_train) | |
| preview = get_data_preview(self._merged_train) | |
| return QuantObservation( | |
| done=False, | |
| reward=0.0, | |
| task_id=task_id, | |
| task_description=self._task_config["description"], | |
| data_schema=schema, | |
| data_preview=preview, | |
| function_signature=self._task_config["function_signature"], | |
| steps_remaining=max_steps, | |
| phase="explore", | |
| message=f"Task '{self._task_config['name']}' loaded. " | |
| f"You have {max_steps} steps. Explore data, write code, backtest, and submit.", | |
| ) | |
| def step(self, action: QuantAction, timeout_s=None, **kwargs) -> QuantObservation: | |
| """Process agent action and return observation with reward.""" | |
| self._state.step_count += 1 | |
| self._state.steps_remaining = max( | |
| 0, self._task_config["max_steps"] - self._state.step_count | |
| ) | |
| # Check if episode is already done | |
| if self._state.steps_remaining <= 0 and action.action_type != "submit_final": | |
| return self._force_submit() | |
| action_type = action.action_type | |
| handlers = { | |
| "explore_data": self._handle_explore, | |
| "submit_code": self._handle_submit_code, | |
| "run_backtest": self._handle_backtest, | |
| "submit_research": self._handle_submit_research, | |
| "submit_final": self._handle_submit_final, | |
| } | |
| handler = handlers.get(action_type) | |
| if handler is None: | |
| return QuantObservation( | |
| done=False, | |
| reward=self._best_score, | |
| task_id=self._state.task_id, | |
| steps_remaining=self._state.steps_remaining, | |
| phase=self._state.current_phase, | |
| message=f"Unknown action_type '{action_type}'. " | |
| f"Use: explore_data, submit_code, run_backtest, submit_research, submit_final.", | |
| ) | |
| obs = handler(action.content) | |
| # Force submit if out of steps | |
| if self._state.steps_remaining <= 0 and not obs.done: | |
| return self._force_submit() | |
| return obs | |
| def state(self) -> QuantState: | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Action handlers | |
| # ------------------------------------------------------------------ | |
| def _handle_explore(self, query: str) -> QuantObservation: | |
| """Execute a read-only data exploration query.""" | |
| self._state.current_phase = "explore" | |
| if not query.strip(): | |
| return self._obs(message="Empty query. Provide a pandas expression like 'df.describe()' or 'df.head()'.") | |
| output = sandbox.execute_exploration_query(query, self._merged_train) | |
| return self._obs( | |
| execution_output=output, | |
| phase="explore", | |
| message="Query executed. Use submit_code to write your strategy.", | |
| ) | |
| def _handle_submit_code(self, code: str) -> QuantObservation: | |
| """Validate and store agent strategy code.""" | |
| self._state.current_phase = "code" | |
| if not code.strip(): | |
| return self._obs(message="Empty code submission. Provide Python code defining generate_trades().") | |
| # Syntax check | |
| syntax_error = sandbox.check_syntax(code) | |
| if syntax_error: | |
| return self._obs( | |
| execution_output=syntax_error, | |
| message=f"Syntax error: {syntax_error}", | |
| reward=0.0, | |
| ) | |
| # Check for dangerous patterns | |
| violations = constraints.scan_for_dangerous_code(code) | |
| if violations: | |
| return self._obs( | |
| constraint_violations=violations, | |
| message="Code contains forbidden patterns (os, subprocess, etc). Revise and resubmit.", | |
| reward=0.0, | |
| ) | |
| # Static lookahead scan (for all tasks) | |
| has_lookahead, detail = constraints.scan_for_lookahead(code) | |
| if has_lookahead: | |
| return self._obs( | |
| constraint_violations=[detail], | |
| message=f"Lookahead detected: {detail}. Revise and resubmit.", | |
| reward=0.0, | |
| ) | |
| # Check for generate_trades function | |
| has_func = "def generate_trades" in code | |
| self._submitted_code = code | |
| self._state.has_submitted_code = True | |
| exec_status = { | |
| "parses": True, | |
| "has_signature": has_func, | |
| } | |
| reward = 0.10 if has_func else 0.05 | |
| if reward > self._best_score: | |
| self._best_score = reward | |
| self._state.best_score = reward | |
| msg = "Code submitted and validated." if has_func else "Code submitted but 'def generate_trades' not found." | |
| msg += " Use run_backtest to test your strategy." | |
| return self._obs( | |
| execution_output="Code validation passed." if has_func else "Warning: generate_trades() function not found.", | |
| phase="code", | |
| message=msg, | |
| reward=reward, | |
| ) | |
| def _handle_backtest(self, _content: str) -> QuantObservation: | |
| """Execute stored code and run backtest.""" | |
| self._state.current_phase = "backtest" | |
| if not self._submitted_code: | |
| return self._obs(message="No code submitted yet. Use submit_code first.") | |
| task_id = self._state.task_id | |
| # Execute strategy code | |
| result = sandbox.execute_strategy_code( | |
| self._submitted_code, | |
| self._nifty_train, | |
| self._banknifty_train, | |
| task_id, | |
| ) | |
| exec_status = { | |
| "parses": True, | |
| "has_signature": "def generate_trades" in self._submitted_code, | |
| "runs": result["success"], | |
| "correct_columns": False, | |
| } | |
| if not result["success"]: | |
| error_msg = result.get("error", "Unknown error") | |
| reward = 0.10 if exec_status["has_signature"] else 0.05 | |
| return self._obs( | |
| execution_output=f"Execution failed: {error_msg}\n{result.get('output', '')}", | |
| message=f"Code execution failed. Fix the error and resubmit.", | |
| reward=reward, | |
| ) | |
| trades_df = result["trades_df"] | |
| if trades_df is None or len(trades_df) == 0: | |
| return self._obs( | |
| execution_output="generate_trades() returned empty or None result.", | |
| message="Strategy produced no trades. Revise your code.", | |
| reward=0.20, | |
| ) | |
| # Check columns | |
| if task_id == "easy": | |
| required_cols = {"bar", "position"} | |
| else: | |
| required_cols = {"bar", "nifty_position", "banknifty_position"} | |
| actual_cols = set(trades_df.columns) | |
| exec_status["correct_columns"] = required_cols.issubset(actual_cols) | |
| if not exec_status["correct_columns"]: | |
| return self._obs( | |
| execution_output=f"Wrong columns. Expected {required_cols}, got {actual_cols}.", | |
| message=f"Output DataFrame has wrong columns. Expected: {required_cols}.", | |
| reward=0.20, | |
| ) | |
| # Run backtester | |
| try: | |
| if task_id == "easy": | |
| bt_results = backtester.replay_trades_single( | |
| trades_df, self._merged_train["Close_nifty"] | |
| ) | |
| else: | |
| bt_results = backtester.replay_trades_multi( | |
| trades_df, | |
| self._nifty_train["Close"], | |
| self._banknifty_train["Close"], | |
| ) | |
| except Exception as e: | |
| return self._obs( | |
| execution_output=f"Backtest error: {e}", | |
| message="Backtest failed. Check your output format.", | |
| reward=0.20, | |
| ) | |
| self._last_backtest = bt_results | |
| self._state.backtest_count += 1 | |
| # Grade | |
| gt = self._task_config["ground_truth"] | |
| tol = self._task_config["tolerances"] | |
| if task_id == "easy": | |
| score = grader.grade_easy(exec_status, bt_results, gt, tol) | |
| elif task_id == "medium": | |
| score = grader.grade_medium(exec_status, bt_results, gt, tol) | |
| else: | |
| score = grader.grade_hard(exec_status, bt_results, None, False) | |
| if score > self._best_score: | |
| self._best_score = score | |
| self._state.best_score = score | |
| # Format results for agent | |
| display_results = {k: v for k, v in bt_results.items() if k != "bar_pnls"} | |
| if "exposure_violations" in display_results: | |
| n_violations = len(display_results["exposure_violations"]) | |
| display_results["exposure_violations"] = f"{n_violations} violations" | |
| violation_msgs = [] | |
| if "exposure_violations" in bt_results and bt_results["exposure_violations"]: | |
| n = len(bt_results["exposure_violations"]) | |
| max_exp = bt_results.get("max_net_exposure_ratio", 0) | |
| violation_msgs.append(f"Exposure constraint violated {n} times (max: {max_exp:.4f} > 0.80)") | |
| return self._obs( | |
| execution_output=result.get("output", ""), | |
| backtest_results=json.dumps(display_results, indent=2), | |
| constraint_violations=violation_msgs, | |
| phase="backtest", | |
| message=f"Backtest complete. Score: {score:.2f}/1.00. " | |
| f"Best score: {self._best_score:.2f}. " | |
| f"Use submit_code to revise or submit_final to lock in.", | |
| reward=score, | |
| ) | |
| def _handle_submit_research(self, content: str) -> QuantObservation: | |
| """Store research notes (hard task only).""" | |
| if self._state.task_id != "hard": | |
| return self._obs(message="submit_research is only for the hard task.") | |
| self._research_md = content | |
| return self._obs( | |
| message="Research notes saved.", | |
| phase="code", | |
| ) | |
| def _handle_submit_final(self, _content: str) -> QuantObservation: | |
| """Lock in final submission and compute final score.""" | |
| if not self._submitted_code: | |
| return QuantObservation( | |
| done=True, | |
| reward=0.0, | |
| task_id=self._state.task_id, | |
| steps_remaining=0, | |
| phase="done", | |
| message="No code was submitted. Final score: 0.0", | |
| ) | |
| task_id = self._state.task_id | |
| # For hard task: run on test data + runtime lookahead check | |
| if task_id == "hard" and self._last_backtest is not None: | |
| score = self._grade_hard_final() | |
| elif self._last_backtest is not None: | |
| # For easy/medium: use last backtest score | |
| score = self._best_score | |
| else: | |
| # No backtest was run -- run one now | |
| result = sandbox.execute_strategy_code( | |
| self._submitted_code, | |
| self._nifty_train, | |
| self._banknifty_train, | |
| task_id, | |
| ) | |
| if result["success"] and result["trades_df"] is not None: | |
| score = 0.20 # Minimum for running code | |
| else: | |
| score = 0.10 | |
| self._state.current_phase = "done" | |
| return QuantObservation( | |
| done=True, | |
| reward=score, | |
| task_id=task_id, | |
| steps_remaining=0, | |
| phase="done", | |
| message=f"Final submission locked. Score: {score:.4f}/1.00", | |
| ) | |
| def _grade_hard_final(self) -> float: | |
| """Run full grading for hard task including OOS and lookahead.""" | |
| exec_status = { | |
| "parses": True, | |
| "has_signature": "def generate_trades" in self._submitted_code, | |
| "runs": True, | |
| "correct_columns": True, | |
| } | |
| train_bt = self._last_backtest | |
| # Runtime lookahead detection | |
| has_lookahead, _detail = constraints.detect_runtime_lookahead( | |
| self._submitted_code, | |
| self._nifty_train, | |
| self._banknifty_train, | |
| ) | |
| # Run on test data | |
| oos_results = None | |
| try: | |
| nifty_test, banknifty_test, _merged_test = get_test_data() | |
| result = sandbox.execute_strategy_code( | |
| self._submitted_code, | |
| nifty_test, | |
| banknifty_test, | |
| "hard", | |
| ) | |
| if result["success"] and result["trades_df"] is not None: | |
| trades_df = result["trades_df"] | |
| oos_results = backtester.replay_trades_multi( | |
| trades_df, | |
| nifty_test["Close"], | |
| banknifty_test["Close"], | |
| ) | |
| except FileNotFoundError: | |
| logger.warning("Test data not available; grading on train data only") | |
| except Exception as e: | |
| logger.error(f"OOS evaluation failed: {e}") | |
| return grader.grade_hard(exec_status, train_bt, oos_results, has_lookahead) | |
| def _force_submit(self) -> QuantObservation: | |
| """Force a final submission when steps are exhausted.""" | |
| return self._handle_submit_final("") | |
| def _obs(self, **kwargs) -> QuantObservation: | |
| """Build an observation with defaults.""" | |
| defaults = { | |
| "done": False, | |
| "reward": self._best_score, | |
| "task_id": self._state.task_id, | |
| "steps_remaining": self._state.steps_remaining, | |
| "phase": self._state.current_phase, | |
| } | |
| defaults.update(kwargs) | |
| return QuantObservation(**defaults) | |