RudrakshNanavaty commited on
Commit
fb8bae3
·
1 Parent(s): f9e8ac9

Improved UI

Browse files
app.py CHANGED
@@ -1,10 +1,15 @@
 
1
  import os
2
  import sys
3
  import json
4
  import gradio as gr
5
- from typing import Optional
 
6
  from dotenv import load_dotenv
7
  from openai import AsyncOpenAI
 
 
 
8
 
9
  # Ensure the project root is in sys.path
10
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
@@ -13,21 +18,799 @@ try:
13
  from earnings_analyst.server.earnings_analyst_environment import (
14
  EarningsAnalystEnvironment,
15
  )
 
 
 
 
 
16
  from earnings_analyst.models import (
17
  EarningsAnalystAction,
18
  EarningsAnalystObservation,
19
  )
 
20
  from earnings_analyst.tasks.registry import TASK_IDS
21
  except (ImportError, ModuleNotFoundError):
22
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
 
 
 
 
 
23
  from models import EarningsAnalystAction, EarningsAnalystObservation
 
24
  from tasks.registry import TASK_IDS
25
 
26
  load_dotenv()
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class State:
30
- def __init__(self):
31
  self.env: Optional[EarningsAnalystEnvironment] = None
32
  self.obs: Optional[EarningsAnalystObservation] = None
33
  self.task_id: str = "sentiment_label"
@@ -36,65 +819,100 @@ class State:
36
  state = State()
37
 
38
 
39
- async def reset_env(task_id: str):
 
 
 
 
 
40
  state.task_id = task_id
41
  state.env = EarningsAnalystEnvironment(task_id=task_id)
42
- state.obs = state.env.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Format observation for display
45
- text_context = ""
46
- if state.obs.text_context:
47
- for name, text in sorted(state.obs.text_context.items()):
48
- text_context += f"### {name}\n{text}\n\n"
49
 
50
- numerical_context = (
51
- json.dumps(state.obs.numerical_context, indent=2)
52
- if state.obs.numerical_context
53
- else "No numerical data."
54
  )
55
 
56
  return [
 
57
  state.obs.task_instruction,
58
- text_context,
59
- numerical_context,
60
- gr.update(visible=True), # Prediction row
61
- gr.update(visible=False), # Result row
62
- "", # Prediction input
63
- "", # Log/Message
 
64
  ]
65
 
66
 
67
- async def step_env(prediction: str):
68
  if not state.env or not state.obs:
69
- return [gr.update(), "Error: Environment not initialized. Click Reset."]
 
 
 
 
 
 
 
 
70
 
71
  action = EarningsAnalystAction(prediction=prediction)
72
  state.obs = state.env.step(action)
73
 
74
- reward = state.obs.reward
 
75
  ground_truth = getattr(state.obs, "ground_truth", "N/A")
76
 
77
- result_text = f"**Reward:** {reward:.4f} \n\n**Ground Truth:** {ground_truth}"
78
-
79
  return [
80
- gr.update(visible=False), # Hide prediction row
81
- gr.update(visible=True, value=result_text), # Show result row
82
  ]
83
 
84
 
85
- async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
 
 
 
 
 
 
 
 
 
86
  if not api_key:
87
- return [gr.update()] * 6 + ["Please provide an API Key."]
88
 
89
- # 1. Reset
90
- out = await reset_env(task_id)
 
91
 
92
- # 2. Predict with LLM
93
- client_params = {"api_key": api_key}
94
- if base_url:
95
- client_params["base_url"] = base_url
96
 
97
- client = AsyncOpenAI(**client_params)
 
 
 
98
 
99
  user_content = f"{state.obs.task_instruction}\n\n"
100
  if state.obs.text_context:
@@ -125,17 +943,10 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
125
  parsed = json.loads(response_text)
126
  prediction = str(parsed.get("prediction", response_text))
127
 
128
- # 3. Step
129
  step_out = await step_env(prediction)
130
 
131
- # Update UI components
132
- # out: [instr, text, num, pred_row, res_row, pred_input, msg]
133
- # step_out: [pred_row, res_row]
134
-
135
  return [
136
- out[0],
137
- out[1],
138
- out[2],
139
  step_out[0],
140
  step_out[1],
141
  prediction,
@@ -143,42 +954,84 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
143
  ]
144
 
145
  except Exception as e:
146
- return [out[0], out[1], out[2], out[3], out[4], "", f"Error: {str(e)}"]
147
 
148
 
149
- # Custom CSS for a premium look
150
- custom_css = """
151
- footer {visibility: hidden}
152
- .container {
153
- max-width: 1100px;
154
- margin: auto;
155
- padding-top: 2rem;
156
- font-family: 'Inter', system-ui, -apple-system, sans-serif;
157
- }
158
- .header { text-align: center; margin-bottom: 2rem; }
159
- .header h1 { font-weight: 800; font-size: 2.5rem; background: linear-gradient(90deg, #ff4d94, #ff85b3); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
160
- .header p { color: #666; font-size: 1.1rem; }
161
- .card { border-radius: 12px; border: 1px solid #eee; background: white; padding: 1.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05); }
162
- .context-box { height: 400px; overflow-y: auto; border: 1px solid #f0f0f0; border-radius: 8px; padding: 1rem; background: #fafafa; }
163
- """
164
 
165
- with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
166
- with gr.Column(elem_classes="container"):
167
- with gr.Column(elem_classes="header"):
168
- gr.Markdown("# 🏑 Earnings Analyst")
169
- gr.Markdown(
170
- "Interactive environment for financial analysis tasks using [OpenEnv](https://github.com/meta-pytorch/OpenEnv). Evaluate agents or your own analysis on earnings call data."
171
- )
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  with gr.Row():
174
- with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Group():
176
  task_select = gr.Dropdown(
177
- choices=TASK_IDS, value="sentiment_label", label="Active Task"
 
 
178
  )
179
- reset_btn = gr.Button("🔄 New Episode", variant="primary")
180
 
181
- gr.Markdown("### 🤖 Auto-Agent Settings")
182
  with gr.Group():
183
  api_key = gr.Textbox(
184
  label="OpenAI API Key",
@@ -192,77 +1045,167 @@ with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
192
  placeholder="https://api.openai.com/v1",
193
  value=os.environ.get("API_BASE_URL", ""),
194
  )
195
- agent_btn = gr.Button("🚀 Run LLM Agent", variant="secondary")
196
 
 
197
  with gr.Column(scale=2):
 
 
 
198
  with gr.Tabs():
199
- with gr.TabItem("Observation"):
200
- instr_view = gr.Markdown("### Instruction\n*N/A*")
201
- with gr.Row():
202
- with gr.Column():
203
- gr.Markdown("#### Text Context")
204
- text_view = gr.Markdown(elem_classes="context-box")
205
- with gr.Column():
206
- gr.Markdown("#### Numerical Context")
207
- num_view = gr.Code(
208
- label="JSON",
209
- language="json",
210
- elem_classes="context-box",
211
- )
212
-
213
- with gr.TabItem("Analysis"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  with gr.Column(visible=False) as prediction_row:
215
  pred_input = gr.Textbox(
216
- label="Your Prediction / Analysis Output",
217
  placeholder="e.g. bullish, or 0.05",
218
  )
219
- submit_btn = gr.Button("Submit Analysis", variant="primary")
220
 
221
- result_view = gr.Markdown(visible=False, elem_classes="card")
222
  message_view = gr.Textbox(
223
- label="Agent Log / Error Messages", interactive=False
 
 
 
 
 
224
  )
225
 
226
- # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  reset_btn.click(
228
  fn=reset_env,
229
- inputs=[task_select],
230
- outputs=[
231
- instr_view,
232
- text_view,
233
- num_view,
234
- prediction_row,
235
- result_view,
236
- pred_input,
237
- message_view,
238
- ],
239
  )
240
-
241
  submit_btn.click(
242
  fn=step_env, inputs=[pred_input], outputs=[prediction_row, result_view]
243
  )
244
-
245
  agent_btn.click(
246
  fn=run_agent,
247
- inputs=[task_select, api_key, model_name, base_url],
248
- outputs=[
249
- instr_view,
250
- text_view,
251
- num_view,
252
- prediction_row,
253
- result_view,
254
- pred_input,
255
- message_view,
256
  ],
 
257
  )
258
 
259
- from fastapi import FastAPI
260
- from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- # Create the main FastAPI app
263
  api_app = FastAPI(title="Earnings Analyst API")
264
 
