retro / scripts /test_generation.py
sankalphs's picture
Phase 1: project setup, MiniMax-M3 synthetic dataset generation, 1446-row clean dataset
122cc3c
Raw
History Blame Contribute Delete
4.01 kB
"""
Mini audit test: generate a small batch of each task type and validate
the structured text responses can be parsed.
"""
import asyncio
import json
import re
import sys
from pathlib import Path
from generate_dataset import generate_prompts, generate_wave, OUTPUT_FILE
def to_decimal(value_str: str) -> float:
value_str = value_str.strip()
if value_str.endswith("%"):
return float(value_str[:-1]) / 100.0
return float(value_str)
def parse_agent(response: str) -> dict:
response = response.strip()
if not response:
raise ValueError("empty response")
agent = re.search(r"agent:\s*(\w+)", response).group(1)
action_match = re.search(r"action:\s*(buy|sell|hold)\s+(\w+)\s+([\d.%]+)", response)
reason = re.search(r"reason:\s*(.+)", response).group(1).strip()
sentiment = re.search(r"sentiment:\s*(\w+)", response).group(1)
return {
"agent": agent,
"actions": [{"asset": action_match.group(2), "action": action_match.group(1), "amount_pct": to_decimal(action_match.group(3)), "reason": reason}],
"sentiment": sentiment
}
def parse_news(response: str) -> dict:
response = response.strip()
if not response:
raise ValueError("empty response")
headline = re.search(r"headline:\s*(.+)", response).group(1).strip()
impact_match = re.search(r"impact:\s*(.+?)(?:\nduration:|$)", response, re.DOTALL)
duration = int(re.search(r"duration:\s*(\d+)", response).group(1))
impact = {}
for token in impact_match.group(1).strip().split():
if ":" in token:
k, v = token.split(":")
impact[k] = to_decimal(v)
return {"headline": headline, "impact": impact, "duration_months": duration}
def parse_mentor(response: str) -> dict:
response = response.strip()
if not response:
raise ValueError("empty response")
roast = re.search(r"roast:\s*(.+)", response).group(1).strip()
sharpe = float(re.search(r"sharpe_ratio:\s*([-\d.]+)", response).group(1))
lesson = re.search(r"lesson:\s*(.+)", response).group(1).strip()
suggestion = re.search(r"suggestion:\s*(.+)", response).group(1).strip()
return {"roast": roast, "sharpe_ratio": sharpe, "lesson": lesson, "suggestion": suggestion}
def parse_guardrail(response: str) -> dict:
response = response.strip()
if not response:
raise ValueError("empty response")
return {"error": re.search(r"error:\s*(.+)", response).group(1).strip()}
PARSERS = {
"agent_decision": parse_agent,
"news_impact": parse_news,
"sharpe_mentor": parse_mentor,
"guardrail": parse_guardrail,
}
async def main():
test_file = OUTPUT_FILE.with_name("retro-alpha-test.jsonl")
if test_file.exists():
test_file.unlink()
waves = [
("agent", 5),
("news", 5),
("mentor", 5),
("guardrail", 5),
]
all_rows = []
for task, count in waves:
print(f"Testing {task} with {count} rows...")
prompts = generate_prompts(task, count)
results = await generate_wave(prompts, f"test-{task}")
print(f" Got {len(results)}/{count} non-empty rows")
all_rows.extend(results)
with open(test_file, "w", encoding="utf-8") as f:
for row in all_rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"\nTest output: {test_file}")
print(f"Total test rows: {len(all_rows)}")
parse_ok = 0
parse_fail = 0
for row in all_rows:
parser = PARSERS[row["task"]]
try:
parser(row["response"])
parse_ok += 1
except Exception as e:
parse_fail += 1
print(f"Parse fail ({row['task']}): {e}")
print(f" Response: {repr(row['response'][:200])}")
print(f"Parsed successfully: {parse_ok}/{len(all_rows)}")
if parse_fail > 0:
print("TEST FAILED")
return 1
print("TEST PASSED")
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))