Virendrasinh10 commited on
Commit
1ce69cc
·
1 Parent(s): d4e41c1

added logic for get_figures and grader implemntation

Browse files
server/earnings_analyst_environment.py CHANGED
@@ -9,6 +9,7 @@ from __future__ import annotations
9
 
10
  import math
11
  import os
 
12
  import random
13
  from typing import Any
14
  from uuid import uuid4
@@ -120,10 +121,16 @@ class EarningsAnalystEnvironment(Environment):
120
  Terminal observation with reward and metadata including ground truth.
121
  """
122
  self._state.step_count += 1
123
- label_col = self._cfg["label_col"]
124
- label_values = list(self._cfg["label_values"])
125
  row = self._current_row or {}
126
- ground_truth = str(row.get(label_col, "")).strip()
 
 
 
 
 
 
127
 
128
  grade_fn = get_grader(self._task_id)
129
  reward = float(
 
9
 
10
  import math
11
  import os
12
+ import json
13
  import random
14
  from typing import Any
15
  from uuid import uuid4
 
121
  Terminal observation with reward and metadata including ground truth.
122
  """
123
  self._state.step_count += 1
124
+ label_col = self._cfg.get("label_col", "symbol")
125
+ label_values = list(self._cfg.get("label_values", []))
126
  row = self._current_row or {}
127
+
128
+ # Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
129
+ if "xbrl_columns" in self._cfg:
130
+ gt_data = {col: row.get(col) for col in self._cfg["xbrl_columns"]}
131
+ ground_truth = json.dumps(gt_data)
132
+ else:
133
+ ground_truth = str(row.get(label_col, "")).strip()
134
 
135
  grade_fn = get_grader(self._task_id)