265
- # Add CORS middleware
266
  api_app.add_middleware(
267
  CORSMiddleware,
268
  allow_origins=["*"],
@@ -270,35 +1213,37 @@ api_app.add_middleware(
270
  allow_headers=["*"],
271
  )
272
 
273
- from fastapi.responses import RedirectResponse
274
-
275
 
276
  @api_app.get("/health")
277
- async def health():
278
  return {"status": "ok", "environment": "earnings_analyst"}
279
 
280
 
281
  @api_app.get("/web")
282
- async def web_redirect():
283
  return RedirectResponse(url="/")
284
 
285
 
286
- # Import the environment app
287
  try:
288
  from earnings_analyst.server.app import app as env_app
289
  except (ImportError, ModuleNotFoundError):
290
  from server.app import app as env_app
291
 
292
- # Mount the environment API at both root and /api for compatibility
293
- # Root level is needed for OpenEnv standard validation
294
  api_app.include_router(env_app.router)
295
  api_app.mount("/api", env_app)
296
 
297
- # Wrap Gradio with FastAPI
298
- app = gr.mount_gradio_app(api_app, demo, path="/")
 
 
 
 
 
 
 
299
 
300
 
301
- def main():
302
  import uvicorn
303
 
304
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import html
2
  import os
3
  import sys
4
  import json
5
  import gradio as gr
6
+ from gradio.themes.utils import colors, sizes, fonts
7
+ from typing import Any, Optional
8
  from dotenv import load_dotenv
9
  from openai import AsyncOpenAI
10
+ from fastapi import FastAPI
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import RedirectResponse
13
 
14
  # Ensure the project root is in sys.path
15
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
 
18
  from earnings_analyst.server.earnings_analyst_environment import (
19
  EarningsAnalystEnvironment,
20
  )
21
+ from earnings_analyst.server.episode_index import (
22
+ RANDOM_EPISODE_LABEL,
23
+ format_episode_id,
24
+ get_episode_index,
25
+ )
26
  from earnings_analyst.models import (
27
  EarningsAnalystAction,
28
  EarningsAnalystObservation,
29
  )
30
+ from earnings_analyst.tasks.exceptions import TaskNotImplementedError
31
  from earnings_analyst.tasks.registry import TASK_IDS
32
  except (ImportError, ModuleNotFoundError):
33
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
34
+ from server.episode_index import (
35
+ RANDOM_EPISODE_LABEL,
36
+ format_episode_id,
37
+ get_episode_index,
38
+ )
39
  from models import EarningsAnalystAction, EarningsAnalystObservation
40
+ from tasks.exceptions import TaskNotImplementedError
41
  from tasks.registry import TASK_IDS
42
 
43
  load_dotenv()
44
 
45
+ # One accordion per possible text column (task specs use these keys).
46
+ TEXT_KEYS = (
47
+ "earnings_transcript",
48
+ "press_release_8k_body",
49
+ "press_release_ex991",
50
+ "press_release_ex992",
51
+ "press_release_sources",
52
+ )
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Theme — Zerodha/Kite-inspired
56
+ # ---------------------------------------------------------------------------
57
+
58
+ zerodha_blue = colors.Color(
59
+ c50="#eff6ff",
60
+ c100="#dbeafe",
61
+ c200="#bfdbfe",
62
+ c300="#93c5fd",
63
+ c400="#60a5fa",
64
+ c500="#387ED1",
65
+ c600="#2563eb",
66
+ c700="#1d4ed8",
67
+ c800="#1e40af",
68
+ c900="#1e3a8a",
69
+ c950="#172554",
70
+ )
71
+
72
+ theme = gr.themes.Base( # type: ignore[attr-defined]
73
+ primary_hue=zerodha_blue,
74
+ secondary_hue=zerodha_blue,
75
+ neutral_hue=colors.slate,
76
+ radius_size=sizes.radius_md,
77
+ font=(fonts.GoogleFont("Inter"), "system-ui", "sans-serif"),
78
+ font_mono=(fonts.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"),
79
+ )
80
+
81
+ theme.set(
82
+ # Body
83
+ body_background_fill="#F8F9FA",
84
+ body_background_fill_dark="#18181b",
85
+ body_text_color="#1a1a1a",
86
+ body_text_color_dark="#f4f4f5",
87
+ body_text_color_subdued="#6b7280",
88
+ body_text_color_subdued_dark="#a1a1aa",
89
+ # Blocks / cards
90
+ block_background_fill="#ffffff",
91
+ block_background_fill_dark="#27272a",
92
+ block_border_color="#e5e7eb",
93
+ block_border_color_dark="#3f3f46",
94
+ block_border_width="1px",
95
+ block_radius="12px",
96
+ block_shadow="none",
97
+ block_shadow_dark="none",
98
+ block_label_background_fill="#f9fafb",
99
+ block_label_background_fill_dark="#3f3f46",
100
+ block_label_border_color="#e5e7eb",
101
+ block_label_border_color_dark="#52525b",
102
+ block_label_text_color="#374151",
103
+ block_label_text_color_dark="#d4d4d8",
104
+ # Inputs
105
+ input_background_fill="#ffffff",
106
+ input_background_fill_dark="#3f3f46",
107
+ input_background_fill_focus="#ffffff",
108
+ input_background_fill_focus_dark="#52525b",
109
+ input_border_color="#d1d5db",
110
+ input_border_color_dark="#52525b",
111
+ input_border_color_focus="#387ED1",
112
+ input_border_color_focus_dark="#60a5fa",
113
+ input_border_width="1px",
114
+ input_radius="8px",
115
+ input_shadow="none",
116
+ input_shadow_dark="none",
117
+ input_shadow_focus="0 0 0 3px rgba(56, 126, 209, 0.15)",
118
+ input_shadow_focus_dark="0 0 0 3px rgba(96, 165, 250, 0.2)",
119
+ input_placeholder_color="#9ca3af",
120
+ input_placeholder_color_dark="#71717a",
121
+ # Primary buttons — solid Zerodha blue
122
+ button_primary_background_fill="#387ED1",
123
+ button_primary_background_fill_dark="#387ED1",
124
+ button_primary_background_fill_hover="#2563eb",
125
+ button_primary_background_fill_hover_dark="#2563eb",
126
+ button_primary_text_color="#ffffff",
127
+ button_primary_text_color_dark="#ffffff",
128
+ button_primary_text_color_hover="#ffffff",
129
+ button_primary_text_color_hover_dark="#ffffff",
130
+ button_primary_border_color="#387ED1",
131
+ button_primary_border_color_dark="#387ED1",
132
+ button_primary_border_color_hover="#2563eb",
133
+ button_primary_border_color_hover_dark="#2563eb",
134
+ button_primary_shadow="none",
135
+ button_primary_shadow_dark="none",
136
+ button_primary_shadow_hover="none",
137
+ button_primary_shadow_hover_dark="none",
138
+ # Secondary buttons — subtle outline
139
+ button_secondary_background_fill="#ffffff",
140
+ button_secondary_background_fill_dark="#3f3f46",
141
+ button_secondary_background_fill_hover="#eff6ff",
142
+ button_secondary_background_fill_hover_dark="#52525b",
143
+ button_secondary_text_color="#387ED1",
144
+ button_secondary_text_color_dark="#93c5fd",
145
+ button_secondary_text_color_hover="#2563eb",
146
+ button_secondary_text_color_hover_dark="#bfdbfe",
147
+ button_secondary_border_color="#d1d5db",
148
+ button_secondary_border_color_dark="#52525b",
149
+ button_secondary_border_color_hover="#387ED1",
150
+ button_secondary_border_color_hover_dark="#60a5fa",
151
+ button_secondary_shadow="none",
152
+ button_secondary_shadow_dark="none",
153
+ button_secondary_shadow_hover="none",
154
+ button_secondary_shadow_hover_dark="none",
155
+ # Links
156
+ link_text_color="#387ED1",
157
+ link_text_color_dark="#60a5fa",
158
+ link_text_color_hover="#2563eb",
159
+ link_text_color_hover_dark="#93c5fd",
160
+ link_text_color_active="#1d4ed8",
161
+ link_text_color_active_dark="#bfdbfe",
162
+ link_text_color_visited="#387ED1",
163
+ link_text_color_visited_dark="#60a5fa",
164
+ # Accent
165
+ color_accent="#387ED1",
166
+ color_accent_soft="#eff6ff",
167
+ color_accent_soft_dark="#1e3a8a",
168
+ border_color_accent="#387ED1",
169
+ border_color_accent_dark="#60a5fa",
170
+ border_color_accent_subdued="#bfdbfe",
171
+ border_color_accent_subdued_dark="#1d4ed8",
172
+ )
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # CSS overrides
176
+ # ---------------------------------------------------------------------------
177
+
178
+ custom_css = """
179
+ /* Align custom panels with theme block/input radii */
180
+ :root {
181
+ --ea-radius-block: 12px;
182
+ --ea-radius-control: 8px;
183
+ }
184
+
185
+ footer { visibility: hidden; }
186
+
187
+ /* Page shell */
188
+ .app-shell {
189
+ max-width: 1140px;
190
+ margin: 0 auto;
191
+ padding: 1.25rem 1rem 2rem;
192
+ }
193
+
194
+ /* Header strip */
195
+ .header-row {
196
+ display: flex;
197
+ align-items: flex-start;
198
+ justify-content: space-between;
199
+ margin-bottom: 1.25rem;
200
+ padding-bottom: 1rem;
201
+ border-bottom: 1px solid #e5e7eb;
202
+ }
203
+ .dark .header-row { border-color: #3f3f46; }
204
+
205
+ /* Status badge */
206
+ .status-badge > .prose p {
207
+ margin: 0;
208
+ font-size: 0.82rem;
209
+ color: #6b7280;
210
+ text-align: right;
211
+ }
212
+ .dark .status-badge > .prose p { color: #a1a1aa; }
213
+
214
+ /* Theme toggle — rounding only on outer frame; Gradio may add per-button radii (fix asymmetry) */
215
+ .theme-toggle .wrap {
216
+ gap: 0 !important;
217
+ display: inline-flex !important;
218
+ flex-direction: row !important;
219
+ align-items: stretch;
220
+ border-radius: var(--ea-radius-control);
221
+ overflow: hidden;
222
+ box-sizing: border-box;
223
+ }
224
+ /* Prefer fieldset / radiogroup as the stroked frame; otherwise stroke the wrap (e.g. button UI) */
225
+ .theme-toggle fieldset,
226
+ .theme-toggle [role="radiogroup"] {
227
+ display: inline-flex !important;
228
+ flex-direction: row !important;
229
+ align-items: stretch;
230
+ padding: 0 !important;
231
+ margin: 0 !important;
232
+ border: 1px solid #d1d5db !important;
233
+ border-radius: var(--ea-radius-control);
234
+ overflow: hidden;
235
+ gap: 0 !important;
236
+ box-sizing: border-box;
237
+ }
238
+ .theme-toggle .wrap:not(:has(fieldset)):not(:has([role="radiogroup"])) {
239
+ border: 1px solid #d1d5db;
240
+ }
241
+ .theme-toggle .gap-2 { gap: 0 !important; }
242
+ .theme-toggle input[type=radio] { display: none; }
243
+ .theme-toggle label {
244
+ display: inline-flex;
245
+ align-items: center;
246
+ gap: 4px;
247
+ padding: 4px 12px;
248
+ font-size: 0.78rem;
249
+ font-weight: 500;
250
+ cursor: pointer;
251
+ border: none !important;
252
+ border-radius: 0 !important;
253
+ border-right: 1px solid #d1d5db !important;
254
+ background: #fff;
255
+ color: #374151;
256
+ transition: background 0.15s, color 0.15s, border-color 0.15s;
257
+ white-space: nowrap;
258
+ }
259
+ .theme-toggle label:last-of-type { border-right: none !important; }
260
+ /* Segments: zero radius so only the group clips to --ea-radius-control (overrides Gradio/Tailwind per-side radii) */
261
+ .theme-toggle button,
262
+ .theme-toggle [role="radiogroup"] button,
263
+ .theme-toggle button:first-child,
264
+ .theme-toggle button:last-child {
265
+ border: none !important;
266
+ border-radius: 0 !important;
267
+ border-top-left-radius: 0 !important;
268
+ border-top-right-radius: 0 !important;
269
+ border-bottom-left-radius: 0 !important;
270
+ border-bottom-right-radius: 0 !important;
271
+ border-right: 1px solid #d1d5db !important;
272
+ margin: 0 !important;
273
+ box-shadow: none !important;
274
+ }
275
+ .theme-toggle button:last-child { border-right: none !important; }
276
+ /* Gradio 4+ uses buttons; Zerodha blue for resolved/selected segment */
277
+ .theme-toggle button.selected,
278
+ .theme-toggle button.ea-theme-active,
279
+ .theme-toggle button[aria-checked="true"],
280
+ .theme-toggle button[data-state="checked"] {
281
+ background: #387ED1 !important;
282
+ color: #fff !important;
283
+ border-color: #387ED1 !important;
284
+ z-index: 1;
285
+ }
286
+ .theme-toggle input[type=radio]:checked + label,
287
+ .theme-toggle label.selected {
288
+ background: #387ED1;
289
+ color: #fff;
290
+ border-color: #387ED1;
291
+ z-index: 1;
292
+ }
293
+ /* Same highlight when .dark is on body/html (do not rely on html.dark alone) */
294
+ .dark .theme-toggle button.selected,
295
+ .dark .theme-toggle button.ea-theme-active,
296
+ .dark .theme-toggle button[aria-checked="true"],
297
+ .dark .theme-toggle button[data-state="checked"] {
298
+ background: #387ED1 !important;
299
+ color: #fff !important;
300
+ border-color: #387ED1 !important;
301
+ }
302
+ body.dark .theme-toggle button.selected,
303
+ body.dark .theme-toggle button.ea-theme-active,
304
+ body.dark .theme-toggle button[aria-checked="true"],
305
+ body.dark .theme-toggle button[data-state="checked"],
306
+ html.dark .theme-toggle button.selected,
307
+ html.dark .theme-toggle button.ea-theme-active,
308
+ html.dark .theme-toggle button[aria-checked="true"],
309
+ html.dark .theme-toggle button[data-state="checked"] {
310
+ background: #387ED1 !important;
311
+ color: #fff !important;
312
+ border-color: #387ED1 !important;
313
+ }
314
+ /* System: Gradio keeps "System" selected — neutralize that segment; .ea-theme-active marks resolved Light/Dark */
315
+ html[data-ea-theme="system"][data-ea-resolved="light"] .theme-toggle button.selected:not(.ea-theme-active) {
316
+ background: #fff !important;
317
+ color: #374151 !important;
318
+ border-right-color: #d1d5db !important;
319
+ }
320
+ html[data-ea-theme="system"][data-ea-resolved="dark"] .theme-toggle button.selected:not(.ea-theme-active) {
321
+ background: #3f3f46 !important;
322
+ color: #d4d4d8 !important;
323
+ border-right-color: #52525b !important;
324
+ }
325
+ .dark .theme-toggle .wrap:not(:has(fieldset)):not(:has([role="radiogroup"])) {
326
+ border-color: #52525b;
327
+ }
328
+ .dark .theme-toggle fieldset,
329
+ .dark .theme-toggle [role="radiogroup"] {
330
+ border-color: #52525b !important;
331
+ }
332
+ .dark .theme-toggle button { border-right-color: #52525b !important; }
333
+ .dark .theme-toggle label {
334
+ background: #3f3f46;
335
+ color: #d4d4d8;
336
+ border-right-color: #52525b !important;
337
+ }
338
+ .dark .theme-toggle input[type=radio]:checked + label,
339
+ .dark .theme-toggle label.selected {
340
+ background: #387ED1;
341
+ color: #fff;
342
+ border-color: #387ED1;
343
+ }
344
+
345
+ /* Gradio index.html uses @media (prefers-color-scheme: dark) on body — override when forcing light */
346
+ @media (prefers-color-scheme: dark) {
347
+ html[data-ea-theme="light"] body {
348
+ background: var(--bg, #f8f9fa) !important;
349
+ color: var(--col, #1a1a1a) !important;
350
+ }
351
+ }
352
+
353
+ /* Task instruction callout — top accent + even border (no heavy left margin) */
354
+ .task-instruction {
355
+ border: 1px solid #bfdbfe;
356
+ border-top: 3px solid #387ED1;
357
+ background: #f0f6ff;
358
+ border-radius: var(--ea-radius-block);
359
+ padding: 0.9rem 1.1rem;
360
+ margin-bottom: 1rem;
361
+ box-sizing: border-box;
362
+ }
363
+ .dark .task-instruction {
364
+ background: rgba(30, 58, 138, 0.14);
365
+ border-color: #3f3f46;
366
+ border-top-color: #60a5fa;
367
+ }
368
+
369
+ /* Market Data JSON — full width, aligned with section (no label-column indent) */
370
+ .market-data-json {
371
+ width: 100% !important;
372
+ border: 1px solid #e5e7eb;
373
+ border-radius: var(--ea-radius-block);
374
+ overflow: hidden;
375
+ box-sizing: border-box;
376
+ }
377
+ .dark .market-data-json {
378
+ border-color: #3f3f46;
379
+ }
380
+ .market-data-json > .wrap {
381
+ padding: 0 !important;
382
+ gap: 0 !important;
383
+ }
384
+ .market-data-json pre {
385
+ margin: 0 !important;
386
+ width: 100%;
387
+ box-sizing: border-box;
388
+ border-radius: 0 !important;
389
+ }
390
+
391
+ /* Accordion content scroll */
392
+ .accordion-scroll > .prose {
393
+ max-height: 280px;
394
+ overflow-y: auto;
395
+ padding-right: 4px;
396
+ }
397
+
398
+ /* Tab active indicator */
399
+ .tabs > .tab-nav button.selected {
400
+ border-bottom: 2px solid #387ED1 !important;
401
+ color: #387ED1 !important;
402
+ font-weight: 600;
403
+ }
404
+ .dark .tabs > .tab-nav button.selected {
405
+ border-bottom-color: #60a5fa !important;
406
+ color: #60a5fa !important;
407
+ }
408
+
409
+ /* Result / error cards */
410
+ .result-card {
411
+ border-radius: var(--ea-radius-block);
412
+ padding: 0.9rem 1.1rem;
413
+ border-width: 1.5px;
414
+ border-style: solid;
415
+ background: #fff;
416
+ margin-top: 0.5rem;
417
+ }
418
+ .dark .result-card { background: #27272a; }
419
+ .result-card .rc-title {
420
+ font-weight: 600;
421
+ font-size: 0.92rem;
422
+ margin-bottom: 0.4rem;
423
+ color: #1a1a1a;
424
+ }
425
+ .dark .result-card .rc-title { color: #f4f4f5; }
426
+ .result-card .rc-body {
427
+ font-size: 0.9rem;
428
+ color: #374151;
429
+ line-height: 1.6;
430
+ }
431
+ .dark .result-card .rc-body { color: #d4d4d8; }
432
+
433
+ /* Result variants (Analysis tab HTML) — avoid inline light bg + .dark text */
434
+ .result-card.result-ok {
435
+ background: #f0fdf4;
436
+ border-color: #bbf7d0;
437
+ }
438
+ .result-card.result-ok .rc-title { color: #14532d; }
439
+ .result-card.result-ok .rc-body { color: #374151; }
440
+ .dark .result-card.result-ok {
441
+ background: rgba(20, 83, 45, 0.42);
442
+ border-color: #15803d;
443
+ }
444
+ .dark .result-card.result-ok .rc-title { color: #bbf7d0; }
445
+ .dark .result-card.result-ok .rc-body { color: #d4d4d8; }
446
+
447
+ .result-card.result-bad {
448
+ background: #fff1f2;
449
+ border-color: #fecaca;
450
+ }
451
+ .result-card.result-bad .rc-title { color: #991b1b; }
452
+ .result-card.result-bad .rc-body { color: #374151; }
453
+ .dark .result-card.result-bad {
454
+ background: rgba(127, 29, 29, 0.38);
455
+ border-color: #b91c1c;
456
+ }
457
+ .dark .result-card.result-bad .rc-title { color: #fecaca; }
458
+ .dark .result-card.result-bad .rc-body { color: #d4d4d8; }
459
+
460
+ .result-card.result-notice {
461
+ background: #fffbeb;
462
+ border-color: #fde68a;
463
+ }
464
+ .result-card.result-notice .rc-title { color: #92400e; }
465
+ .result-card.result-notice .rc-body { color: #374151; }
466
+ .dark .result-card.result-notice {
467
+ background: rgba(120, 53, 15, 0.42);
468
+ border-color: #d97706;
469
+ }
470
+ .dark .result-card.result-notice .rc-title { color: #fde68a; }
471
+ .dark .result-card.result-notice .rc-body { color: #d4d4d8; }
472
+
473
+ /* Sidebar section labels */
474
+ .section-label {
475
+ font-size: 0.78rem;
476
+ font-weight: 700;
477
+ letter-spacing: 0.06em;
478
+ text-transform: uppercase;
479
+ color: #9ca3af;
480
+ margin: 1rem 0 0.4rem;
481
+ }
482
+ .dark .section-label { color: #71717a; }
483
+ """
484
+
485
+ # ---------------------------------------------------------------------------
486
+ # JS — dark / light / system toggle
487
+ # ---------------------------------------------------------------------------
488
+
489
+ theme_toggle_js = """
490
+ (() => {
491
+ const KEY = "ea-theme";
492
+ const PREF_INDEX = { system: 0, light: 1, dark: 2 };
493
+
494
+ function wantsDark(pref) {
495
+ if (pref === "dark") return true;
496
+ if (pref === "light") return false;
497
+ return window.matchMedia("(prefers-color-scheme: dark)").matches;
498
+ }
499
+
500
+ function getToggleRoot() {
501
+ return document.querySelector(".theme-toggle");
502
+ }
503
+
504
+ function getToggleButtons() {
505
+ const root = getToggleRoot();
506
+ if (!root) return [];
507
+ return [...root.querySelectorAll("button")];
508
+ }
509
+
510
+ function isButtonSelected(b) {
511
+ return (
512
+ b.classList.contains("selected") ||
513
+ b.getAttribute("aria-checked") === "true" ||
514
+ b.dataset.state === "checked"
515
+ );
516
+ }
517
+
518
+ /** Gradio Radio value ↔ saved preference (fixes default "System" after reload). */
519
+ function syncRadioToStorage() {
520
+ const pref = localStorage.getItem(KEY) || "system";
521
+ const want = PREF_INDEX[pref];
522
+ if (want === undefined) return;
523
+ const buttons = getToggleButtons();
524
+ if (buttons.length < 3) return;
525
+ const selected = buttons.findIndex(isButtonSelected);
526
+ if (selected !== want) buttons[want].click();
527
+ }
528
+
529
+ /** Highlight effective Light/Dark when preference is system; match explicit light/dark. */
530
+ function updateThemeToggleVisual() {
531
+ const pref = localStorage.getItem(KEY) || "system";
532
+ const buttons = getToggleButtons();
533
+ if (buttons.length < 3) return;
534
+
535
+ // Reset all buttons: remove class, clear highlight, and force radius to 0 via inline
536
+ // !important (beats Gradio's Tailwind rounded-l-lg / rounded-r-lg cascade)
537
+ buttons.forEach((b) => {
538
+ b.classList.remove("ea-theme-active");
539
+ b.style.removeProperty("background");
540
+ b.style.removeProperty("color");
541
+ b.style.removeProperty("border-color");
542
+ ["border-radius",
543
+ "border-top-left-radius", "border-top-right-radius",
544
+ "border-bottom-left-radius", "border-bottom-right-radius",
545
+ ].forEach((p) => b.style.setProperty(p, "0", "important"));
546
+ });
547
+
548
+ const resolvedDark = wantsDark(pref);
549
+ let activeBtn = null;
550
+ if (pref === "system") {
551
+ activeBtn = resolvedDark ? buttons[2] : buttons[1];
552
+ } else if (pref === "dark") {
553
+ activeBtn = buttons[2];
554
+ } else {
555
+ activeBtn = buttons[1];
556
+ }
557
+ if (activeBtn) {
558
+ activeBtn.classList.add("ea-theme-active");
559
+ // Inline !important overrides Gradio's dark-mode stylesheet — no cascade battle
560
+ activeBtn.style.setProperty("background", "#387ED1", "important");
561
+ activeBtn.style.setProperty("color", "#fff", "important");
562
+ activeBtn.style.setProperty("border-color", "#387ED1", "important");
563
+ }
564
+ }
565
+
566
+ function applyTheme(pref) {
567
+ const dark = wantsDark(pref);
568
+ const html = document.documentElement;
569
+ html.classList.toggle("dark", dark);
570
+ if (document.body) {
571
+ document.body.classList.toggle("dark", dark);
572
+ }
573
+ html.setAttribute("data-ea-theme", pref);
574
+ html.setAttribute("data-ea-resolved", dark ? "dark" : "light");
575
+ if (pref === "system") {
576
+ html.style.removeProperty("color-scheme");
577
+ } else {
578
+ html.style.colorScheme = dark ? "dark" : "light";
579
+ }
580
+ queueMicrotask(() => {
581
+ syncRadioToStorage();
582
+ updateThemeToggleVisual();
583
+ });
584
+ }
585
+
586
+ let toggleSyncScheduled = false;
587
+ function scheduleToggleSync() {
588
+ if (toggleSyncScheduled) return;
589
+ toggleSyncScheduled = true;
590
+ queueMicrotask(() => {
591
+ toggleSyncScheduled = false;
592
+ syncRadioToStorage();
593
+ updateThemeToggleVisual();
594
+ });
595
+ }
596
+
597
+ let themeToggleOuterObs = false;
598
+ function attachThemeToggleObserver() {
599
+ const hook = (root) => {
600
+ if (!root || root.__eaToggleObserved) return;
601
+ root.__eaToggleObserved = true;
602
+ const mo = new MutationObserver(() => scheduleToggleSync());
603
+ mo.observe(root, {
604
+ childList: true,
605
+ subtree: true,
606
+ attributes: true,
607
+ attributeFilter: ["class", "aria-checked", "data-state"],
608
+ });
609
+ scheduleToggleSync();
610
+ };
611
+ hook(getToggleRoot());
612
+ if (themeToggleOuterObs) return;
613
+ themeToggleOuterObs = true;
614
+ const outer = new MutationObserver(() => {
615
+ const r = getToggleRoot();
616
+ if (r && !r.__eaToggleObserved) hook(r);
617
+ });
618
+ if (document.body) {
619
+ outer.observe(document.body, { childList: true, subtree: true });
620
+ }
621
+ }
622
+
623
+ function syncFromStorage() {
624
+ applyTheme(localStorage.getItem(KEY) || "system");
625
+ }
626
+
627
+ function attachBodyObserver() {
628
+ if (!document.body) return;
629
+ new MutationObserver(() => {
630
+ const pref = localStorage.getItem(KEY) || "system";
631
+ const need = wantsDark(pref);
632
+ if (document.body.classList.contains("dark") !== need) {
633
+ applyTheme(pref);
634
+ }
635
+ }).observe(document.body, { attributes: true, attributeFilter: ["class"] });
636
+ }
637
+
638
+ syncFromStorage();
639
+ if (document.body) {
640
+ attachBodyObserver();
641
+ attachThemeToggleObserver();
642
+ } else {
643
+ document.addEventListener("DOMContentLoaded", () => {
644
+ attachBodyObserver();
645
+ attachThemeToggleObserver();
646
+ });
647
+ }
648
+
649
+ function beatGradioInit() {
650
+ syncFromStorage();
651
+ queueMicrotask(syncFromStorage);
652
+ requestAnimationFrame(syncFromStorage);
653
+ }
654
+ if (document.readyState === "loading") {
655
+ document.addEventListener("DOMContentLoaded", beatGradioInit);
656
+ } else {
657
+ beatGradioInit();
658
+ }
659
+ window.addEventListener("load", () => {
660
+ syncFromStorage();
661
+ attachThemeToggleObserver();
662
+ [0, 50, 200, 500, 1000].forEach((ms) => setTimeout(syncFromStorage, ms));
663
+ });
664
+
665
+ window.matchMedia("(prefers-color-scheme: dark)").addEventListener("change", () => {
666
+ if ((localStorage.getItem(KEY) || "system") === "system") {
667
+ applyTheme("system");
668
+ }
669
+ });
670
+
671
+ window.__setEATheme = (pref) => {
672
+ localStorage.setItem(KEY, pref);
673
+ applyTheme(pref);
674
+ };
675
+ })();
676
+ """
677
+
678
+ # ---------------------------------------------------------------------------
679
+ # Helper functions
680
+ # ---------------------------------------------------------------------------
681
+
682
+
683
+ def _format_result_html(reward: float, ground_truth: Any) -> str:
684
+ variant = "result-ok" if reward >= 0 else "result-bad"
685
+ icon = "+" if reward >= 0 else ""
686
+ gt = html.escape(str(ground_truth))
687
+ return (
688
+ f'<div class="result-card {variant}">'
689
+ f'<div class="rc-title">Result {icon}</div>'
690
+ f'<div class="rc-body">'
691
+ f"<strong>Reward:</strong> {reward:.4f}<br/>"
692
+ f"<strong>Ground truth:</strong> {gt}"
693
+ f"</div></div>"
694
+ )
695
+
696
+
697
+ def _format_error_html(message: str) -> str:
698
+ msg = html.escape(message)
699
+ return (
700
+ '<div class="result-card result-notice">'
701
+ '<div class="rc-title">Notice</div>'
702
+ f'<div class="rc-body">{msg}</div></div>'
703
+ )
704
+
705
+
706
+ def _episode_status_markdown(
707
+ task_id: str,
708
+ company: str | None = None,
709
+ year: str | None = None,
710
+ quarter: str | None = None,
711
+ ) -> str:
712
+ use_specific = bool(
713
+ company and year and quarter and company != RANDOM_EPISODE_LABEL
714
+ )
715
+ if use_specific:
716
+ assert company is not None and year is not None and quarter is not None
717
+ idx = get_episode_index()
718
+ sym = idx.symbol_for_display(company)
719
+ yi = int(year)
720
+ qi = int(quarter)
721
+ eid = format_episode_id(sym, yi, qi)
722
+ row_line = f"**Row:** `{html.escape(eid)}` \n"
723
+ else:
724
+ row_line = "**Row:** random sample \n"
725
+ return (
726
+ f"**Episode loaded** &mdash; `{html.escape(task_id)}` \n"
727
+ f"{row_line}"
728
+ "Use **Observation** / **Analysis** tabs."
729
+ )
730
+
731
+
732
+ def _reset_failed_outputs(message: str) -> list[Any]:
733
+ texts, acc_updates = _text_rows_and_accordion_updates(None)
734
+ err_status = f"**Episode error** &mdash; {html.escape(message)}"
735
+ idle = "*Could not load episode. Fix the selection or try again.*"
736
+ return [
737
+ err_status,
738
+ idle,
739
+ *texts,
740
+ *acc_updates,
741
+ {},
742
+ gr.update(visible=False),
743
+ gr.update(visible=False, value=""),
744
+ "",
745
+ message,
746
+ ]
747
+
748
+
749
+ def _init_episode_dropdowns() -> tuple[Any, Any, Any]:
750
+ idx = get_episode_index()
751
+ displays = idx.sorted_company_displays()
752
+ choices = [RANDOM_EPISODE_LABEL] + displays
753
+ return (
754
+ gr.update(choices=choices, value=RANDOM_EPISODE_LABEL),
755
+ gr.update(choices=[], value=None, interactive=False),
756
+ gr.update(choices=[], value=None, interactive=False),
757
+ )
758
+
759
+
760
+ def _on_company_change(company: str | None) -> tuple[Any, Any]:
761
+ if not company or company == RANDOM_EPISODE_LABEL:
762
+ return (
763
+ gr.update(choices=[], value=None, interactive=False),
764
+ gr.update(choices=[], value=None, interactive=False),
765
+ )
766
+ idx = get_episode_index()
767
+ sym = idx.symbol_for_display(company)
768
+ years = idx.years_for_symbol(sym)
769
+ year_choices = [str(y) for y in years]
770
+ y0 = year_choices[0] if year_choices else None
771
+ if y0 is not None:
772
+ qs = idx.quarters_for(sym, int(y0))
773
+ q_choices = [str(q) for q in qs]
774
+ q0 = q_choices[0] if q_choices else None
775
+ else:
776
+ q_choices, q0 = [], None
777
+ return (
778
+ gr.update(choices=year_choices, value=y0, interactive=bool(year_choices)),
779
+ gr.update(choices=q_choices, value=q0, interactive=bool(q_choices)),
780
+ )
781
+
782
+
783
+ def _on_year_change(company: str | None, year_str: str | None) -> Any:
784
+ if not company or company == RANDOM_EPISODE_LABEL or not year_str:
785
+ return gr.update(choices=[], value=None, interactive=False)
786
+ idx = get_episode_index()
787
+ sym = idx.symbol_for_display(company)
788
+ qs = idx.quarters_for(sym, int(year_str))
789
+ q_choices = [str(q) for q in qs]
790
+ q0 = q_choices[0] if q_choices else None
791
+ return gr.update(choices=q_choices, value=q0, interactive=bool(q_choices))
792
+
793
+
794
+ def _text_rows_and_accordion_updates(
795
+ obs: Optional[EarningsAnalystObservation],
796
+ ) -> tuple[list[str], list[dict[str, Any]]]:
797
+ texts: list[str] = []
798
+ updates: list[dict[str, Any]] = []
799
+ ctx = obs.text_context if obs and obs.text_context else {}
800
+ for key in TEXT_KEYS:
801
+ raw = ctx.get(key, "") if ctx else ""
802
+ texts.append(raw if isinstance(raw, str) else str(raw))
803
+ updates.append(gr.update(visible=(key in ctx)))
804
+ return texts, updates
805
+
806
+
807
+ # ---------------------------------------------------------------------------
808
+ # State & environment helpers
809
+ # ---------------------------------------------------------------------------
810
+
811
 
812
  class State:
813
+ def __init__(self) -> None:
814
  self.env: Optional[EarningsAnalystEnvironment] = None
815
  self.obs: Optional[EarningsAnalystObservation] = None
816
  self.task_id: str = "sentiment_label"
 
819
  state = State()
820
 
821
 
822
+ async def reset_env(
823
+ task_id: str,
824
+ company: str | None,
825
+ year: str | None,
826
+ quarter: str | None,
827
+ ) -> list[Any]:
828
  state.task_id = task_id
829
  state.env = EarningsAnalystEnvironment(task_id=task_id)
830
+ try:
831
+ use_specific = bool(
832
+ company and year and quarter and company != RANDOM_EPISODE_LABEL
833
+ )
834
+ if use_specific:
835
+ assert company is not None and year is not None and quarter is not None
836
+ idx = get_episode_index()
837
+ sym = idx.symbol_for_display(company)
838
+ state.obs = state.env.reset(
839
+ pick_symbol=sym,
840
+ pick_year=int(year),
841
+ pick_quarter=int(quarter),
842
+ )
843
+ else:
844
+ state.obs = state.env.reset()
845
+ except (ValueError, TaskNotImplementedError) as e:
846
+ state.obs = None
847
+ return _reset_failed_outputs(str(e))
848
 
849
+ texts, acc_updates = _text_rows_and_accordion_updates(state.obs)
 
 
 
 
850
 
851
+ numerical_value: Any = (
852
+ state.obs.numerical_context if state.obs.numerical_context else {}
 
 
853
  )
854
 
855
  return [
856
+ _episode_status_markdown(task_id, company, year, quarter),
857
  state.obs.task_instruction,
858
+ *texts,
859
+ *acc_updates,
860
+ numerical_value,
861
+ gr.update(visible=True), # prediction_row
862
+ gr.update(visible=False, value=""), # result_view
863
+ "", # pred_input
864
+ "", # message_view
865
  ]
866
 
867
 
868
+ async def step_env(prediction: str) -> list[Any]:
869
  if not state.env or not state.obs:
870
+ return [
871
+ gr.update(),
872
+ gr.update(
873
+ visible=True,
874
+ value=_format_error_html(
875
+ "Environment not initialized. Click New Episode."
876
+ ),
877
+ ),
878
+ ]
879
 
880
  action = EarningsAnalystAction(prediction=prediction)
881
  state.obs = state.env.step(action)
882
 
883
+ raw_reward = state.obs.reward
884
+ reward_f = float(raw_reward) if raw_reward is not None else 0.0
885
  ground_truth = getattr(state.obs, "ground_truth", "N/A")
886
 
 
 
887
  return [
888
+ gr.update(visible=False),
889
+ gr.update(visible=True, value=_format_result_html(reward_f, ground_truth)),
890
  ]
891
 
892
 
893
+ async def run_agent(
894
+ task_id: str,
895
+ api_key: str,
896
+ model: str,
897
+ base_url: str,
898
+ company: str | None,
899
+ year: str | None,
900
+ quarter: str | None,
901
+ ) -> list[Any]:
902
+ n_out = 17
903
  if not api_key:
904
+ return [gr.update()] * (n_out - 1) + ["Please provide an API Key."]
905
 
906
+ out = await reset_env(task_id, company, year, quarter)
907
+ if state.obs is None:
908
+ return out
909
 
910
+ assert state.obs is not None
 
 
 
911
 
912
+ if base_url:
913
+ client = AsyncOpenAI(api_key=api_key, base_url=base_url)
914
+ else:
915
+ client = AsyncOpenAI(api_key=api_key)
916
 
917
  user_content = f"{state.obs.task_instruction}\n\n"
918
  if state.obs.text_context:
 
943
  parsed = json.loads(response_text)
944
  prediction = str(parsed.get("prediction", response_text))
945
 
 
946
  step_out = await step_env(prediction)
947
 
 
 
 
 
948
  return [
949
+ *out[:13],
 
 
950
  step_out[0],
951
  step_out[1],
952
  prediction,
 
954
  ]
955
 
956
  except Exception as e:
957
+ return [*out[:13], out[13], out[14], "", f"Error: {str(e)}"]
958
 
959
 
960
+ # ---------------------------------------------------------------------------
961
+ # Layout helpers
962
+ # ---------------------------------------------------------------------------
963
+
964
+
965
+ def _accordion_title(key: str) -> str:
966
+ labels = {
967
+ "earnings_transcript": "Earnings Call Transcript",
968
+ "press_release_8k_body": "Press Release — 8-K Body",
969
+ "press_release_ex991": "Exhibit 99.1",
970
+ "press_release_ex992": "Exhibit 99.2",
971
+ "press_release_sources": "Press Release Sources",
972
+ }
973
+ return labels.get(key, key.replace("_", " ").title())
 
974
 
 
 
 
 
 
 
 
975
 
976
+ # ---------------------------------------------------------------------------
977
+ # Gradio layout
978
+ # ---------------------------------------------------------------------------
979
+
980
+ with gr.Blocks(title="Earnings Analyst - OpenEnv") as demo:
981
+ with gr.Column(elem_classes="app-shell"):
982
+ # ── Header ──────────────────────────────────────────────────────────
983
+ with gr.Row(elem_classes="header-row"):
984
+ with gr.Column(scale=5, min_width=0):
985
+ gr.Markdown("## Earnings Analyst")
986
+ gr.Markdown(
987
+ "Interactive environment for financial analysis using "
988
+ "[OpenEnv](https://github.com/meta-pytorch/OpenEnv). "
989
+ "Evaluate agents or your own predictions on earnings call data."
990
+ )
991
+ with gr.Column(scale=2, min_width=0):
992
+ theme_toggle = gr.Radio(
993
+ choices=["System", "Light", "Dark"],
994
+ value="System",
995
+ label="Theme",
996
+ elem_classes="theme-toggle",
997
+ container=False,
998
+ )
999
+ status_badge = gr.Markdown(
1000
+ "No episode loaded. Click **New Episode**.",
1001
+ elem_classes="status-badge",
1002
+ )
1003
+
1004
+ # ── Main content ─────────────────────────────────────────────────────
1005
  with gr.Row():
1006
+ # Sidebar
1007
+ with gr.Column(scale=1, min_width=220):
1008
+ with gr.Group():
1009
+ company_dd = gr.Dropdown(
1010
+ label="Company",
1011
+ choices=[RANDOM_EPISODE_LABEL],
1012
+ value=RANDOM_EPISODE_LABEL,
1013
+ )
1014
+ with gr.Row():
1015
+ year_dd = gr.Dropdown(
1016
+ label="Year",
1017
+ choices=[],
1018
+ interactive=False,
1019
+ )
1020
+ quarter_dd = gr.Dropdown(
1021
+ label="Quarter",
1022
+ choices=[],
1023
+ interactive=False,
1024
+ )
1025
+
1026
  with gr.Group():
1027
  task_select = gr.Dropdown(
1028
+ choices=TASK_IDS,
1029
+ value="sentiment_label",
1030
+ label="Active Task",
1031
  )
1032
+ reset_btn = gr.Button("New Episode", variant="primary")
1033
 
1034
+ gr.Markdown("Auto-Agent", elem_classes="section-label")
1035
  with gr.Group():
1036
  api_key = gr.Textbox(
1037
  label="OpenAI API Key",
 
1045
  placeholder="https://api.openai.com/v1",
1046
  value=os.environ.get("API_BASE_URL", ""),
1047
  )
1048
+ agent_btn = gr.Button("Run LLM Agent", variant="secondary")
1049
 
1050
+ # Main panel
1051
  with gr.Column(scale=2):
1052
+ with gr.Column(elem_classes="task-instruction"):
1053
+ instr_view = gr.Markdown("*No task loaded. Click **New Episode**.*")
1054
+
1055
  with gr.Tabs():
1056
+ with gr.Tab("Observation"):
1057
+ gr.Markdown("**Documents** — expand a section to read.")
1058
+
1059
+ with gr.Accordion(
1060
+ _accordion_title(TEXT_KEYS[0]),
1061
+ open=True,
1062
+ visible=False,
1063
+ ) as acc_earnings_transcript:
1064
+ md_earnings_transcript = gr.Markdown(
1065
+ "", elem_classes="accordion-scroll"
1066
+ )
1067
+
1068
+ with gr.Accordion(
1069
+ _accordion_title(TEXT_KEYS[1]),
1070
+ open=False,
1071
+ visible=False,
1072
+ ) as acc_press_release_8k_body:
1073
+ md_press_release_8k_body = gr.Markdown(
1074
+ "", elem_classes="accordion-scroll"
1075
+ )
1076
+
1077
+ with gr.Accordion(
1078
+ _accordion_title(TEXT_KEYS[2]),
1079
+ open=False,
1080
+ visible=False,
1081
+ ) as acc_press_release_ex991:
1082
+ md_press_release_ex991 = gr.Markdown(
1083
+ "", elem_classes="accordion-scroll"
1084
+ )
1085
+
1086
+ with gr.Accordion(
1087
+ _accordion_title(TEXT_KEYS[3]),
1088
+ open=False,
1089
+ visible=False,
1090
+ ) as acc_press_release_ex992:
1091
+ md_press_release_ex992 = gr.Markdown(
1092
+ "", elem_classes="accordion-scroll"
1093
+ )
1094
+
1095
+ with gr.Accordion(
1096
+ _accordion_title(TEXT_KEYS[4]),
1097
+ open=False,
1098
+ visible=False,
1099
+ ) as acc_press_release_sources:
1100
+ md_press_release_sources = gr.Markdown(
1101
+ "", elem_classes="accordion-scroll"
1102
+ )
1103
+
1104
+ gr.Markdown("**Market Data**", elem_classes="section-label")
1105
+ num_view = gr.JSON(
1106
+ value={},
1107
+ label=None,
1108
+ show_label=False,
1109
+ container=False,
1110
+ elem_classes="market-data-json",
1111
+ )
1112
+
1113
+ with gr.Tab("Analysis"):
1114
  with gr.Column(visible=False) as prediction_row:
1115
  pred_input = gr.Textbox(
1116
+ label="Your prediction / analysis output",
1117
  placeholder="e.g. bullish, or 0.05",
1118
  )
1119
+ submit_btn = gr.Button("Submit analysis", variant="primary")
1120
 
1121
+ result_view = gr.HTML(visible=False, value="")
1122
  message_view = gr.Textbox(
1123
+ label="Agent log / messages",
1124
+ interactive=False,
1125
+ buttons=["copy"],
1126
+ lines=5,
1127
+ max_lines=12,
1128
+ autoscroll=True,
1129
  )
1130
 
1131
+ # ── Event outputs list ───────────────────────────────────────────────────
1132
+ reset_outputs = [
1133
+ status_badge,
1134
+ instr_view,
1135
+ md_earnings_transcript,
1136
+ md_press_release_8k_body,
1137
+ md_press_release_ex991,
1138
+ md_press_release_ex992,
1139
+ md_press_release_sources,
1140
+ acc_earnings_transcript,
1141
+ acc_press_release_8k_body,
1142
+ acc_press_release_ex991,
1143
+ acc_press_release_ex992,
1144
+ acc_press_release_sources,
1145
+ num_view,
1146
+ prediction_row,
1147
+ result_view,
1148
+ pred_input,
1149
+ message_view,
1150
+ ]
1151
+
1152
+ demo.load(
1153
+ _init_episode_dropdowns,
1154
+ outputs=[company_dd, year_dd, quarter_dd],
1155
+ )
1156
+ company_dd.change(
1157
+ _on_company_change,
1158
+ inputs=[company_dd],
1159
+ outputs=[year_dd, quarter_dd],
1160
+ )
1161
+ year_dd.change(
1162
+ _on_year_change,
1163
+ inputs=[company_dd, year_dd],
1164
+ outputs=[quarter_dd],
1165
+ )
1166
+
1167
  reset_btn.click(
1168
  fn=reset_env,
1169
+ inputs=[task_select, company_dd, year_dd, quarter_dd],
1170
+ outputs=reset_outputs,
 
 
 
 
 
 
 
 
1171
  )
 
1172
  submit_btn.click(
1173
  fn=step_env, inputs=[pred_input], outputs=[prediction_row, result_view]
1174
  )
 
1175
  agent_btn.click(
1176
  fn=run_agent,
1177
+ inputs=[
1178
+ task_select,
1179
+ api_key,
1180
+ model_name,
1181
+ base_url,
1182
+ company_dd,
1183
+ year_dd,
1184
+ quarter_dd,
 
1185
  ],
1186
+ outputs=reset_outputs,
1187
  )
1188
 
1189
+ # Theme toggle — client-side only, no server round-trip
1190
+ theme_toggle.change(
1191
+ fn=None,
1192
+ inputs=[theme_toggle],
1193
+ js=(
1194
+ "(choice) => { "
1195
+ "const v = Array.isArray(choice) ? choice[0] : choice; "
1196
+ "const key = String(v ?? 'system').toLowerCase(); "
1197
+ "window.__setEATheme?.(key); "
1198
+ "return choice; "
1199
+ "}"
1200
+ ),
1201
+ )
1202
+
1203
+ # ---------------------------------------------------------------------------
1204
+ # FastAPI app
1205
+ # ---------------------------------------------------------------------------
1206
 
 
1207
  api_app = FastAPI(title="Earnings Analyst API")
1208
 
 
1209
  api_app.add_middleware(
1210
  CORSMiddleware,
1211
  allow_origins=["*"],
 
1213
  allow_headers=["*"],
1214
  )
1215
 
 
 
1216
 
1217
  @api_app.get("/health")
1218
+ async def health() -> dict[str, str]:
1219
  return {"status": "ok", "environment": "earnings_analyst"}
1220
 
1221
 
1222
  @api_app.get("/web")
1223
+ async def web_redirect() -> RedirectResponse:
1224
  return RedirectResponse(url="/")
1225
 
1226
 
 
1227
  try:
1228
  from earnings_analyst.server.app import app as env_app
1229
  except (ImportError, ModuleNotFoundError):
1230
  from server.app import app as env_app
1231
 
 
 
1232
  api_app.include_router(env_app.router)
1233
  api_app.mount("/api", env_app)
1234
 
1235
+ # Mount Gradio theme, css, and js passed here (Gradio 6 API)
1236
+ app = gr.mount_gradio_app(
1237
+ api_app,
1238
+ demo,
1239
+ path="/",
1240
+ theme=theme,
1241
+ css=custom_css,
1242
+ js=theme_toggle_js,
1243
+ )
1244
 
1245
 
1246
+ def main() -> None:
1247
  import uvicorn
1248
 
1249
  uvicorn.run(app, host="0.0.0.0", port=8000)
environment_config.py CHANGED
@@ -11,4 +11,21 @@ from earnings_analyst.tasks.registry import DEFAULT_TASK, TASKS
11
  DATASET_ID = "RudrakshNanavaty/earnings-call-data"
12
  DATASET_FILE = "episodes_press_release_8k.parquet"
13
 
14
- __all__ = ["DATASET_ID", "DATASET_FILE", "DEFAULT_TASK", "TASKS"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  DATASET_ID = "RudrakshNanavaty/earnings-call-data"
12
  DATASET_FILE = "episodes_press_release_8k.parquet"
13
 
14
+ # Columns for Gradio episode selection (must exist in the parquet).
15
+ DATASET_EPISODE_ID_COLUMN = "episode_id"
16
+ DATASET_SYMBOL_COLUMN = "symbol"
17
+ DATASET_COMPANY_NAME_COLUMN = "company_name"
18
+ DATASET_YEAR_COLUMN = "year"
19
+ DATASET_QUARTER_COLUMN = "quarter"
20
+
21
+ __all__ = [
22
+ "DATASET_ID",
23
+ "DATASET_FILE",
24
+ "DATASET_EPISODE_ID_COLUMN",
25
+ "DATASET_SYMBOL_COLUMN",
26
+ "DATASET_COMPANY_NAME_COLUMN",
27
+ "DATASET_YEAR_COLUMN",
28
+ "DATASET_QUARTER_COLUMN",
29
+ "DEFAULT_TASK",
30
+ "TASKS",
31
+ ]
server/__init__.py CHANGED
@@ -1,5 +1,22 @@
1
- """Earnings Analyst environment server components."""
2
 
3
- from .earnings_analyst_environment import EarningsAnalystEnvironment
 
 
 
 
 
 
 
 
 
4
 
5
  __all__ = ["EarningsAnalystEnvironment"]
 
 
 
 
 
 
 
 
 
1
+ """Earnings Analyst environment server components.
2
 
3
+ ``EarningsAnalystEnvironment`` is loaded lazily so importing sibling modules
4
+ (e.g. ``episode_index``) does not pull in the Hugging Face dataset.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ if TYPE_CHECKING:
12
+ from .earnings_analyst_environment import EarningsAnalystEnvironment
13
 
14
  __all__ = ["EarningsAnalystEnvironment"]
15
+
16
+
17
+ def __getattr__(name: str) -> Any:
18
+ if name == "EarningsAnalystEnvironment":
19
+ from .earnings_analyst_environment import EarningsAnalystEnvironment as _Env
20
+
21
+ return _Env
22
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
server/earnings_analyst_environment.py CHANGED
@@ -11,7 +11,7 @@ import math
11
  import os
12
  import json
13
  import random
14
- from typing import Any
15
  from uuid import uuid4
16
 
17
  from openenv.core.env_server.interfaces import Environment
@@ -23,6 +23,7 @@ from earnings_analyst.tasks.exceptions import TaskNotImplementedError
23
  from earnings_analyst.tasks.registry import get_grader
24
 
25
  from .dataset_loader import dataset
 
26
 
27
 
28
  def _resolve_task_id(explicit: str | None) -> str:
@@ -68,18 +69,62 @@ class EarningsAnalystEnvironment(Environment):
68
  self._state = State(episode_id=str(uuid4()), step_count=0)
69
  self._current_row: dict[str, Any] | None = None
70
 
71
- def reset(self) -> EarningsAnalystObservation:
72
- """Sample one dataset row and return the agent-visible observation bundle."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if not self._cfg["implemented"]:
74
  raise TaskNotImplementedError(
75
  f"Task {self._task_id!r} is not implemented yet. "
76
  f"Set implemented=True and fill spec/grader under tasks/ when ready."
77
  )
78
 
79
- self._state = State(episode_id=str(uuid4()), step_count=0)
80
- idx = random.randrange(len(dataset))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  row = dataset[idx]
82
- # Normalize to a plain dict for grading and column access
83
  self._current_row = dict(row)
84
 
85
  text_context = {
@@ -118,7 +163,8 @@ class EarningsAnalystEnvironment(Environment):
118
 
119
  # Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
120
  if "xbrl_columns" in self._cfg:
121
- gt_data = {col: row.get(col) for col in self._cfg["xbrl_columns"]}
 
122
  ground_truth = json.dumps(gt_data)
123
  else:
124
  ground_truth = str(row.get(label_col, "")).strip()
@@ -132,7 +178,6 @@ class EarningsAnalystEnvironment(Environment):
132
  )
133
  )
134
 
135
-
136
  return EarningsAnalystObservation(
137
  text_context={},
138
  numerical_context={},
@@ -146,7 +191,6 @@ class EarningsAnalystEnvironment(Environment):
146
  },
147
  )
148
 
149
-
150
  @property
151
  def state(self) -> State:
152
  """Current environment state."""
 
11
  import os
12
  import json
13
  import random
14
+ from typing import Any, cast
15
  from uuid import uuid4
16
 
17
  from openenv.core.env_server.interfaces import Environment
 
23
  from earnings_analyst.tasks.registry import get_grader
24
 
25
  from .dataset_loader import dataset
26
+ from .episode_index import get_episode_index
27
 
28
 
29
  def _resolve_task_id(explicit: str | None) -> str:
 
69
  self._state = State(episode_id=str(uuid4()), step_count=0)
70
  self._current_row: dict[str, Any] | None = None
71
 
72
+ def reset(
73
+ self,
74
+ seed: int | None = None,
75
+ episode_id: str | None = None,
76
+ **kwargs: Any,
77
+ ) -> EarningsAnalystObservation:
78
+ """Sample one dataset row and return the agent-visible observation bundle.
79
+
80
+ If ``pick_symbol``, ``pick_year``, and ``pick_quarter`` are passed (Gradio),
81
+ load that row by ``(symbol, year, quarter)``. Otherwise sample uniformly at
82
+ random (default for HTTP clients).
83
+
84
+ Optional ``seed`` affects only the random row path.
85
+ """
86
+ pick_symbol = kwargs.pop("pick_symbol", None)
87
+ pick_year = kwargs.pop("pick_year", None)
88
+ pick_quarter = kwargs.pop("pick_quarter", None)
89
+ if kwargs:
90
+ raise TypeError(f"Unexpected keyword arguments: {sorted(kwargs)!r}")
91
+
92
  if not self._cfg["implemented"]:
93
  raise TaskNotImplementedError(
94
  f"Task {self._task_id!r} is not implemented yet. "
95
  f"Set implemented=True and fill spec/grader under tasks/ when ready."
96
  )
97
 
98
+ eid = episode_id if episode_id is not None else str(uuid4())
99
+ self._state = State(episode_id=eid, step_count=0)
100
+ if (
101
+ pick_symbol is not None
102
+ and str(pick_symbol).strip() != ""
103
+ and pick_year is not None
104
+ and pick_quarter is not None
105
+ ):
106
+ sym = str(pick_symbol).strip()
107
+ try:
108
+ yi = int(pick_year) if not isinstance(pick_year, int) else pick_year
109
+ qi = (
110
+ int(pick_quarter)
111
+ if not isinstance(pick_quarter, int)
112
+ else pick_quarter
113
+ )
114
+ except (TypeError, ValueError) as e:
115
+ raise ValueError("Year and quarter must be integers.") from e
116
+ try:
117
+ idx = get_episode_index().row_index(sym, yi, qi)
118
+ except KeyError as e:
119
+ raise ValueError(str(e)) from None
120
+ else:
121
+ rng = random.Random(seed) if seed is not None else random
122
+ idx = rng.randrange(len(dataset))
123
+
124
+ return self._load_row_at(idx)
125
+
126
+ def _load_row_at(self, idx: int) -> EarningsAnalystObservation:
127
  row = dataset[idx]
 
128
  self._current_row = dict(row)
129
 
130
  text_context = {
 
163
 
164
  # Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
165
  if "xbrl_columns" in self._cfg:
166
+ xbrl_cols = cast(list[str], self._cfg["xbrl_columns"])
167
+ gt_data = {col: row.get(col) for col in xbrl_cols}
168
  ground_truth = json.dumps(gt_data)
169
  else:
170
  ground_truth = str(row.get(label_col, "")).strip()
 
178
  )
179
  )
