Spaces:
Sleeping
Sleeping
Commit ·
1ce69cc
1
Parent(s): d4e41c1
added logic for get_figures and grader implemntation
Browse files- server/earnings_analyst_environment.py +10 -3
- tasks/get_figures/grader.py +69 -4
- tasks/get_figures/spec.py +49 -9
server/earnings_analyst_environment.py
CHANGED
|
@@ -9,6 +9,7 @@ from __future__ import annotations
|
|
| 9 |
|
| 10 |
import math
|
| 11 |
import os
|
|
|
|
| 12 |
import random
|
| 13 |
from typing import Any
|
| 14 |
from uuid import uuid4
|
|
@@ -120,10 +121,16 @@ class EarningsAnalystEnvironment(Environment):
|
|
| 120 |
Terminal observation with reward and metadata including ground truth.
|
| 121 |
"""
|
| 122 |
self._state.step_count += 1
|
| 123 |
-
label_col = self._cfg
|
| 124 |
-
label_values = list(self._cfg
|
| 125 |
row = self._current_row or {}
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
grade_fn = get_grader(self._task_id)
|
| 129 |
reward = float(
|
|
|
|
| 9 |
|
| 10 |
import math
|
| 11 |
import os
|
| 12 |
+
import json
|
| 13 |
import random
|
| 14 |
from typing import Any
|
| 15 |
from uuid import uuid4
|
|
|
|
| 121 |
Terminal observation with reward and metadata including ground truth.
|
| 122 |
"""
|
| 123 |
self._state.step_count += 1
|
| 124 |
+
label_col = self._cfg.get("label_col", "symbol")
|
| 125 |
+
label_values = list(self._cfg.get("label_values", []))
|
| 126 |
row = self._current_row or {}
|
| 127 |
+
|
| 128 |
+
# Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
|
| 129 |
+
if "xbrl_columns" in self._cfg:
|
| 130 |
+
gt_data = {col: row.get(col) for col in self._cfg["xbrl_columns"]}
|
| 131 |
+
ground_truth = json.dumps(gt_data)
|
| 132 |
+
else:
|
| 133 |
+
ground_truth = str(row.get(label_col, "")).strip()
|
| 134 |
|
| 135 |
grade_fn = get_grader(self._task_id)
|
| 136 |
reward = float(
|
tasks/get_figures/grader.py
CHANGED
|
@@ -2,13 +2,78 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
|
| 8 |
"""
|
| 9 |
-
Score the agent's extraction performance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
"""
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
import json
|
| 5 |
+
import math
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _safe_float(val: Any) -> float | None:
|
| 10 |
+
if val is None or val == "":
|
| 11 |
+
return None
|
| 12 |
+
try:
|
| 13 |
+
return float(val)
|
| 14 |
+
except (ValueError, TypeError):
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_score(pred: float | None, target: float | None, tolerance: float = 0.01) -> float:
|
| 19 |
+
"""Compare pred to target with relative error tolerance."""
|
| 20 |
+
if target is None:
|
| 21 |
+
# If ground truth is null, reward 1.0 if prediction is also null, else 0.0
|
| 22 |
+
return 1.0 if pred is None else 0.0
|
| 23 |
+
|
| 24 |
+
if pred is None:
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
if abs(target) < 1e-9:
|
| 28 |
+
return 1.0 if abs(pred) < 1e-9 else 0.0
|
| 29 |
+
|
| 30 |
+
relative_error = abs(pred - target) / abs(target)
|
| 31 |
+
return 1.0 if relative_error <= tolerance else 0.0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _flatten_metrics(data: dict[str, Any]) -> dict[str, float | None]:
|
| 35 |
+
"""Helper to flatten the nested metrics JSON provided by the agent."""
|
| 36 |
+
flat = {}
|
| 37 |
+
for section in ["income_statement", "balance_sheet", "cash_flow"]:
|
| 38 |
+
if section in data and isinstance(data[section], dict):
|
| 39 |
+
for key, val in data[section].items():
|
| 40 |
+
flat[key] = _safe_float(val)
|
| 41 |
+
return flat
|
| 42 |
|
| 43 |
|
| 44 |
def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
|
| 45 |
"""
|
| 46 |
+
Score the agent's extraction performance across multiple financial metrics.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
predicted: Agent's response string (expected JSON).
|
| 50 |
+
ground_truth: Environment's packed JSON string of XBRL values.
|
| 51 |
+
label_values: Unused.
|
| 52 |
|
| 53 |
+
Returns:
|
| 54 |
+
Average score (0.0 to 1.0) across all Metrics.
|
| 55 |
"""
|
| 56 |
+
try:
|
| 57 |
+
pred_data = json.loads(predicted)
|
| 58 |
+
target_data = json.loads(ground_truth)
|
| 59 |
+
except (json.JSONDecodeError, TypeError):
|
| 60 |
+
return 0.0
|
| 61 |
+
|
| 62 |
+
# Flatten the agent's nested response
|
| 63 |
+
pred_metrics = _flatten_metrics(pred_data)
|
| 64 |
+
|
| 65 |
+
# Environment's target_data is already flat (mapping column_name -> value)
|
| 66 |
+
# We need to map the canonical keys (revenue, etc.) to the column values.
|
| 67 |
+
from .spec import METRIC_TO_COLUMN
|
| 68 |
+
|
| 69 |
+
scores = []
|
| 70 |
+
for metric_key, col_name in METRIC_TO_COLUMN.items():
|
| 71 |
+
pred_val = pred_metrics.get(metric_key)
|
| 72 |
+
target_val = _safe_float(target_data.get(col_name))
|
| 73 |
+
|
| 74 |
+
scores.append(_get_score(pred_val, target_val))
|
| 75 |
+
|
| 76 |
+
if not scores:
|
| 77 |
+
return 0.0
|
| 78 |
+
|
| 79 |
+
return sum(scores) / len(scores)
|
tasks/get_figures/spec.py
CHANGED
|
@@ -5,25 +5,65 @@ from ..types import TaskSpec
|
|
| 5 |
|
| 6 |
CANONICAL_TASK_ID = "get_figures"
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
SPEC: TaskSpec = {
|
| 9 |
"task_id": CANONICAL_TASK_ID,
|
| 10 |
-
"implemented":
|
| 11 |
"text_cols": [
|
| 12 |
"earnings_transcript",
|
| 13 |
"press_release_8k_body",
|
| 14 |
"press_release_ex991",
|
| 15 |
"press_release_ex992",
|
| 16 |
-
"press_release_sources",
|
| 17 |
],
|
| 18 |
-
"numerical_cols": [
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
"label_values": [],
|
| 21 |
"task_instruction": (
|
| 22 |
"Extract key financial figures from the provided earnings call materials.\n\n"
|
| 23 |
-
"Return a JSON object matching this exact schema:\n"
|
| 24 |
-
|
| 25 |
-
"
|
| 26 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
),
|
| 28 |
-
"kind": "
|
|
|
|
|
|
|
| 29 |
}
|
|
|
|
| 5 |
|
| 6 |
CANONICAL_TASK_ID = "get_figures"
|
| 7 |
|
| 8 |
+
# Mapping from JSON metric keys to dataset XBRL columns
|
| 9 |
+
METRIC_TO_COLUMN: dict[str, str] = {
|
| 10 |
+
"revenue": "xbrl_revenue",
|
| 11 |
+
"cost_of_revenue": "xbrl_cost_of_revenue",
|
| 12 |
+
"gross_profit": "xbrl_gross_profit",
|
| 13 |
+
"operating_income": "xbrl_operating_income",
|
| 14 |
+
"net_income": "xbrl_net_income",
|
| 15 |
+
"eps_basic": "xbrl_eps_basic",
|
| 16 |
+
"eps_diluted": "xbrl_eps_diluted",
|
| 17 |
+
"cash_and_cash_equivalents": "xbrl_cash_and_cash_equivalents",
|
| 18 |
+
"total_assets": "xbrl_total_assets",
|
| 19 |
+
"total_liabilities": "xbrl_total_liabilities",
|
| 20 |
+
"net_cash_operating_activities": "xbrl_net_cash_operating_activities",
|
| 21 |
+
"capital_expenditures": "xbrl_capital_expenditures",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
SPEC: TaskSpec = {
|
| 25 |
"task_id": CANONICAL_TASK_ID,
|
| 26 |
+
"implemented": True,
|
| 27 |
"text_cols": [
|
| 28 |
"earnings_transcript",
|
| 29 |
"press_release_8k_body",
|
| 30 |
"press_release_ex991",
|
| 31 |
"press_release_ex992",
|
|
|
|
| 32 |
],
|
| 33 |
+
"numerical_cols": [
|
| 34 |
+
"price_momentum_30d",
|
| 35 |
+
"avg_volume_20d",
|
| 36 |
+
],
|
| 37 |
+
"label_col": "xbrl_revenue", # Primary ground truth column
|
| 38 |
"label_values": [],
|
| 39 |
"task_instruction": (
|
| 40 |
"Extract key financial figures from the provided earnings call materials.\n\n"
|
| 41 |
+
"Return a JSON object matching this exact US-GAAP taxonomy schema:\n"
|
| 42 |
+
"{\n"
|
| 43 |
+
' "taxonomy_version": "us-gaap-2024",\n'
|
| 44 |
+
' "income_statement": {\n'
|
| 45 |
+
' "revenue": <float>,\n'
|
| 46 |
+
' "cost_of_revenue": <float>,\n'
|
| 47 |
+
' "gross_profit": <float>,\n'
|
| 48 |
+
' "operating_income": <float>,\n'
|
| 49 |
+
' "net_income": <float>,\n'
|
| 50 |
+
' "eps_basic": <float>,\n'
|
| 51 |
+
' "eps_diluted": <float>\n'
|
| 52 |
+
" },\n"
|
| 53 |
+
' "balance_sheet": {\n'
|
| 54 |
+
' "cash_and_cash_equivalents": <float>,\n'
|
| 55 |
+
' "total_assets": <float>,\n'
|
| 56 |
+
' "total_liabilities": <float>\n'
|
| 57 |
+
" },\n"
|
| 58 |
+
' "cash_flow": {\n'
|
| 59 |
+
' "net_cash_operating_activities": <float>,\n'
|
| 60 |
+
' "capital_expenditures": <float>\n'
|
| 61 |
+
" }\n"
|
| 62 |
+
"}\n\n"
|
| 63 |
+
"Values should ideally be in USD. If a figure is not found or not mentioned, use null.\n"
|
| 64 |
+
"Do not include any other keys, explanations, or markdown blocks."
|
| 65 |
),
|
| 66 |
+
"kind": "extraction",
|
| 67 |
+
# Metadata for the environment to pack all ground truth figures
|
| 68 |
+
"xbrl_columns": list(METRIC_TO_COLUMN.values()),
|
| 69 |
}
|