136
  reward = float(
tasks/get_figures/grader.py CHANGED
@@ -2,13 +2,78 @@
2
 
3
  from __future__ import annotations
4
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
8
  """
9
- Score the agent's extraction performance.
 
 
 
 
 
10
 
11
- Currently a stub until XBRL ground truth is provided.
12
- Always returns 0.0 with implemented: False in spec.
13
  """
14
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
  import json
5
+ import math
6
+ from typing import Any
7
+
8
+
9
+ def _safe_float(val: Any) -> float | None:
10
+ if val is None or val == "":
11
+ return None
12
+ try:
13
+ return float(val)
14
+ except (ValueError, TypeError):
15
+ return None
16
+
17
+
18
+ def _get_score(pred: float | None, target: float | None, tolerance: float = 0.01) -> float:
19
+ """Compare pred to target with relative error tolerance."""
20
+ if target is None:
21
+ # If ground truth is null, reward 1.0 if prediction is also null, else 0.0
22
+ return 1.0 if pred is None else 0.0
23
+
24
+ if pred is None:
25
+ return 0.0
26
+
27
+ if abs(target) < 1e-9:
28
+ return 1.0 if abs(pred) < 1e-9 else 0.0
29
+
30
+ relative_error = abs(pred - target) / abs(target)
31
+ return 1.0 if relative_error <= tolerance else 0.0
32
+
33
+
34
+ def _flatten_metrics(data: dict[str, Any]) -> dict[str, float | None]:
35
+ """Helper to flatten the nested metrics JSON provided by the agent."""
36
+ flat = {}
37
+ for section in ["income_statement", "balance_sheet", "cash_flow"]:
38
+ if section in data and isinstance(data[section], dict):
39
+ for key, val in data[section].items():
40
+ flat[key] = _safe_float(val)
41
+ return flat
42
 
43
 
44
  def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
45
  """
46
+ Score the agent's extraction performance across multiple financial metrics.
47
+
48
+ Args:
49
+ predicted: Agent's response string (expected JSON).
50
+ ground_truth: Environment's packed JSON string of XBRL values.
51
+ label_values: Unused.
52
 
53
+ Returns:
54
+ Average score (0.0 to 1.0) across all Metrics.
55
  """
56
+ try:
57
+ pred_data = json.loads(predicted)
58
+ target_data = json.loads(ground_truth)
59
+ except (json.JSONDecodeError, TypeError):
60
+ return 0.0
61
+
62
+ # Flatten the agent's nested response
63
+ pred_metrics = _flatten_metrics(pred_data)
64
+
65
+ # Environment's target_data is already flat (mapping column_name -> value)
66
+ # We need to map the canonical keys (revenue, etc.) to the column values.
67
+ from .spec import METRIC_TO_COLUMN
68
+
69
+ scores = []
70
+ for metric_key, col_name in METRIC_TO_COLUMN.items():
71
+ pred_val = pred_metrics.get(metric_key)
72
+ target_val = _safe_float(target_data.get(col_name))
73
+
74
+ scores.append(_get_score(pred_val, target_val))
75
+
76
+ if not scores:
77
+ return 0.0
78
+
79
+ return sum(scores) / len(scores)
tasks/get_figures/spec.py CHANGED
@@ -5,25 +5,65 @@ from ..types import TaskSpec
5
 
6
  CANONICAL_TASK_ID = "get_figures"
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  SPEC: TaskSpec = {
9
  "task_id": CANONICAL_TASK_ID,
10
- "implemented": False, # Safety gate: set to True once ground truth column is confirmed.
11
  "text_cols": [
12
  "earnings_transcript",
13
  "press_release_8k_body",
14
  "press_release_ex991",
15
  "press_release_ex992",
16
- "press_release_sources",
17
  ],
18
- "numerical_cols": [],
19
- "label_col": "symbol", # Placeholder
 
 
 
20
  "label_values": [],
21
  "task_instruction": (
22
  "Extract key financial figures from the provided earnings call materials.\n\n"
23
- "Return a JSON object matching this exact schema:\n"
24
- '{"revenue": <float>, "net_income": <float>, "eps": <float>}\n\n'
25
- "Use the currency specified in the documents. If a figure is not found, use null.\n"
26
- "Do not include any other keys or explanation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ),
28
- "kind": "other",
 
 
29
  }
 
5
 
6
  CANONICAL_TASK_ID = "get_figures"
7
 
8
+ # Mapping from JSON metric keys to dataset XBRL columns
9
+ METRIC_TO_COLUMN: dict[str, str] = {
10
+ "revenue": "xbrl_revenue",
11
+ "cost_of_revenue": "xbrl_cost_of_revenue",
12
+ "gross_profit": "xbrl_gross_profit",
13
+ "operating_income": "xbrl_operating_income",
14
+ "net_income": "xbrl_net_income",
15
+ "eps_basic": "xbrl_eps_basic",
16
+ "eps_diluted": "xbrl_eps_diluted",
17
+ "cash_and_cash_equivalents": "xbrl_cash_and_cash_equivalents",
18
+ "total_assets": "xbrl_total_assets",
19
+ "total_liabilities": "xbrl_total_liabilities",
20
+ "net_cash_operating_activities": "xbrl_net_cash_operating_activities",
21
+ "capital_expenditures": "xbrl_capital_expenditures",
22
+ }
23
+
24
  SPEC: TaskSpec = {
25
  "task_id": CANONICAL_TASK_ID,
26
+ "implemented": True,
27
  "text_cols": [
28
  "earnings_transcript",
29
  "press_release_8k_body",
30
  "press_release_ex991",
31
  "press_release_ex992",
 
32
  ],
33
+ "numerical_cols": [
34
+ "price_momentum_30d",
35
+ "avg_volume_20d",
36
+ ],
37
+ "label_col": "xbrl_revenue", # Primary ground truth column
38
  "label_values": [],
39
  "task_instruction": (
40
  "Extract key financial figures from the provided earnings call materials.\n\n"
41
+ "Return a JSON object matching this exact US-GAAP taxonomy schema:\n"
42
+ "{\n"
43
+ ' "taxonomy_version": "us-gaap-2024",\n'
44
+ ' "income_statement": {\n'
45
+ ' "revenue": <float>,\n'
46
+ ' "cost_of_revenue": <float>,\n'
47
+ ' "gross_profit": <float>,\n'
48
+ ' "operating_income": <float>,\n'
49
+ ' "net_income": <float>,\n'
50
+ ' "eps_basic": <float>,\n'
51
+ ' "eps_diluted": <float>\n'
52
+ " },\n"
53
+ ' "balance_sheet": {\n'
54
+ ' "cash_and_cash_equivalents": <float>,\n'
55
+ ' "total_assets": <float>,\n'
56
+ ' "total_liabilities": <float>\n'
57
+ " },\n"
58
+ ' "cash_flow": {\n'
59
+ ' "net_cash_operating_activities": <float>,\n'
60
+ ' "capital_expenditures": <float>\n'
61
+ " }\n"
62
+ "}\n\n"
63
+ "Values should ideally be in USD. If a figure is not found or not mentioned, use null.\n"
64
+ "Do not include any other keys, explanations, or markdown blocks."
65
  ),
66
+ "kind": "extraction",
67
+ # Metadata for the environment to pack all ground truth figures
68
+ "xbrl_columns": list(METRIC_TO_COLUMN.values()),
69
  }