Spaces:
Running
Running
| """ | |
| Validate the generated text-format dataset, filter parseable rows, | |
| and write a clean dataset ready for fine-tuning. | |
| """ | |
| import json | |
| import re | |
| import sys | |
| from collections import Counter | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent.parent | |
| INPUT_FILE = ROOT / "data" / "retro-alpha-training-v1.jsonl" | |
| CLEAN_FILE = ROOT / "data" / "retro-alpha-clean.jsonl" | |
| def to_decimal(value_str: str) -> float: | |
| value_str = value_str.strip() | |
| if value_str.endswith("%"): | |
| return float(value_str[:-1]) / 100.0 | |
| return float(value_str) | |
| def parse_agent(response: str) -> dict | None: | |
| try: | |
| agent = re.search(r"agent:\s*<?(\w+)>?", response).group(1).lower() | |
| action_match = re.search(r"action:\s*(buy|sell|hold)\s+(\w+)\s+([\d.%]+)", response) | |
| reason = re.search(r"reason:\s*(.+)", response).group(1).strip() | |
| sentiment = re.search(r"sentiment:\s*<?(\w+)>?", response).group(1).lower() | |
| return { | |
| "agent": agent, | |
| "actions": [{"asset": action_match.group(2), "action": action_match.group(1), "amount_pct": to_decimal(action_match.group(3)), "reason": reason}], | |
| "sentiment": sentiment | |
| } | |
| except Exception: | |
| return None | |
| def parse_news(response: str) -> dict | None: | |
| try: | |
| headline = re.search(r"headline:\s*(.+)", response).group(1).strip() | |
| impact_match = re.search(r"impact:\s*(.+?)(?:\nduration:|$)", response, re.DOTALL) | |
| duration = int(re.search(r"duration:\s*(\d+)", response).group(1)) | |
| impact = {} | |
| for token in impact_match.group(1).strip().split(): | |
| if ":" in token: | |
| k, v = token.split(":") | |
| impact[k] = to_decimal(v) | |
| required = ["cash", "fd", "gov_bonds", "nifty_50", "nifty_it", "real_estate", "crypto", "gold"] | |
| if not all(k in impact for k in required): | |
| return None | |
| return {"headline": headline, "impact": impact, "duration_months": duration} | |
| except Exception: | |
| return None | |
| def parse_mentor(response: str) -> dict | None: | |
| try: | |
| roast = re.search(r"roast:\s*(.+)", response).group(1).strip() | |
| sharpe = float(re.search(r"sharpe_ratio:\s*([-\d.]+)", response).group(1)) | |
| lesson = re.search(r"lesson:\s*(.+)", response).group(1).strip() | |
| suggestion = re.search(r"suggestion:\s*(.+)", response).group(1).strip() | |
| return {"roast": roast, "sharpe_ratio": sharpe, "lesson": lesson, "suggestion": suggestion} | |
| except Exception: | |
| return None | |
| def parse_guardrail(response: str) -> dict | None: | |
| try: | |
| return {"error": re.search(r"error:\s*(.+)", response).group(1).strip()} | |
| except Exception: | |
| return None | |
| PARSERS = { | |
| "agent_decision": parse_agent, | |
| "news_impact": parse_news, | |
| "sharpe_mentor": parse_mentor, | |
| "guardrail": parse_guardrail, | |
| } | |
| def validate(): | |
| if not INPUT_FILE.exists(): | |
| print(f"Dataset not found: {INPUT_FILE}") | |
| sys.exit(1) | |
| total = 0 | |
| valid = 0 | |
| errors = 0 | |
| empty = 0 | |
| task_counts = Counter() | |
| valid_task_counts = Counter() | |
| with open(INPUT_FILE, "r", encoding="utf-8") as f_in, open(CLEAN_FILE, "w", encoding="utf-8") as f_out: | |
| for line_num, line in enumerate(f_in, 1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| total += 1 | |
| try: | |
| row = json.loads(line) | |
| except json.JSONDecodeError as e: | |
| print(f"Line {line_num}: JSON parse error: {e}") | |
| errors += 1 | |
| continue | |
| task = row.get("task") | |
| task_counts[task] += 1 | |
| response = row.get("response", "").strip() | |
| if not response: | |
| empty += 1 | |
| continue | |
| parser = PARSERS.get(task) | |
| if not parser: | |
| print(f"Line {line_num}: unknown task '{task}'") | |
| errors += 1 | |
| continue | |
| parsed = parser(response) | |
| if parsed is None: | |
| errors += 1 | |
| if errors <= 20: | |
| print(f"Line {line_num} ({task}): parse error") | |
| print(f" Response: {repr(response[:200])}") | |
| continue | |
| valid += 1 | |
| valid_task_counts[task] += 1 | |
| f_out.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| print("\n=== Validation Report ===") | |
| print(f"Total rows: {total}") | |
| print(f"Valid rows: {valid}") | |
| print(f"Invalid/empty rows: {errors + empty}") | |
| print(f"Empty responses: {empty}") | |
| print("Task distribution (input):") | |
| for task, count in task_counts.most_common(): | |
| print(f" {task}: {count}") | |
| print("Valid task distribution:") | |
| for task, count in valid_task_counts.most_common(): | |
| print(f" {task}: {count}") | |
| print(f"\nClean dataset: {CLEAN_FILE}") | |
| if valid >= 1500: | |
| print(f"\nDataset is usable for fine-tuning ({valid} valid rows).") | |
| return 0 | |
| else: | |
| print(f"\nOnly {valid} valid rows. Consider generating more.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(validate()) | |