quant-research-env / server /quant_research_env_environment.py
yobro4619's picture
Upload folder using huggingface_hub
c26148b verified
"""
Quant Research Environment Implementation.
A multi-phase code-writing environment where an AI agent acts as a
quantitative researcher: exploring market data, writing Python trading
strategy code, backtesting it, and iterating based on feedback.
"""
import json
import logging
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata
try:
from ..models import QuantAction, QuantObservation, QuantState
except ImportError:
from models import QuantAction, QuantObservation, QuantState
from . import backtester, constraints, grader, sandbox
from .data_loader import get_data_preview, get_data_schema, get_test_data, get_train_data
from .tasks import TASKS
logger = logging.getLogger(__name__)
class QuantResearchEnvironment(Environment):
"""
Quantitative trading strategy research environment.
Agents interact through structured actions to explore data,
write strategy code, backtest it, and iterate. Three tasks
test progressively harder quant skills.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
super().__init__()
self._state = QuantState(episode_id=str(uuid4()), step_count=0)
self._task_config = None
self._submitted_code = ""
self._research_md = ""
self._best_score = 0.0
self._last_backtest = None
self._nifty_train = None
self._banknifty_train = None
self._merged_train = None
def get_metadata(self) -> EnvironmentMetadata:
readme_content = None
for path in ["/app/README.md", "README.md"]:
try:
with open(path, encoding="utf-8") as f:
raw = f.read()
# Strip YAML frontmatter (HF Spaces metadata between --- markers)
if raw.startswith("---"):
end = raw.find("---", 3)
if end != -1:
raw = raw[end + 3:].lstrip("\n")
readme_content = raw
break
except FileNotFoundError:
pass
return EnvironmentMetadata(
name="Quant Research Environment",
description=(
"A quantitative trading strategy research workbench. "
"Agents explore market data, write Python trading strategies, "
"backtest them, and iterate. Three tasks: Easy (SMA crossover), "
"Medium (pairs trading), Hard (alpha research on hidden data)."
),
version="1.0.0",
readme_content=readme_content,
)
def reset(self, seed=None, episode_id=None, **kwargs) -> QuantObservation:
"""
Reset environment for a specific task.
Keyword args:
task_id: "easy", "medium", or "hard" (default: "easy")
"""
task_id = kwargs.get("task_id", "easy")
if task_id not in TASKS:
task_id = "easy"
self._task_config = TASKS[task_id]
self._submitted_code = ""
self._research_md = ""
self._best_score = 0.0
self._last_backtest = None
# Load data
self._nifty_train, self._banknifty_train, self._merged_train = get_train_data()
max_steps = self._task_config["max_steps"]
self._state = QuantState(
episode_id=episode_id or str(uuid4()),
step_count=0,
task_id=task_id,
current_phase="explore",
has_submitted_code=False,
backtest_count=0,
best_score=0.0,
steps_remaining=max_steps,
)
schema = get_data_schema(self._merged_train)
preview = get_data_preview(self._merged_train)
return QuantObservation(
done=False,
reward=0.0,
task_id=task_id,
task_description=self._task_config["description"],
data_schema=schema,
data_preview=preview,
function_signature=self._task_config["function_signature"],
steps_remaining=max_steps,
phase="explore",
message=f"Task '{self._task_config['name']}' loaded. "
f"You have {max_steps} steps. Explore data, write code, backtest, and submit.",
)
def step(self, action: QuantAction, timeout_s=None, **kwargs) -> QuantObservation:
"""Process agent action and return observation with reward."""
self._state.step_count += 1
self._state.steps_remaining = max(
0, self._task_config["max_steps"] - self._state.step_count
)
# Check if episode is already done
if self._state.steps_remaining <= 0 and action.action_type != "submit_final":
return self._force_submit()
action_type = action.action_type
handlers = {
"explore_data": self._handle_explore,
"submit_code": self._handle_submit_code,
"run_backtest": self._handle_backtest,
"submit_research": self._handle_submit_research,
"submit_final": self._handle_submit_final,
}
handler = handlers.get(action_type)
if handler is None:
return QuantObservation(
done=False,
reward=self._best_score,
task_id=self._state.task_id,
steps_remaining=self._state.steps_remaining,
phase=self._state.current_phase,
message=f"Unknown action_type '{action_type}'. "
f"Use: explore_data, submit_code, run_backtest, submit_research, submit_final.",
)
obs = handler(action.content)
# Force submit if out of steps
if self._state.steps_remaining <= 0 and not obs.done:
return self._force_submit()
return obs
@property
def state(self) -> QuantState:
return self._state
# ------------------------------------------------------------------
# Action handlers
# ------------------------------------------------------------------
def _handle_explore(self, query: str) -> QuantObservation:
"""Execute a read-only data exploration query."""
self._state.current_phase = "explore"
if not query.strip():
return self._obs(message="Empty query. Provide a pandas expression like 'df.describe()' or 'df.head()'.")
output = sandbox.execute_exploration_query(query, self._merged_train)
return self._obs(
execution_output=output,
phase="explore",
message="Query executed. Use submit_code to write your strategy.",
)
def _handle_submit_code(self, code: str) -> QuantObservation:
"""Validate and store agent strategy code."""
self._state.current_phase = "code"
if not code.strip():
return self._obs(message="Empty code submission. Provide Python code defining generate_trades().")
# Syntax check
syntax_error = sandbox.check_syntax(code)
if syntax_error:
return self._obs(
execution_output=syntax_error,
message=f"Syntax error: {syntax_error}",
reward=0.0,
)
# Check for dangerous patterns
violations = constraints.scan_for_dangerous_code(code)
if violations:
return self._obs(
constraint_violations=violations,
message="Code contains forbidden patterns (os, subprocess, etc). Revise and resubmit.",
reward=0.0,
)
# Static lookahead scan (for all tasks)
has_lookahead, detail = constraints.scan_for_lookahead(code)
if has_lookahead:
return self._obs(
constraint_violations=[detail],
message=f"Lookahead detected: {detail}. Revise and resubmit.",
reward=0.0,
)
# Check for generate_trades function
has_func = "def generate_trades" in code
self._submitted_code = code
self._state.has_submitted_code = True
exec_status = {
"parses": True,
"has_signature": has_func,
}
reward = 0.10 if has_func else 0.05
if reward > self._best_score:
self._best_score = reward
self._state.best_score = reward
msg = "Code submitted and validated." if has_func else "Code submitted but 'def generate_trades' not found."
msg += " Use run_backtest to test your strategy."
return self._obs(
execution_output="Code validation passed." if has_func else "Warning: generate_trades() function not found.",
phase="code",
message=msg,
reward=reward,
)
def _handle_backtest(self, _content: str) -> QuantObservation:
"""Execute stored code and run backtest."""
self._state.current_phase = "backtest"
if not self._submitted_code:
return self._obs(message="No code submitted yet. Use submit_code first.")
task_id = self._state.task_id
# Execute strategy code
result = sandbox.execute_strategy_code(
self._submitted_code,
self._nifty_train,
self._banknifty_train,
task_id,
)
exec_status = {
"parses": True,
"has_signature": "def generate_trades" in self._submitted_code,
"runs": result["success"],
"correct_columns": False,
}
if not result["success"]:
error_msg = result.get("error", "Unknown error")
reward = 0.10 if exec_status["has_signature"] else 0.05
return self._obs(
execution_output=f"Execution failed: {error_msg}\n{result.get('output', '')}",
message=f"Code execution failed. Fix the error and resubmit.",
reward=reward,
)
trades_df = result["trades_df"]
if trades_df is None or len(trades_df) == 0:
return self._obs(
execution_output="generate_trades() returned empty or None result.",
message="Strategy produced no trades. Revise your code.",
reward=0.20,
)
# Check columns
if task_id == "easy":
required_cols = {"bar", "position"}
else:
required_cols = {"bar", "nifty_position", "banknifty_position"}
actual_cols = set(trades_df.columns)
exec_status["correct_columns"] = required_cols.issubset(actual_cols)
if not exec_status["correct_columns"]:
return self._obs(
execution_output=f"Wrong columns. Expected {required_cols}, got {actual_cols}.",
message=f"Output DataFrame has wrong columns. Expected: {required_cols}.",
reward=0.20,
)
# Run backtester
try:
if task_id == "easy":
bt_results = backtester.replay_trades_single(
trades_df, self._merged_train["Close_nifty"]
)
else:
bt_results = backtester.replay_trades_multi(
trades_df,
self._nifty_train["Close"],
self._banknifty_train["Close"],
)
except Exception as e:
return self._obs(
execution_output=f"Backtest error: {e}",
message="Backtest failed. Check your output format.",
reward=0.20,
)
self._last_backtest = bt_results
self._state.backtest_count += 1
# Grade
gt = self._task_config["ground_truth"]
tol = self._task_config["tolerances"]
if task_id == "easy":
score = grader.grade_easy(exec_status, bt_results, gt, tol)
elif task_id == "medium":
score = grader.grade_medium(exec_status, bt_results, gt, tol)
else:
score = grader.grade_hard(exec_status, bt_results, None, False)
if score > self._best_score:
self._best_score = score
self._state.best_score = score
# Format results for agent
display_results = {k: v for k, v in bt_results.items() if k != "bar_pnls"}
if "exposure_violations" in display_results:
n_violations = len(display_results["exposure_violations"])
display_results["exposure_violations"] = f"{n_violations} violations"
violation_msgs = []
if "exposure_violations" in bt_results and bt_results["exposure_violations"]:
n = len(bt_results["exposure_violations"])
max_exp = bt_results.get("max_net_exposure_ratio", 0)
violation_msgs.append(f"Exposure constraint violated {n} times (max: {max_exp:.4f} > 0.80)")
return self._obs(
execution_output=result.get("output", ""),
backtest_results=json.dumps(display_results, indent=2),
constraint_violations=violation_msgs,
phase="backtest",
message=f"Backtest complete. Score: {score:.2f}/1.00. "
f"Best score: {self._best_score:.2f}. "
f"Use submit_code to revise or submit_final to lock in.",
reward=score,
)
def _handle_submit_research(self, content: str) -> QuantObservation:
"""Store research notes (hard task only)."""
if self._state.task_id != "hard":
return self._obs(message="submit_research is only for the hard task.")
self._research_md = content
return self._obs(
message="Research notes saved.",
phase="code",
)
def _handle_submit_final(self, _content: str) -> QuantObservation:
"""Lock in final submission and compute final score."""
if not self._submitted_code:
return QuantObservation(
done=True,
reward=0.0,
task_id=self._state.task_id,
steps_remaining=0,
phase="done",
message="No code was submitted. Final score: 0.0",
)
task_id = self._state.task_id
# For hard task: run on test data + runtime lookahead check
if task_id == "hard" and self._last_backtest is not None:
score = self._grade_hard_final()
elif self._last_backtest is not None:
# For easy/medium: use last backtest score
score = self._best_score
else:
# No backtest was run -- run one now
result = sandbox.execute_strategy_code(
self._submitted_code,
self._nifty_train,
self._banknifty_train,
task_id,
)
if result["success"] and result["trades_df"] is not None:
score = 0.20 # Minimum for running code
else:
score = 0.10
self._state.current_phase = "done"
return QuantObservation(
done=True,
reward=score,
task_id=task_id,
steps_remaining=0,
phase="done",
message=f"Final submission locked. Score: {score:.4f}/1.00",
)
def _grade_hard_final(self) -> float:
"""Run full grading for hard task including OOS and lookahead."""
exec_status = {
"parses": True,
"has_signature": "def generate_trades" in self._submitted_code,
"runs": True,
"correct_columns": True,
}
train_bt = self._last_backtest
# Runtime lookahead detection
has_lookahead, _detail = constraints.detect_runtime_lookahead(
self._submitted_code,
self._nifty_train,
self._banknifty_train,
)
# Run on test data
oos_results = None
try:
nifty_test, banknifty_test, _merged_test = get_test_data()
result = sandbox.execute_strategy_code(
self._submitted_code,
nifty_test,
banknifty_test,
"hard",
)
if result["success"] and result["trades_df"] is not None:
trades_df = result["trades_df"]
oos_results = backtester.replay_trades_multi(
trades_df,
nifty_test["Close"],
banknifty_test["Close"],
)
except FileNotFoundError:
logger.warning("Test data not available; grading on train data only")
except Exception as e:
logger.error(f"OOS evaluation failed: {e}")
return grader.grade_hard(exec_status, train_bt, oos_results, has_lookahead)
def _force_submit(self) -> QuantObservation:
"""Force a final submission when steps are exhausted."""
return self._handle_submit_final("")
def _obs(self, **kwargs) -> QuantObservation:
"""Build an observation with defaults."""
defaults = {
"done": False,
"reward": self._best_score,
"task_id": self._state.task_id,
"steps_remaining": self._state.steps_remaining,
"phase": self._state.current_phase,
}
defaults.update(kwargs)
return QuantObservation(**defaults)