intelkishan commited on
Commit
abbea70
·
1 Parent(s): 2a048ac

Implemented quarter prediction

Browse files
client.py CHANGED
@@ -53,6 +53,8 @@ class EarningsAnalystEnv(
53
  }
54
 
55
  def _parse_result(self, payload: Dict) -> StepResult[EarningsAnalystObservation]:
 
 
56
  """
57
  Parse server response into StepResult[EarningsAnalystObservation].
58
 
@@ -69,7 +71,9 @@ class EarningsAnalystEnv(
69
  task_instruction=obs_data.get("task_instruction", ""),
70
  done=payload.get("done", False),
71
  reward=payload.get("reward"),
 
72
  metadata=obs_data.get("metadata", {}),
 
73
  )
74
 
75
  return StepResult(
 
53
  }
54
 
55
  def _parse_result(self, payload: Dict) -> StepResult[EarningsAnalystObservation]:
56
+
57
+
58
  """
59
  Parse server response into StepResult[EarningsAnalystObservation].
60
 
 
71
  task_instruction=obs_data.get("task_instruction", ""),
72
  done=payload.get("done", False),
73
  reward=payload.get("reward"),
74
+ ground_truth=obs_data.get("ground_truth", ""),
75
  metadata=obs_data.get("metadata", {}),
76
+
77
  )
78
 
