File size: 4,008 Bytes
122cc3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Mini audit test: generate a small batch of each task type and validate
the structured text responses can be parsed.
"""

import asyncio
import json
import re
import sys
from pathlib import Path

from generate_dataset import generate_prompts, generate_wave, OUTPUT_FILE


def to_decimal(value_str: str) -> float:
    value_str = value_str.strip()
    if value_str.endswith("%"):
        return float(value_str[:-1]) / 100.0
    return float(value_str)


def parse_agent(response: str) -> dict:
    response = response.strip()
    if not response:
        raise ValueError("empty response")
    agent = re.search(r"agent:\s*(\w+)", response).group(1)
    action_match = re.search(r"action:\s*(buy|sell|hold)\s+(\w+)\s+([\d.%]+)", response)
    reason = re.search(r"reason:\s*(.+)", response).group(1).strip()
    sentiment = re.search(r"sentiment:\s*(\w+)", response).group(1)
    return {
        "agent": agent,
        "actions": [{"asset": action_match.group(2), "action": action_match.group(1), "amount_pct": to_decimal(action_match.group(3)), "reason": reason}],
        "sentiment": sentiment
    }


def parse_news(response: str) -> dict:
    response = response.strip()
    if not response:
        raise ValueError("empty response")
    headline = re.search(r"headline:\s*(.+)", response).group(1).strip()
    impact_match = re.search(r"impact:\s*(.+?)(?:\nduration:|$)", response, re.DOTALL)
    duration = int(re.search(r"duration:\s*(\d+)", response).group(1))
    impact = {}
    for token in impact_match.group(1).strip().split():
        if ":" in token:
            k, v = token.split(":")
            impact[k] = to_decimal(v)
    return {"headline": headline, "impact": impact, "duration_months": duration}


def parse_mentor(response: str) -> dict:
    response = response.strip()
    if not response:
        raise ValueError("empty response")
    roast = re.search(r"roast:\s*(.+)", response).group(1).strip()
    sharpe = float(re.search(r"sharpe_ratio:\s*([-\d.]+)", response).group(1))
    lesson = re.search(r"lesson:\s*(.+)", response).group(1).strip()
    suggestion = re.search(r"suggestion:\s*(.+)", response).group(1).strip()
    return {"roast": roast, "sharpe_ratio": sharpe, "lesson": lesson, "suggestion": suggestion}


def parse_guardrail(response: str) -> dict:
    response = response.strip()
    if not response:
        raise ValueError("empty response")
    return {"error": re.search(r"error:\s*(.+)", response).group(1).strip()}


PARSERS = {
    "agent_decision": parse_agent,
    "news_impact": parse_news,
    "sharpe_mentor": parse_mentor,
    "guardrail": parse_guardrail,
}


async def main():
    test_file = OUTPUT_FILE.with_name("retro-alpha-test.jsonl")
    if test_file.exists():
        test_file.unlink()

    waves = [
        ("agent", 5),
        ("news", 5),
        ("mentor", 5),
        ("guardrail", 5),
    ]

    all_rows = []
    for task, count in waves:
        print(f"Testing {task} with {count} rows...")
        prompts = generate_prompts(task, count)
        results = await generate_wave(prompts, f"test-{task}")
        print(f"  Got {len(results)}/{count} non-empty rows")
        all_rows.extend(results)

    with open(test_file, "w", encoding="utf-8") as f:
        for row in all_rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    print(f"\nTest output: {test_file}")
    print(f"Total test rows: {len(all_rows)}")

    parse_ok = 0
    parse_fail = 0
    for row in all_rows:
        parser = PARSERS[row["task"]]
        try:
            parser(row["response"])
            parse_ok += 1
        except Exception as e:
            parse_fail += 1
            print(f"Parse fail ({row['task']}): {e}")
            print(f"  Response: {repr(row['response'][:200])}")

    print(f"Parsed successfully: {parse_ok}/{len(all_rows)}")
    if parse_fail > 0:
        print("TEST FAILED")
        return 1

    print("TEST PASSED")
    return 0


if __name__ == "__main__":
    sys.exit(asyncio.run(main()))