180
 
 
181
  return EarningsAnalystObservation(
182
  text_context={},
183
  numerical_context={},
 
191
  },
192
  )
193
 
 
194
  @property
195
  def state(self) -> State:
196
  """Current environment state."""
server/episode_index.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build lookups over the loaded Hugging Face split for Gradio (company / year / quarter).
3
+
4
+ Rows are keyed by ``(symbol, year, quarter)`` and must match ``episode_id`` when present
5
+ (``symbol_year_Qquarter``, e.g. ``A_2006_Q1``).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from collections import defaultdict
12
+ from typing import Any
13
+
14
+ from earnings_analyst.environment_config import (
15
+ DATASET_COMPANY_NAME_COLUMN,
16
+ DATASET_EPISODE_ID_COLUMN,
17
+ DATASET_QUARTER_COLUMN,
18
+ DATASET_SYMBOL_COLUMN,
19
+ DATASET_YEAR_COLUMN,
20
+ )
21
+
22
+ # First sidebar option: sample a random row (same behavior as HTTP clients).
23
+ RANDOM_EPISODE_LABEL = "— Random row —"
24
+
25
+
26
+ def _normalize_cell(value: Any) -> str | None:
27
+ if value is None:
28
+ return None
29
+ s = str(value).strip()
30
+ return s if s else None
31
+
32
+
33
+ def _as_int(value: Any) -> int | None:
34
+ if value is None:
35
+ return None
36
+ try:
37
+ x = int(round(float(value)))
38
+ except (TypeError, ValueError):
39
+ return None
40
+ return x
41
+
42
+
43
+ def format_episode_id(symbol: str, year: int, quarter: int) -> str:
44
+ """Canonical ``episode_id`` pattern: ``{symbol}_{year}_Q{quarter}``."""
45
+ return f"{symbol}_{year}_Q{quarter}"
46
+
47
+
48
+ class EpisodeIndex:
49
+ """Maps ticker symbol + fiscal year + quarter to a dataset row index."""
50
+
51
+ _REQUIRED_COLUMNS: tuple[str, ...] = (
52
+ DATASET_EPISODE_ID_COLUMN,
53
+ DATASET_SYMBOL_COLUMN,
54
+ DATASET_COMPANY_NAME_COLUMN,
55
+ DATASET_YEAR_COLUMN,
56
+ DATASET_QUARTER_COLUMN,
57
+ )
58
+
59
+ def __init__(
60
+ self,
61
+ dataset: Any,
62
+ episode_id_col: str = DATASET_EPISODE_ID_COLUMN,
63
+ symbol_col: str = DATASET_SYMBOL_COLUMN,
64
+ company_name_col: str = DATASET_COMPANY_NAME_COLUMN,
65
+ year_col: str = DATASET_YEAR_COLUMN,
66
+ quarter_col: str = DATASET_QUARTER_COLUMN,
67
+ ) -> None:
68
+ names = dataset.column_names
69
+ missing = [c for c in self._REQUIRED_COLUMNS if c not in names]
70
+ if missing:
71
+ raise ValueError(
72
+ f"Dataset is missing column(s) {missing}. "
73
+ f"Check environment_config.py. Available columns: {names!r}."
74
+ )
75
+
76
+ self._symbol_to_display: dict[str, str] = {}
77
+ self._display_to_symbol: dict[str, str] = {}
78
+ self._symbol_to_years: dict[str, set[int]] = defaultdict(set)
79
+ self._symbol_year_to_quarters: dict[tuple[str, int], set[int]] = defaultdict(
80
+ set
81
+ )
82
+ self._triple_to_index: dict[tuple[str, int, int], int] = {}
83
+ duplicate_triples: set[tuple[str, int, int]] = set()
84
+
85
+ n = len(dataset)
86
+ for i in range(n):
87
+ row = dataset[i]
88
+ sym = _normalize_cell(row.get(symbol_col))
89
+ cname = _normalize_cell(row.get(company_name_col))
90
+ yr = _as_int(row.get(year_col))
91
+ qn = _as_int(row.get(quarter_col))
92
+ if not sym or yr is None or qn is None:
93
+ continue
94
+
95
+ if sym not in self._symbol_to_display:
96
+ label = cname or sym
97
+ display = f"{label} ({sym})"
98
+ self._symbol_to_display[sym] = display
99
+ self._display_to_symbol[display] = sym
100
+
101
+ triple = (sym, yr, qn)
102
+ self._symbol_to_years[sym].add(yr)
103
+ self._symbol_year_to_quarters[(sym, yr)].add(qn)
104
+
105
+ if triple in self._triple_to_index:
106
+ duplicate_triples.add(triple)
107
+ else:
108
+ self._triple_to_index[triple] = i
109
+ # episode_id_col is validated to exist; row lookup uses (symbol, year, quarter).
110
+
111
+ if duplicate_triples:
112
+ warnings.warn(
113
+ "Duplicate (symbol, year, quarter) rows; using the first index for each.",
114
+ UserWarning,
115
+ stacklevel=1,
116
+ )
117
+
118
+ def sorted_company_displays(self) -> list[str]:
119
+ return sorted(self._display_to_symbol.keys())
120
+
121
+ def symbol_for_display(self, display: str) -> str:
122
+ if display not in self._display_to_symbol:
123
+ raise KeyError(f"Unknown company selection: {display!r}")
124
+ return self._display_to_symbol[display]
125
+
126
+ def years_for_symbol(self, symbol: str) -> list[int]:
127
+ return sorted(self._symbol_to_years.get(symbol, []))
128
+
129
+ def quarters_for(self, symbol: str, year: int) -> list[int]:
130
+ return sorted(self._symbol_year_to_quarters.get((symbol, year), []))
131
+
132
+ def row_index(self, symbol: str, year: int, quarter: int) -> int:
133
+ key = (symbol, year, quarter)
134
+ if key not in self._triple_to_index:
135
+ raise KeyError(
136
+ f"No row for symbol={symbol!r}, year={year!r}, quarter={quarter!r}. "
137
+ "Pick a valid combination from the dropdowns."
138
+ )
139
+ return self._triple_to_index[key]
140
+
141
+
142
+ _instance: EpisodeIndex | None = None
143
+
144
+
145
+ def get_episode_index() -> EpisodeIndex:
146
+ """Lazily build the index over ``dataset_loader.dataset`` (single pass)."""
147
+ global _instance
148
+ if _instance is None:
149
+ from earnings_analyst.server.dataset_loader import dataset
150
+
151
+ _instance = EpisodeIndex(dataset)
152
+ return _instance
153
+
154
+
155
+ __all__ = [
156
+ "EpisodeIndex",
157
+ "format_episode_id",
158
+ "get_episode_index",
159
+ "RANDOM_EPISODE_LABEL",
160
+ ]
tests/test_episode_index.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for :mod:`earnings_analyst.server.episode_index`."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from datasets import Dataset
7
+
8
+ from earnings_analyst.server.episode_index import (
9
+ EpisodeIndex,
10
+ RANDOM_EPISODE_LABEL,
11
+ format_episode_id,
12
+ )
13
+
14
+
15
+ def test_format_episode_id() -> None:
16
+ assert format_episode_id("A", 2006, 1) == "A_2006_Q1"
17
+
18
+
19
+ def test_episode_index_lookup_and_cascades() -> None:
20
+ ds = Dataset.from_dict(
21
+ {
22
+ "episode_id": ["AAA_2024_Q1", "AAA_2024_Q2", "BBB_2024_Q1"],
23
+ "symbol": ["AAA", "AAA", "BBB"],
24
+ "company_name": ["Alpha Co", "Alpha Co", "Beta Co"],
25
+ "year": [2024, 2024, 2024],
26
+ "quarter": [1, 2, 1],
27
+ }
28
+ )
29
+ idx = EpisodeIndex(ds)
30
+ d_aaa = "Alpha Co (AAA)"
31
+ d_bbb = "Beta Co (BBB)"
32
+ assert idx.sorted_company_displays() == [d_aaa, d_bbb]
33
+ assert idx.symbol_for_display(d_aaa) == "AAA"
34
+ assert idx.years_for_symbol("AAA") == [2024]
35
+ assert idx.quarters_for("AAA", 2024) == [1, 2]
36
+ assert idx.row_index("AAA", 2024, 1) == 0
37
+ assert idx.row_index("AAA", 2024, 2) == 1
38
+ assert idx.row_index("BBB", 2024, 1) == 2
39
+
40
+
41
+ def test_episode_index_duplicate_triple_warns() -> None:
42
+ ds = Dataset.from_dict(
43
+ {
44
+ "episode_id": ["X_2024_Q1", "X_2024_Q1_dup"],
45
+ "symbol": ["X", "X"],
46
+ "company_name": ["X", "X"],
47
+ "year": [2024, 2024],
48
+ "quarter": [1, 1],
49
+ }
50
+ )
51
+ with pytest.warns(UserWarning, match="Duplicate"):
52
+ idx = EpisodeIndex(ds)
53
+ assert idx.row_index("X", 2024, 1) == 0
54
+
55
+
56
+ def test_episode_index_skips_incomplete_rows() -> None:
57
+ ds = Dataset.from_dict(
58
+ {
59
+ "episode_id": ["_", "Z_2024_Q1"],
60
+ "symbol": ["", "ZZZ"],
61
+ "company_name": ["", "Zed"],
62
+ "year": [2024, 2024],
63
+ "quarter": [1, 1],
64
+ }
65
+ )
66
+ idx = EpisodeIndex(ds)
67
+ assert idx.sorted_company_displays() == ["Zed (ZZZ)"]
68
+ assert idx.row_index("ZZZ", 2024, 1) == 1
69
+
70
+
71
+ def test_episode_index_missing_column_raises() -> None:
72
+ ds = Dataset.from_dict({"a": [1]})
73
+ with pytest.raises(ValueError, match="missing column"):
74
+ EpisodeIndex(ds)
75
+
76
+
77
+ def test_random_label_constant() -> None:
78
+ assert "Random" in RANDOM_EPISODE_LABEL