79
  return StepResult(
evaluate.py CHANGED
@@ -82,7 +82,8 @@ async def run_evaluation(
82
  episode_result = await run_episode(
83
  base_url=base_url,
84
  model=model,
85
- verbose=False,
 
86
  )
87
  episode_reward = float(
88
  episode_result.reward if episode_result.reward is not None else 0.0
 
82
  episode_result = await run_episode(
83
  base_url=base_url,
84
  model=model,
85
+ verbose=True,
86
+
87
  )
88
  episode_reward = float(
89
  episode_result.reward if episode_result.reward is not None else 0.0
inference.py CHANGED
@@ -54,9 +54,12 @@ class EpisodeResult:
54
  model_response_text: str | None = None
55
 
56
 
57
- def _normalize_sentiment(model_text: str, valid: list[str] | None = None) -> str:
58
- """Map model output to a canonical label; fallback to neutral."""
59
- labels = valid or DEFAULT_LABELS
 
 
 
60
  normalized_model_text = str(model_text).strip().lower()
61
  for canonical_label in labels:
62
  if normalized_model_text == canonical_label.lower():
@@ -91,18 +94,14 @@ async def predict_with_openai(
91
  valid_labels: list[str] | None = None,
92
  ) -> tuple[str, str]:
93
  """
94
- Example Chat Completions call returning a JSON object; maps to a canonical label.
95
-
96
- Replace or parameterize this when you implement tasks beyond placeholder demos.
97
  """
98
- labels = valid_labels or DEFAULT_LABELS
99
  user_content = build_user_content(obs)
100
  system_prompt = (
101
  "You are a financial analyst assistant. "
102
- "Reply with a single JSON object only, no markdown or extra text, "
103
- 'with key "sentiment" whose value is exactly one of: '
104
- + ", ".join(f'"{lab}"' for lab in labels)
105
- + "."
106
  )
107
  completion = await client.chat.completions.create(
108
  model=model,
@@ -113,13 +112,24 @@ async def predict_with_openai(
113
  response_format={"type": "json_object"},
114
  )
115
  response_text = (completion.choices[0].message.content or "").strip()
116
- predicted = "neutral"
 
 
117
  try:
118
  parsed: dict[str, Any] = json.loads(response_text)
119
- if isinstance(parsed, dict) and "sentiment" in parsed:
120
- predicted = _normalize_sentiment(str(parsed["sentiment"]), labels)
 
 
 
 
 
 
 
121
  except (json.JSONDecodeError, TypeError, ValueError):
122
- predicted = _normalize_sentiment(response_text, labels)
 
 
123
  return predicted, response_text
124
 
125
 
@@ -147,18 +157,30 @@ async def run_episode(
147
  openai_client_options: dict[str, Any] = {"api_key": api_key}
148
  if resolved_openai_base_url:
149
  openai_client_options["base_url"] = resolved_openai_base_url
 
 
 
 
150
  client = AsyncOpenAI(**openai_client_options)
151
 
 
 
152
  async with EarningsAnalystEnv(base_url=environment_base_url) as env:
153
  reset_out = await env.reset()
154
  observation = reset_out.observation
 
 
 
 
 
 
155
  predicted, response_text = await predict_with_openai(
156
- observation, client=client, model=model_name
157
  )
158
  step_out = await env.step(EarningsAnalystAction(prediction=predicted))
159
  step_observation = step_out.observation
160
- observation_metadata = getattr(step_observation, "metadata", None) or {}
161
- ground_truth_label = str(observation_metadata.get("ground_truth", ""))
162
  reward = step_out.reward
163
  if verbose:
164
  print(
 
54
  model_response_text: str | None = None
55
 
56
 
57
+ def _normalize_prediction(model_text: str, valid: list[str] | None = None) -> str:
58
+ """Map model output to a canonical label or return as is for regression."""
59
+ if not valid:
60
+ return model_text.strip()
61
+
62
+ labels = valid
63
  normalized_model_text = str(model_text).strip().lower()
64
  for canonical_label in labels:
65
  if normalized_model_text == canonical_label.lower():
 
94
  valid_labels: list[str] | None = None,
95
  ) -> tuple[str, str]:
96
  """
97
+ Example Chat Completions call returning a JSON object.
 
 
98
  """
 
99
  user_content = build_user_content(obs)
100
  system_prompt = (
101
  "You are a financial analyst assistant. "
102
+ "Your task is to analyze the provided financial data and respond "
103
+ "EXACTLY as instructed in the Task Instruction. "
104
+ "Reply with a single JSON object only, no markdown or extra text."
 
105
  )
106
  completion = await client.chat.completions.create(
107
  model=model,
 
112
  response_format={"type": "json_object"},
113
  )
114
  response_text = (completion.choices[0].message.content or "").strip()
115
+
116
+ # Try to extract the primary value based on common keys
117
+ predicted = response_text
118
  try:
119
  parsed: dict[str, Any] = json.loads(response_text)
120
+ if isinstance(parsed, dict):
121
+ # Check for common return keys
122
+ for key in ["sentiment", "move", "label", "prediction"]:
123
+ if key in parsed:
124
+ if valid_labels:
125
+ predicted = _normalize_prediction(str(parsed[key]), valid_labels)
126
+ else:
127
+ predicted = str(parsed[key])
128
+ break
129
  except (json.JSONDecodeError, TypeError, ValueError):
130
+ if valid_labels:
131
+ predicted = _normalize_prediction(response_text, valid_labels)
132
+
133
  return predicted, response_text
134
 
135
 
 
157
  openai_client_options: dict[str, Any] = {"api_key": api_key}
158
  if resolved_openai_base_url:
159
  openai_client_options["base_url"] = resolved_openai_base_url
160
+
161
+ if verbose:
162
+ print(f"DEBUG: Using base_url={resolved_openai_base_url or 'default'} model={model_name}")
163
+
164
  client = AsyncOpenAI(**openai_client_options)
165
 
166
+
167
+
168
  async with EarningsAnalystEnv(base_url=environment_base_url) as env:
169
  reset_out = await env.reset()
170
  observation = reset_out.observation
171
+ # We pass valid_labels if they exist in the observation/registry
172
+ # This implementation assumes the client can fetch labels or we hardcode.
173
+ # For simplicity, we'll try to use labels from metadata if available on reset
174
+ # Or just use None for regression.
175
+ valid_labels = getattr(observation, "label_values", None)
176
+
177
  predicted, response_text = await predict_with_openai(
178
+ observation, client=client, model=model_name, valid_labels=valid_labels
179
  )
180
  step_out = await env.step(EarningsAnalystAction(prediction=predicted))
181
  step_observation = step_out.observation
182
+ ground_truth_label = str(getattr(step_observation, "ground_truth", ""))
183
+
184
  reward = step_out.reward
185
  if verbose:
186
  print(
server/app.py CHANGED
@@ -32,7 +32,7 @@ except Exception as e: # pragma: no cover
32
  try:
33
  from ..models import EarningsAnalystAction, EarningsAnalystObservation
34
  from .earnings_analyst_environment import EarningsAnalystEnvironment
35
- except ModuleNotFoundError:
36
  from models import EarningsAnalystAction, EarningsAnalystObservation
37
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
38
 
 
32
  try:
33
  from ..models import EarningsAnalystAction, EarningsAnalystObservation
34
  from .earnings_analyst_environment import EarningsAnalystEnvironment
35
+ except (ImportError, ModuleNotFoundError):
36
  from models import EarningsAnalystAction, EarningsAnalystObservation
37
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
38
 
server/earnings_analyst_environment.py CHANGED
@@ -134,19 +134,21 @@ class EarningsAnalystEnvironment(Environment):
134
  )
135
  )
136
 
 
137
  return EarningsAnalystObservation(
138
  text_context={},
139
  numerical_context={},
140
  task_instruction=self._cfg["task_instruction"],
141
  done=True,
142
  reward=reward,
 
143
  metadata={
144
  "task_id": self._task_id,
145
  "predicted": action.prediction,
146
- "ground_truth": ground_truth,
147
  },
148
  )
149
 
 
150
  @property
151
  def state(self) -> State:
152
  """Current environment state."""
 
134
  )
135
  )
136
 
137
+
138
  return EarningsAnalystObservation(
139
  text_context={},
140
  numerical_context={},
141
  task_instruction=self._cfg["task_instruction"],
142
  done=True,
143
  reward=reward,
144
+ ground_truth=ground_truth,
145
  metadata={
146
  "task_id": self._task_id,
147
  "predicted": action.prediction,
 
148
  },
149
  )
150
 
151
+
152
  @property
153
  def state(self) -> State:
154
  """Current environment state."""
tasks/__init__.py CHANGED
@@ -3,21 +3,11 @@
3
  from __future__ import annotations
4
 
5
  from .exceptions import TaskNotImplementedError
6
- from .registry import (
7
- DEFAULT_TASK,
8
- GRADERS,
9
- TASKS,
10
- TASK_IDS,
11
- get_grader,
12
- get_task_spec,
13
- )
14
 
15
  __all__ = [
16
- "DEFAULT_TASK",
17
- "GRADERS",
18
- "TASKS",
19
- "TASK_IDS",
20
  "TaskNotImplementedError",
21
- "get_grader",
22
- "get_task_spec",
23
  ]
 
 
3
  from __future__ import annotations
4
 
5
  from .exceptions import TaskNotImplementedError
6
+ # Registry exports removed to avoid circular imports during dynamic task loading.
7
+ # Use 'from tasks.registry import ...' instead.
8
+
 
 
 
 
 
9
 
10
  __all__ = [
 
 
 
 
11
  "TaskNotImplementedError",
 
 
12
  ]
13
+
tasks/grading.py CHANGED
@@ -94,3 +94,32 @@ def grade_exact(
94
  if _normalize_text(predicted) == _normalize_text(ground_truth):
95
  return 1.0
96
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if _normalize_text(predicted) == _normalize_text(ground_truth):
95
  return 1.0
96
  return 0.0
97
+
98
+
99
+ def grade_regression(
100
+ predicted: str,
101
+ ground_truth: str,
102
+ scale: float = 0.1,
103
+ ) -> float:
104
+ """
105
+ Score a numerical prediction: exp(-abs(pred - gt) / scale).
106
+ Returns 1.0 for exact, decaying towards 0.0.
107
+ """
108
+ import math
109
+
110
+ try:
111
+ # Ground truth is passed as str(float) from the environment
112
+ gt_val = float(ground_truth)
113
+ except (ValueError, TypeError):
114
+ return 0.0
115
+
116
+ # Try to parse predicted as a pure number if it's not JSON
117
+ # (Though usually the task asks for JSON)
118
+ try:
119
+ pred_val = float(predicted)
120
+ except (ValueError, TypeError):
121
+ # Fallback: try to find a number in the string or just return 0
122
+ return 0.0
123
+
124
+ error = abs(pred_val - gt_val)
125
+ return math.exp(-error / scale)
tasks/next_quarter_move/grader.py CHANGED
@@ -1,10 +1,29 @@
1
- """Grading for ``next_quarter_move`` — implement when the task is ready."""
2
 
3
  from __future__ import annotations
 
 
 
 
4
 
5
 
6
  def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
7
- raise NotImplementedError(
8
- "Task 'next_quarter_move' is not implemented yet. "
9
- "Implement grader logic in tasks/next_quarter_move/grader.py."
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grading for ``next_quarter_move`` (regression)."""
2
 
3
  from __future__ import annotations
4
+ import json
5
+ import re
6
+
7
+ from ..grading import grade_regression
8
 
9
 
10
  def grade(predicted: str, ground_truth: str, label_values: list[str]) -> float:
11
+ """
12
+ Parses predicted string for a 'move' key or a numeric value,
13
+ then grades against ground_truth via exponential decay.
14
+ """
15
+ _ = label_values
16
+
17
+ # Try to extract number from JSON if possible
18
+ pred_val_str = predicted
19
+ try:
20
+ data = json.loads(predicted)
21
+ if isinstance(data, dict) and "move" in data:
22
+ pred_val_str = str(data["move"])
23
+ except (json.JSONDecodeError, TypeError):
24
+ # Fallback: find the first float-like thing in the string
25
+ match = re.search(r"[-+]?\d*\.\d+|\d+", predicted)
26
+ if match:
27
+ pred_val_str = match.group()
28
+
29
+ return grade_regression(pred_val_str, ground_truth, scale=0.1)
tasks/next_quarter_move/spec.py CHANGED
@@ -1,4 +1,4 @@
1
- """Task specification for ``next_quarter_move`` fill in when implementing."""
2
 
3
  from __future__ import annotations
4
 
@@ -8,11 +8,28 @@ CANONICAL_TASK_ID = "next_quarter_move"
8
 
9
  SPEC: TaskSpec = {
10
  "task_id": CANONICAL_TASK_ID,
11
- "implemented": False,
12
- "text_cols": [],
13
- "numerical_cols": [],
14
- "label_col": "",
15
- "label_values": [],
16
- "task_instruction": "",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  "kind": "regression",
18
  }
 
1
+ """Task specification for ``next_quarter_move`` (predicting return until next qtr earnings)."""
2
 
3
  from __future__ import annotations
4
 
 
8
 
9
  SPEC: TaskSpec = {
10
  "task_id": CANONICAL_TASK_ID,
11
+ "implemented": True,
12
+ "text_cols": [
13
+ "earnings_transcript",
14
+ "press_release_8k_body",
15
+ "press_release_ex991",
16
+ "press_release_ex992",
17
+ ],
18
+ "numerical_cols": [
19
+ "price_momentum_30d",
20
+ "price_momentum_90d",
21
+ "pct_from_52w_high_pt",
22
+ "avg_volume_20d",
23
+ "d_minus_1_close",
24
+ ],
25
+ "label_col": "move_next_qtr",
26
+ "label_values": [], # Regression tasks don't use categorical labels
27
+ "task_instruction": (
28
+ "Analyse the provided earnings call materials and predict the stock price movement "
29
+ "from this quarter's earnings date until the day before the next quarter's earnings date.\n\n"
30
+ "Returns a JSON object matching this exact schema:\n"
31
+ '{"move": <predicted float, e.g. 0.05 for 5% gain or -0.02 for 2% loss>}\n\n'
32
+ "Do not include any other keys or explanation."
33
+ ),
34
  "kind": "regression",
35
  }