Commit ·
fb8bae3
1
Parent(s): f9e8ac9
Improved UI
Browse files- app.py +1071 -126
- environment_config.py +18 -1
- server/__init__.py +19 -2
- server/earnings_analyst_environment.py +53 -9
- server/episode_index.py +160 -0
- tests/test_episode_index.py +78 -0
app.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import json
|
| 4 |
import gradio as gr
|
| 5 |
-
from
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from openai import AsyncOpenAI
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Ensure the project root is in sys.path
|
| 10 |
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
|
@@ -13,21 +18,799 @@ try:
|
|
| 13 |
from earnings_analyst.server.earnings_analyst_environment import (
|
| 14 |
EarningsAnalystEnvironment,
|
| 15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from earnings_analyst.models import (
|
| 17 |
EarningsAnalystAction,
|
| 18 |
EarningsAnalystObservation,
|
| 19 |
)
|
|
|
|
| 20 |
from earnings_analyst.tasks.registry import TASK_IDS
|
| 21 |
except (ImportError, ModuleNotFoundError):
|
| 22 |
from server.earnings_analyst_environment import EarningsAnalystEnvironment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from models import EarningsAnalystAction, EarningsAnalystObservation
|
|
|
|
| 24 |
from tasks.registry import TASK_IDS
|
| 25 |
|
| 26 |
load_dotenv()
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class State:
|
| 30 |
-
def __init__(self):
|
| 31 |
self.env: Optional[EarningsAnalystEnvironment] = None
|
| 32 |
self.obs: Optional[EarningsAnalystObservation] = None
|
| 33 |
self.task_id: str = "sentiment_label"
|
|
@@ -36,65 +819,100 @@ class State:
|
|
| 36 |
state = State()
|
| 37 |
|
| 38 |
|
| 39 |
-
async def reset_env(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
state.task_id = task_id
|
| 41 |
state.env = EarningsAnalystEnvironment(task_id=task_id)
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
text_context = ""
|
| 46 |
-
if state.obs.text_context:
|
| 47 |
-
for name, text in sorted(state.obs.text_context.items()):
|
| 48 |
-
text_context += f"### {name}\n{text}\n\n"
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
if state.obs.numerical_context
|
| 53 |
-
else "No numerical data."
|
| 54 |
)
|
| 55 |
|
| 56 |
return [
|
|
|
|
| 57 |
state.obs.task_instruction,
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
gr.update(visible=
|
| 62 |
-
"", #
|
| 63 |
-
"", #
|
|
|
|
| 64 |
]
|
| 65 |
|
| 66 |
|
| 67 |
-
async def step_env(prediction: str):
|
| 68 |
if not state.env or not state.obs:
|
| 69 |
-
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
action = EarningsAnalystAction(prediction=prediction)
|
| 72 |
state.obs = state.env.step(action)
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
ground_truth = getattr(state.obs, "ground_truth", "N/A")
|
| 76 |
|
| 77 |
-
result_text = f"**Reward:** {reward:.4f} \n\n**Ground Truth:** {ground_truth}"
|
| 78 |
-
|
| 79 |
return [
|
| 80 |
-
gr.update(visible=False),
|
| 81 |
-
gr.update(visible=True, value=
|
| 82 |
]
|
| 83 |
|
| 84 |
|
| 85 |
-
async def run_agent(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
if not api_key:
|
| 87 |
-
return [gr.update()] *
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
client_params = {"api_key": api_key}
|
| 94 |
-
if base_url:
|
| 95 |
-
client_params["base_url"] = base_url
|
| 96 |
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
user_content = f"{state.obs.task_instruction}\n\n"
|
| 100 |
if state.obs.text_context:
|
|
@@ -125,17 +943,10 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
|
|
| 125 |
parsed = json.loads(response_text)
|
| 126 |
prediction = str(parsed.get("prediction", response_text))
|
| 127 |
|
| 128 |
-
# 3. Step
|
| 129 |
step_out = await step_env(prediction)
|
| 130 |
|
| 131 |
-
# Update UI components
|
| 132 |
-
# out: [instr, text, num, pred_row, res_row, pred_input, msg]
|
| 133 |
-
# step_out: [pred_row, res_row]
|
| 134 |
-
|
| 135 |
return [
|
| 136 |
-
out[
|
| 137 |
-
out[1],
|
| 138 |
-
out[2],
|
| 139 |
step_out[0],
|
| 140 |
step_out[1],
|
| 141 |
prediction,
|
|
@@ -143,42 +954,84 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
|
|
| 143 |
]
|
| 144 |
|
| 145 |
except Exception as e:
|
| 146 |
-
return [out[
|
| 147 |
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
.
|
| 163 |
-
"""
|
| 164 |
|
| 165 |
-
with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
|
| 166 |
-
with gr.Column(elem_classes="container"):
|
| 167 |
-
with gr.Column(elem_classes="header"):
|
| 168 |
-
gr.Markdown("# 🏑 Earnings Analyst")
|
| 169 |
-
gr.Markdown(
|
| 170 |
-
"Interactive environment for financial analysis tasks using [OpenEnv](https://github.com/meta-pytorch/OpenEnv). Evaluate agents or your own analysis on earnings call data."
|
| 171 |
-
)
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
with gr.Row():
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
with gr.Group():
|
| 176 |
task_select = gr.Dropdown(
|
| 177 |
-
choices=TASK_IDS,
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
-
reset_btn = gr.Button("
|
| 180 |
|
| 181 |
-
gr.Markdown("
|
| 182 |
with gr.Group():
|
| 183 |
api_key = gr.Textbox(
|
| 184 |
label="OpenAI API Key",
|
|
@@ -192,77 +1045,167 @@ with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
|
|
| 192 |
placeholder="https://api.openai.com/v1",
|
| 193 |
value=os.environ.get("API_BASE_URL", ""),
|
| 194 |
)
|
| 195 |
-
agent_btn = gr.Button("
|
| 196 |
|
|
|
|
| 197 |
with gr.Column(scale=2):
|
|
|
|
|
|
|
|
|
|
| 198 |
with gr.Tabs():
|
| 199 |
-
with gr.
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
with gr.Column(visible=False) as prediction_row:
|
| 215 |
pred_input = gr.Textbox(
|
| 216 |
-
label="Your
|
| 217 |
placeholder="e.g. bullish, or 0.05",
|
| 218 |
)
|
| 219 |
-
submit_btn = gr.Button("Submit
|
| 220 |
|
| 221 |
-
result_view = gr.
|
| 222 |
message_view = gr.Textbox(
|
| 223 |
-
label="Agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
|
| 226 |
-
# Event
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
reset_btn.click(
|
| 228 |
fn=reset_env,
|
| 229 |
-
inputs=[task_select],
|
| 230 |
-
outputs=
|
| 231 |
-
instr_view,
|
| 232 |
-
text_view,
|
| 233 |
-
num_view,
|
| 234 |
-
prediction_row,
|
| 235 |
-
result_view,
|
| 236 |
-
pred_input,
|
| 237 |
-
message_view,
|
| 238 |
-
],
|
| 239 |
)
|
| 240 |
-
|
| 241 |
submit_btn.click(
|
| 242 |
fn=step_env, inputs=[pred_input], outputs=[prediction_row, result_view]
|
| 243 |
)
|
| 244 |
-
|
| 245 |
agent_btn.click(
|
| 246 |
fn=run_agent,
|
| 247 |
-
inputs=[
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
message_view,
|
| 256 |
],
|
|
|
|
| 257 |
)
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
# Create the main FastAPI app
|
| 263 |
api_app = FastAPI(title="Earnings Analyst API")
|
| 264 |
|
| 265 |
-
# Add CORS middleware
|
| 266 |
api_app.add_middleware(
|
| 267 |
CORSMiddleware,
|
| 268 |
allow_origins=["*"],
|
|
@@ -270,35 +1213,37 @@ api_app.add_middleware(
|
|
| 270 |
allow_headers=["*"],
|
| 271 |
)
|
| 272 |
|
| 273 |
-
from fastapi.responses import RedirectResponse
|
| 274 |
-
|
| 275 |
|
| 276 |
@api_app.get("/health")
|
| 277 |
-
async def health():
|
| 278 |
return {"status": "ok", "environment": "earnings_analyst"}
|
| 279 |
|
| 280 |
|
| 281 |
@api_app.get("/web")
|
| 282 |
-
async def web_redirect():
|
| 283 |
return RedirectResponse(url="/")
|
| 284 |
|
| 285 |
|
| 286 |
-
# Import the environment app
|
| 287 |
try:
|
| 288 |
from earnings_analyst.server.app import app as env_app
|
| 289 |
except (ImportError, ModuleNotFoundError):
|
| 290 |
from server.app import app as env_app
|
| 291 |
|
| 292 |
-
# Mount the environment API at both root and /api for compatibility
|
| 293 |
-
# Root level is needed for OpenEnv standard validation
|
| 294 |
api_app.include_router(env_app.router)
|
| 295 |
api_app.mount("/api", env_app)
|
| 296 |
|
| 297 |
-
#
|
| 298 |
-
app = gr.mount_gradio_app(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
|
| 301 |
-
def main():
|
| 302 |
import uvicorn
|
| 303 |
|
| 304 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 1 |
+
import html
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import json
|
| 5 |
import gradio as gr
|
| 6 |
+
from gradio.themes.utils import colors, sizes, fonts
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
from openai import AsyncOpenAI
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from fastapi.responses import RedirectResponse
|
| 13 |
|
| 14 |
# Ensure the project root is in sys.path
|
| 15 |
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
|
|
|
| 18 |
from earnings_analyst.server.earnings_analyst_environment import (
|
| 19 |
EarningsAnalystEnvironment,
|
| 20 |
)
|
| 21 |
+
from earnings_analyst.server.episode_index import (
|
| 22 |
+
RANDOM_EPISODE_LABEL,
|
| 23 |
+
format_episode_id,
|
| 24 |
+
get_episode_index,
|
| 25 |
+
)
|
| 26 |
from earnings_analyst.models import (
|
| 27 |
EarningsAnalystAction,
|
| 28 |
EarningsAnalystObservation,
|
| 29 |
)
|
| 30 |
+
from earnings_analyst.tasks.exceptions import TaskNotImplementedError
|
| 31 |
from earnings_analyst.tasks.registry import TASK_IDS
|
| 32 |
except (ImportError, ModuleNotFoundError):
|
| 33 |
from server.earnings_analyst_environment import EarningsAnalystEnvironment
|
| 34 |
+
from server.episode_index import (
|
| 35 |
+
RANDOM_EPISODE_LABEL,
|
| 36 |
+
format_episode_id,
|
| 37 |
+
get_episode_index,
|
| 38 |
+
)
|
| 39 |
from models import EarningsAnalystAction, EarningsAnalystObservation
|
| 40 |
+
from tasks.exceptions import TaskNotImplementedError
|
| 41 |
from tasks.registry import TASK_IDS
|
| 42 |
|
| 43 |
load_dotenv()
|
| 44 |
|
| 45 |
+
# One accordion per possible text column (task specs use these keys).
|
| 46 |
+
TEXT_KEYS = (
|
| 47 |
+
"earnings_transcript",
|
| 48 |
+
"press_release_8k_body",
|
| 49 |
+
"press_release_ex991",
|
| 50 |
+
"press_release_ex992",
|
| 51 |
+
"press_release_sources",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Theme — Zerodha/Kite-inspired
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
zerodha_blue = colors.Color(
|
| 59 |
+
c50="#eff6ff",
|
| 60 |
+
c100="#dbeafe",
|
| 61 |
+
c200="#bfdbfe",
|
| 62 |
+
c300="#93c5fd",
|
| 63 |
+
c400="#60a5fa",
|
| 64 |
+
c500="#387ED1",
|
| 65 |
+
c600="#2563eb",
|
| 66 |
+
c700="#1d4ed8",
|
| 67 |
+
c800="#1e40af",
|
| 68 |
+
c900="#1e3a8a",
|
| 69 |
+
c950="#172554",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
theme = gr.themes.Base( # type: ignore[attr-defined]
|
| 73 |
+
primary_hue=zerodha_blue,
|
| 74 |
+
secondary_hue=zerodha_blue,
|
| 75 |
+
neutral_hue=colors.slate,
|
| 76 |
+
radius_size=sizes.radius_md,
|
| 77 |
+
font=(fonts.GoogleFont("Inter"), "system-ui", "sans-serif"),
|
| 78 |
+
font_mono=(fonts.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
theme.set(
|
| 82 |
+
# Body
|
| 83 |
+
body_background_fill="#F8F9FA",
|
| 84 |
+
body_background_fill_dark="#18181b",
|
| 85 |
+
body_text_color="#1a1a1a",
|
| 86 |
+
body_text_color_dark="#f4f4f5",
|
| 87 |
+
body_text_color_subdued="#6b7280",
|
| 88 |
+
body_text_color_subdued_dark="#a1a1aa",
|
| 89 |
+
# Blocks / cards
|
| 90 |
+
block_background_fill="#ffffff",
|
| 91 |
+
block_background_fill_dark="#27272a",
|
| 92 |
+
block_border_color="#e5e7eb",
|
| 93 |
+
block_border_color_dark="#3f3f46",
|
| 94 |
+
block_border_width="1px",
|
| 95 |
+
block_radius="12px",
|
| 96 |
+
block_shadow="none",
|
| 97 |
+
block_shadow_dark="none",
|
| 98 |
+
block_label_background_fill="#f9fafb",
|
| 99 |
+
block_label_background_fill_dark="#3f3f46",
|
| 100 |
+
block_label_border_color="#e5e7eb",
|
| 101 |
+
block_label_border_color_dark="#52525b",
|
| 102 |
+
block_label_text_color="#374151",
|
| 103 |
+
block_label_text_color_dark="#d4d4d8",
|
| 104 |
+
# Inputs
|
| 105 |
+
input_background_fill="#ffffff",
|
| 106 |
+
input_background_fill_dark="#3f3f46",
|
| 107 |
+
input_background_fill_focus="#ffffff",
|
| 108 |
+
input_background_fill_focus_dark="#52525b",
|
| 109 |
+
input_border_color="#d1d5db",
|
| 110 |
+
input_border_color_dark="#52525b",
|
| 111 |
+
input_border_color_focus="#387ED1",
|
| 112 |
+
input_border_color_focus_dark="#60a5fa",
|
| 113 |
+
input_border_width="1px",
|
| 114 |
+
input_radius="8px",
|
| 115 |
+
input_shadow="none",
|
| 116 |
+
input_shadow_dark="none",
|
| 117 |
+
input_shadow_focus="0 0 0 3px rgba(56, 126, 209, 0.15)",
|
| 118 |
+
input_shadow_focus_dark="0 0 0 3px rgba(96, 165, 250, 0.2)",
|
| 119 |
+
input_placeholder_color="#9ca3af",
|
| 120 |
+
input_placeholder_color_dark="#71717a",
|
| 121 |
+
# Primary buttons — solid Zerodha blue
|
| 122 |
+
button_primary_background_fill="#387ED1",
|
| 123 |
+
button_primary_background_fill_dark="#387ED1",
|
| 124 |
+
button_primary_background_fill_hover="#2563eb",
|
| 125 |
+
button_primary_background_fill_hover_dark="#2563eb",
|
| 126 |
+
button_primary_text_color="#ffffff",
|
| 127 |
+
button_primary_text_color_dark="#ffffff",
|
| 128 |
+
button_primary_text_color_hover="#ffffff",
|
| 129 |
+
button_primary_text_color_hover_dark="#ffffff",
|
| 130 |
+
button_primary_border_color="#387ED1",
|
| 131 |
+
button_primary_border_color_dark="#387ED1",
|
| 132 |
+
button_primary_border_color_hover="#2563eb",
|
| 133 |
+
button_primary_border_color_hover_dark="#2563eb",
|
| 134 |
+
button_primary_shadow="none",
|
| 135 |
+
button_primary_shadow_dark="none",
|
| 136 |
+
button_primary_shadow_hover="none",
|
| 137 |
+
button_primary_shadow_hover_dark="none",
|
| 138 |
+
# Secondary buttons — subtle outline
|
| 139 |
+
button_secondary_background_fill="#ffffff",
|
| 140 |
+
button_secondary_background_fill_dark="#3f3f46",
|
| 141 |
+
button_secondary_background_fill_hover="#eff6ff",
|
| 142 |
+
button_secondary_background_fill_hover_dark="#52525b",
|
| 143 |
+
button_secondary_text_color="#387ED1",
|
| 144 |
+
button_secondary_text_color_dark="#93c5fd",
|
| 145 |
+
button_secondary_text_color_hover="#2563eb",
|
| 146 |
+
button_secondary_text_color_hover_dark="#bfdbfe",
|
| 147 |
+
button_secondary_border_color="#d1d5db",
|
| 148 |
+
button_secondary_border_color_dark="#52525b",
|
| 149 |
+
button_secondary_border_color_hover="#387ED1",
|
| 150 |
+
button_secondary_border_color_hover_dark="#60a5fa",
|
| 151 |
+
button_secondary_shadow="none",
|
| 152 |
+
button_secondary_shadow_dark="none",
|
| 153 |
+
button_secondary_shadow_hover="none",
|
| 154 |
+
button_secondary_shadow_hover_dark="none",
|
| 155 |
+
# Links
|
| 156 |
+
link_text_color="#387ED1",
|
| 157 |
+
link_text_color_dark="#60a5fa",
|
| 158 |
+
link_text_color_hover="#2563eb",
|
| 159 |
+
link_text_color_hover_dark="#93c5fd",
|
| 160 |
+
link_text_color_active="#1d4ed8",
|
| 161 |
+
link_text_color_active_dark="#bfdbfe",
|
| 162 |
+
link_text_color_visited="#387ED1",
|
| 163 |
+
link_text_color_visited_dark="#60a5fa",
|
| 164 |
+
# Accent
|
| 165 |
+
color_accent="#387ED1",
|
| 166 |
+
color_accent_soft="#eff6ff",
|
| 167 |
+
color_accent_soft_dark="#1e3a8a",
|
| 168 |
+
border_color_accent="#387ED1",
|
| 169 |
+
border_color_accent_dark="#60a5fa",
|
| 170 |
+
border_color_accent_subdued="#bfdbfe",
|
| 171 |
+
border_color_accent_subdued_dark="#1d4ed8",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# CSS overrides
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
custom_css = """
|
| 179 |
+
/* Align custom panels with theme block/input radii */
|
| 180 |
+
:root {
|
| 181 |
+
--ea-radius-block: 12px;
|
| 182 |
+
--ea-radius-control: 8px;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
footer { visibility: hidden; }
|
| 186 |
+
|
| 187 |
+
/* Page shell */
|
| 188 |
+
.app-shell {
|
| 189 |
+
max-width: 1140px;
|
| 190 |
+
margin: 0 auto;
|
| 191 |
+
padding: 1.25rem 1rem 2rem;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/* Header strip */
|
| 195 |
+
.header-row {
|
| 196 |
+
display: flex;
|
| 197 |
+
align-items: flex-start;
|
| 198 |
+
justify-content: space-between;
|
| 199 |
+
margin-bottom: 1.25rem;
|
| 200 |
+
padding-bottom: 1rem;
|
| 201 |
+
border-bottom: 1px solid #e5e7eb;
|
| 202 |
+
}
|
| 203 |
+
.dark .header-row { border-color: #3f3f46; }
|
| 204 |
+
|
| 205 |
+
/* Status badge */
|
| 206 |
+
.status-badge > .prose p {
|
| 207 |
+
margin: 0;
|
| 208 |
+
font-size: 0.82rem;
|
| 209 |
+
color: #6b7280;
|
| 210 |
+
text-align: right;
|
| 211 |
+
}
|
| 212 |
+
.dark .status-badge > .prose p { color: #a1a1aa; }
|
| 213 |
+
|
| 214 |
+
/* Theme toggle — rounding only on outer frame; Gradio may add per-button radii (fix asymmetry) */
|
| 215 |
+
.theme-toggle .wrap {
|
| 216 |
+
gap: 0 !important;
|
| 217 |
+
display: inline-flex !important;
|
| 218 |
+
flex-direction: row !important;
|
| 219 |
+
align-items: stretch;
|
| 220 |
+
border-radius: var(--ea-radius-control);
|
| 221 |
+
overflow: hidden;
|
| 222 |
+
box-sizing: border-box;
|
| 223 |
+
}
|
| 224 |
+
/* Prefer fieldset / radiogroup as the stroked frame; otherwise stroke the wrap (e.g. button UI) */
|
| 225 |
+
.theme-toggle fieldset,
|
| 226 |
+
.theme-toggle [role="radiogroup"] {
|
| 227 |
+
display: inline-flex !important;
|
| 228 |
+
flex-direction: row !important;
|
| 229 |
+
align-items: stretch;
|
| 230 |
+
padding: 0 !important;
|
| 231 |
+
margin: 0 !important;
|
| 232 |
+
border: 1px solid #d1d5db !important;
|
| 233 |
+
border-radius: var(--ea-radius-control);
|
| 234 |
+
overflow: hidden;
|
| 235 |
+
gap: 0 !important;
|
| 236 |
+
box-sizing: border-box;
|
| 237 |
+
}
|
| 238 |
+
.theme-toggle .wrap:not(:has(fieldset)):not(:has([role="radiogroup"])) {
|
| 239 |
+
border: 1px solid #d1d5db;
|
| 240 |
+
}
|
| 241 |
+
.theme-toggle .gap-2 { gap: 0 !important; }
|
| 242 |
+
.theme-toggle input[type=radio] { display: none; }
|
| 243 |
+
.theme-toggle label {
|
| 244 |
+
display: inline-flex;
|
| 245 |
+
align-items: center;
|
| 246 |
+
gap: 4px;
|
| 247 |
+
padding: 4px 12px;
|
| 248 |
+
font-size: 0.78rem;
|
| 249 |
+
font-weight: 500;
|
| 250 |
+
cursor: pointer;
|
| 251 |
+
border: none !important;
|
| 252 |
+
border-radius: 0 !important;
|
| 253 |
+
border-right: 1px solid #d1d5db !important;
|
| 254 |
+
background: #fff;
|
| 255 |
+
color: #374151;
|
| 256 |
+
transition: background 0.15s, color 0.15s, border-color 0.15s;
|
| 257 |
+
white-space: nowrap;
|
| 258 |
+
}
|
| 259 |
+
.theme-toggle label:last-of-type { border-right: none !important; }
|
| 260 |
+
/* Segments: zero radius so only the group clips to --ea-radius-control (overrides Gradio/Tailwind per-side radii) */
|
| 261 |
+
.theme-toggle button,
|
| 262 |
+
.theme-toggle [role="radiogroup"] button,
|
| 263 |
+
.theme-toggle button:first-child,
|
| 264 |
+
.theme-toggle button:last-child {
|
| 265 |
+
border: none !important;
|
| 266 |
+
border-radius: 0 !important;
|
| 267 |
+
border-top-left-radius: 0 !important;
|
| 268 |
+
border-top-right-radius: 0 !important;
|
| 269 |
+
border-bottom-left-radius: 0 !important;
|
| 270 |
+
border-bottom-right-radius: 0 !important;
|
| 271 |
+
border-right: 1px solid #d1d5db !important;
|
| 272 |
+
margin: 0 !important;
|
| 273 |
+
box-shadow: none !important;
|
| 274 |
+
}
|
| 275 |
+
.theme-toggle button:last-child { border-right: none !important; }
|
| 276 |
+
/* Gradio 4+ uses buttons; Zerodha blue for resolved/selected segment */
|
| 277 |
+
.theme-toggle button.selected,
|
| 278 |
+
.theme-toggle button.ea-theme-active,
|
| 279 |
+
.theme-toggle button[aria-checked="true"],
|
| 280 |
+
.theme-toggle button[data-state="checked"] {
|
| 281 |
+
background: #387ED1 !important;
|
| 282 |
+
color: #fff !important;
|
| 283 |
+
border-color: #387ED1 !important;
|
| 284 |
+
z-index: 1;
|
| 285 |
+
}
|
| 286 |
+
.theme-toggle input[type=radio]:checked + label,
|
| 287 |
+
.theme-toggle label.selected {
|
| 288 |
+
background: #387ED1;
|
| 289 |
+
color: #fff;
|
| 290 |
+
border-color: #387ED1;
|
| 291 |
+
z-index: 1;
|
| 292 |
+
}
|
| 293 |
+
/* Same highlight when .dark is on body/html (do not rely on html.dark alone) */
|
| 294 |
+
.dark .theme-toggle button.selected,
|
| 295 |
+
.dark .theme-toggle button.ea-theme-active,
|
| 296 |
+
.dark .theme-toggle button[aria-checked="true"],
|
| 297 |
+
.dark .theme-toggle button[data-state="checked"] {
|
| 298 |
+
background: #387ED1 !important;
|
| 299 |
+
color: #fff !important;
|
| 300 |
+
border-color: #387ED1 !important;
|
| 301 |
+
}
|
| 302 |
+
body.dark .theme-toggle button.selected,
|
| 303 |
+
body.dark .theme-toggle button.ea-theme-active,
|
| 304 |
+
body.dark .theme-toggle button[aria-checked="true"],
|
| 305 |
+
body.dark .theme-toggle button[data-state="checked"],
|
| 306 |
+
html.dark .theme-toggle button.selected,
|
| 307 |
+
html.dark .theme-toggle button.ea-theme-active,
|
| 308 |
+
html.dark .theme-toggle button[aria-checked="true"],
|
| 309 |
+
html.dark .theme-toggle button[data-state="checked"] {
|
| 310 |
+
background: #387ED1 !important;
|
| 311 |
+
color: #fff !important;
|
| 312 |
+
border-color: #387ED1 !important;
|
| 313 |
+
}
|
| 314 |
+
/* System: Gradio keeps "System" selected — neutralize that segment; .ea-theme-active marks resolved Light/Dark */
|
| 315 |
+
html[data-ea-theme="system"][data-ea-resolved="light"] .theme-toggle button.selected:not(.ea-theme-active) {
|
| 316 |
+
background: #fff !important;
|
| 317 |
+
color: #374151 !important;
|
| 318 |
+
border-right-color: #d1d5db !important;
|
| 319 |
+
}
|
| 320 |
+
html[data-ea-theme="system"][data-ea-resolved="dark"] .theme-toggle button.selected:not(.ea-theme-active) {
|
| 321 |
+
background: #3f3f46 !important;
|
| 322 |
+
color: #d4d4d8 !important;
|
| 323 |
+
border-right-color: #52525b !important;
|
| 324 |
+
}
|
| 325 |
+
.dark .theme-toggle .wrap:not(:has(fieldset)):not(:has([role="radiogroup"])) {
|
| 326 |
+
border-color: #52525b;
|
| 327 |
+
}
|
| 328 |
+
.dark .theme-toggle fieldset,
|
| 329 |
+
.dark .theme-toggle [role="radiogroup"] {
|
| 330 |
+
border-color: #52525b !important;
|
| 331 |
+
}
|
| 332 |
+
.dark .theme-toggle button { border-right-color: #52525b !important; }
|
| 333 |
+
.dark .theme-toggle label {
|
| 334 |
+
background: #3f3f46;
|
| 335 |
+
color: #d4d4d8;
|
| 336 |
+
border-right-color: #52525b !important;
|
| 337 |
+
}
|
| 338 |
+
.dark .theme-toggle input[type=radio]:checked + label,
|
| 339 |
+
.dark .theme-toggle label.selected {
|
| 340 |
+
background: #387ED1;
|
| 341 |
+
color: #fff;
|
| 342 |
+
border-color: #387ED1;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
/* Gradio index.html uses @media (prefers-color-scheme: dark) on body — override when forcing light */
|
| 346 |
+
@media (prefers-color-scheme: dark) {
|
| 347 |
+
html[data-ea-theme="light"] body {
|
| 348 |
+
background: var(--bg, #f8f9fa) !important;
|
| 349 |
+
color: var(--col, #1a1a1a) !important;
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
/* Task instruction callout — top accent + even border (no heavy left margin) */
|
| 354 |
+
.task-instruction {
|
| 355 |
+
border: 1px solid #bfdbfe;
|
| 356 |
+
border-top: 3px solid #387ED1;
|
| 357 |
+
background: #f0f6ff;
|
| 358 |
+
border-radius: var(--ea-radius-block);
|
| 359 |
+
padding: 0.9rem 1.1rem;
|
| 360 |
+
margin-bottom: 1rem;
|
| 361 |
+
box-sizing: border-box;
|
| 362 |
+
}
|
| 363 |
+
.dark .task-instruction {
|
| 364 |
+
background: rgba(30, 58, 138, 0.14);
|
| 365 |
+
border-color: #3f3f46;
|
| 366 |
+
border-top-color: #60a5fa;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
/* Market Data JSON — full width, aligned with section (no label-column indent) */
|
| 370 |
+
.market-data-json {
|
| 371 |
+
width: 100% !important;
|
| 372 |
+
border: 1px solid #e5e7eb;
|
| 373 |
+
border-radius: var(--ea-radius-block);
|
| 374 |
+
overflow: hidden;
|
| 375 |
+
box-sizing: border-box;
|
| 376 |
+
}
|
| 377 |
+
.dark .market-data-json {
|
| 378 |
+
border-color: #3f3f46;
|
| 379 |
+
}
|
| 380 |
+
.market-data-json > .wrap {
|
| 381 |
+
padding: 0 !important;
|
| 382 |
+
gap: 0 !important;
|
| 383 |
+
}
|
| 384 |
+
.market-data-json pre {
|
| 385 |
+
margin: 0 !important;
|
| 386 |
+
width: 100%;
|
| 387 |
+
box-sizing: border-box;
|
| 388 |
+
border-radius: 0 !important;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
/* Accordion content scroll */
|
| 392 |
+
.accordion-scroll > .prose {
|
| 393 |
+
max-height: 280px;
|
| 394 |
+
overflow-y: auto;
|
| 395 |
+
padding-right: 4px;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
/* Tab active indicator */
|
| 399 |
+
.tabs > .tab-nav button.selected {
|
| 400 |
+
border-bottom: 2px solid #387ED1 !important;
|
| 401 |
+
color: #387ED1 !important;
|
| 402 |
+
font-weight: 600;
|
| 403 |
+
}
|
| 404 |
+
.dark .tabs > .tab-nav button.selected {
|
| 405 |
+
border-bottom-color: #60a5fa !important;
|
| 406 |
+
color: #60a5fa !important;
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
/* Result / error cards */
|
| 410 |
+
.result-card {
|
| 411 |
+
border-radius: var(--ea-radius-block);
|
| 412 |
+
padding: 0.9rem 1.1rem;
|
| 413 |
+
border-width: 1.5px;
|
| 414 |
+
border-style: solid;
|
| 415 |
+
background: #fff;
|
| 416 |
+
margin-top: 0.5rem;
|
| 417 |
+
}
|
| 418 |
+
.dark .result-card { background: #27272a; }
|
| 419 |
+
.result-card .rc-title {
|
| 420 |
+
font-weight: 600;
|
| 421 |
+
font-size: 0.92rem;
|
| 422 |
+
margin-bottom: 0.4rem;
|
| 423 |
+
color: #1a1a1a;
|
| 424 |
+
}
|
| 425 |
+
.dark .result-card .rc-title { color: #f4f4f5; }
|
| 426 |
+
.result-card .rc-body {
|
| 427 |
+
font-size: 0.9rem;
|
| 428 |
+
color: #374151;
|
| 429 |
+
line-height: 1.6;
|
| 430 |
+
}
|
| 431 |
+
.dark .result-card .rc-body { color: #d4d4d8; }
|
| 432 |
+
|
| 433 |
+
/* Result variants (Analysis tab HTML) — avoid inline light bg + .dark text */
|
| 434 |
+
.result-card.result-ok {
|
| 435 |
+
background: #f0fdf4;
|
| 436 |
+
border-color: #bbf7d0;
|
| 437 |
+
}
|
| 438 |
+
.result-card.result-ok .rc-title { color: #14532d; }
|
| 439 |
+
.result-card.result-ok .rc-body { color: #374151; }
|
| 440 |
+
.dark .result-card.result-ok {
|
| 441 |
+
background: rgba(20, 83, 45, 0.42);
|
| 442 |
+
border-color: #15803d;
|
| 443 |
+
}
|
| 444 |
+
.dark .result-card.result-ok .rc-title { color: #bbf7d0; }
|
| 445 |
+
.dark .result-card.result-ok .rc-body { color: #d4d4d8; }
|
| 446 |
+
|
| 447 |
+
.result-card.result-bad {
|
| 448 |
+
background: #fff1f2;
|
| 449 |
+
border-color: #fecaca;
|
| 450 |
+
}
|
| 451 |
+
.result-card.result-bad .rc-title { color: #991b1b; }
|
| 452 |
+
.result-card.result-bad .rc-body { color: #374151; }
|
| 453 |
+
.dark .result-card.result-bad {
|
| 454 |
+
background: rgba(127, 29, 29, 0.38);
|
| 455 |
+
border-color: #b91c1c;
|
| 456 |
+
}
|
| 457 |
+
.dark .result-card.result-bad .rc-title { color: #fecaca; }
|
| 458 |
+
.dark .result-card.result-bad .rc-body { color: #d4d4d8; }
|
| 459 |
+
|
| 460 |
+
.result-card.result-notice {
|
| 461 |
+
background: #fffbeb;
|
| 462 |
+
border-color: #fde68a;
|
| 463 |
+
}
|
| 464 |
+
.result-card.result-notice .rc-title { color: #92400e; }
|
| 465 |
+
.result-card.result-notice .rc-body { color: #374151; }
|
| 466 |
+
.dark .result-card.result-notice {
|
| 467 |
+
background: rgba(120, 53, 15, 0.42);
|
| 468 |
+
border-color: #d97706;
|
| 469 |
+
}
|
| 470 |
+
.dark .result-card.result-notice .rc-title { color: #fde68a; }
|
| 471 |
+
.dark .result-card.result-notice .rc-body { color: #d4d4d8; }
|
| 472 |
+
|
| 473 |
+
/* Sidebar section labels */
|
| 474 |
+
.section-label {
|
| 475 |
+
font-size: 0.78rem;
|
| 476 |
+
font-weight: 700;
|
| 477 |
+
letter-spacing: 0.06em;
|
| 478 |
+
text-transform: uppercase;
|
| 479 |
+
color: #9ca3af;
|
| 480 |
+
margin: 1rem 0 0.4rem;
|
| 481 |
+
}
|
| 482 |
+
.dark .section-label { color: #71717a; }
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
# ---------------------------------------------------------------------------
|
| 486 |
+
# JS — dark / light / system toggle
|
| 487 |
+
# ---------------------------------------------------------------------------
|
| 488 |
+
|
| 489 |
+
theme_toggle_js = """
|
| 490 |
+
(() => {
|
| 491 |
+
const KEY = "ea-theme";
|
| 492 |
+
const PREF_INDEX = { system: 0, light: 1, dark: 2 };
|
| 493 |
+
|
| 494 |
+
function wantsDark(pref) {
|
| 495 |
+
if (pref === "dark") return true;
|
| 496 |
+
if (pref === "light") return false;
|
| 497 |
+
return window.matchMedia("(prefers-color-scheme: dark)").matches;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
function getToggleRoot() {
|
| 501 |
+
return document.querySelector(".theme-toggle");
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
function getToggleButtons() {
|
| 505 |
+
const root = getToggleRoot();
|
| 506 |
+
if (!root) return [];
|
| 507 |
+
return [...root.querySelectorAll("button")];
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
function isButtonSelected(b) {
|
| 511 |
+
return (
|
| 512 |
+
b.classList.contains("selected") ||
|
| 513 |
+
b.getAttribute("aria-checked") === "true" ||
|
| 514 |
+
b.dataset.state === "checked"
|
| 515 |
+
);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
/** Gradio Radio value ↔ saved preference (fixes default "System" after reload). */
|
| 519 |
+
function syncRadioToStorage() {
|
| 520 |
+
const pref = localStorage.getItem(KEY) || "system";
|
| 521 |
+
const want = PREF_INDEX[pref];
|
| 522 |
+
if (want === undefined) return;
|
| 523 |
+
const buttons = getToggleButtons();
|
| 524 |
+
if (buttons.length < 3) return;
|
| 525 |
+
const selected = buttons.findIndex(isButtonSelected);
|
| 526 |
+
if (selected !== want) buttons[want].click();
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
/** Highlight effective Light/Dark when preference is system; match explicit light/dark. */
|
| 530 |
+
function updateThemeToggleVisual() {
|
| 531 |
+
const pref = localStorage.getItem(KEY) || "system";
|
| 532 |
+
const buttons = getToggleButtons();
|
| 533 |
+
if (buttons.length < 3) return;
|
| 534 |
+
|
| 535 |
+
// Reset all buttons: remove class, clear highlight, and force radius to 0 via inline
|
| 536 |
+
// !important (beats Gradio's Tailwind rounded-l-lg / rounded-r-lg cascade)
|
| 537 |
+
buttons.forEach((b) => {
|
| 538 |
+
b.classList.remove("ea-theme-active");
|
| 539 |
+
b.style.removeProperty("background");
|
| 540 |
+
b.style.removeProperty("color");
|
| 541 |
+
b.style.removeProperty("border-color");
|
| 542 |
+
["border-radius",
|
| 543 |
+
"border-top-left-radius", "border-top-right-radius",
|
| 544 |
+
"border-bottom-left-radius", "border-bottom-right-radius",
|
| 545 |
+
].forEach((p) => b.style.setProperty(p, "0", "important"));
|
| 546 |
+
});
|
| 547 |
+
|
| 548 |
+
const resolvedDark = wantsDark(pref);
|
| 549 |
+
let activeBtn = null;
|
| 550 |
+
if (pref === "system") {
|
| 551 |
+
activeBtn = resolvedDark ? buttons[2] : buttons[1];
|
| 552 |
+
} else if (pref === "dark") {
|
| 553 |
+
activeBtn = buttons[2];
|
| 554 |
+
} else {
|
| 555 |
+
activeBtn = buttons[1];
|
| 556 |
+
}
|
| 557 |
+
if (activeBtn) {
|
| 558 |
+
activeBtn.classList.add("ea-theme-active");
|
| 559 |
+
// Inline !important overrides Gradio's dark-mode stylesheet — no cascade battle
|
| 560 |
+
activeBtn.style.setProperty("background", "#387ED1", "important");
|
| 561 |
+
activeBtn.style.setProperty("color", "#fff", "important");
|
| 562 |
+
activeBtn.style.setProperty("border-color", "#387ED1", "important");
|
| 563 |
+
}
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
function applyTheme(pref) {
|
| 567 |
+
const dark = wantsDark(pref);
|
| 568 |
+
const html = document.documentElement;
|
| 569 |
+
html.classList.toggle("dark", dark);
|
| 570 |
+
if (document.body) {
|
| 571 |
+
document.body.classList.toggle("dark", dark);
|
| 572 |
+
}
|
| 573 |
+
html.setAttribute("data-ea-theme", pref);
|
| 574 |
+
html.setAttribute("data-ea-resolved", dark ? "dark" : "light");
|
| 575 |
+
if (pref === "system") {
|
| 576 |
+
html.style.removeProperty("color-scheme");
|
| 577 |
+
} else {
|
| 578 |
+
html.style.colorScheme = dark ? "dark" : "light";
|
| 579 |
+
}
|
| 580 |
+
queueMicrotask(() => {
|
| 581 |
+
syncRadioToStorage();
|
| 582 |
+
updateThemeToggleVisual();
|
| 583 |
+
});
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
let toggleSyncScheduled = false;
|
| 587 |
+
function scheduleToggleSync() {
|
| 588 |
+
if (toggleSyncScheduled) return;
|
| 589 |
+
toggleSyncScheduled = true;
|
| 590 |
+
queueMicrotask(() => {
|
| 591 |
+
toggleSyncScheduled = false;
|
| 592 |
+
syncRadioToStorage();
|
| 593 |
+
updateThemeToggleVisual();
|
| 594 |
+
});
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
let themeToggleOuterObs = false;
|
| 598 |
+
function attachThemeToggleObserver() {
|
| 599 |
+
const hook = (root) => {
|
| 600 |
+
if (!root || root.__eaToggleObserved) return;
|
| 601 |
+
root.__eaToggleObserved = true;
|
| 602 |
+
const mo = new MutationObserver(() => scheduleToggleSync());
|
| 603 |
+
mo.observe(root, {
|
| 604 |
+
childList: true,
|
| 605 |
+
subtree: true,
|
| 606 |
+
attributes: true,
|
| 607 |
+
attributeFilter: ["class", "aria-checked", "data-state"],
|
| 608 |
+
});
|
| 609 |
+
scheduleToggleSync();
|
| 610 |
+
};
|
| 611 |
+
hook(getToggleRoot());
|
| 612 |
+
if (themeToggleOuterObs) return;
|
| 613 |
+
themeToggleOuterObs = true;
|
| 614 |
+
const outer = new MutationObserver(() => {
|
| 615 |
+
const r = getToggleRoot();
|
| 616 |
+
if (r && !r.__eaToggleObserved) hook(r);
|
| 617 |
+
});
|
| 618 |
+
if (document.body) {
|
| 619 |
+
outer.observe(document.body, { childList: true, subtree: true });
|
| 620 |
+
}
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
function syncFromStorage() {
|
| 624 |
+
applyTheme(localStorage.getItem(KEY) || "system");
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
function attachBodyObserver() {
|
| 628 |
+
if (!document.body) return;
|
| 629 |
+
new MutationObserver(() => {
|
| 630 |
+
const pref = localStorage.getItem(KEY) || "system";
|
| 631 |
+
const need = wantsDark(pref);
|
| 632 |
+
if (document.body.classList.contains("dark") !== need) {
|
| 633 |
+
applyTheme(pref);
|
| 634 |
+
}
|
| 635 |
+
}).observe(document.body, { attributes: true, attributeFilter: ["class"] });
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
syncFromStorage();
|
| 639 |
+
if (document.body) {
|
| 640 |
+
attachBodyObserver();
|
| 641 |
+
attachThemeToggleObserver();
|
| 642 |
+
} else {
|
| 643 |
+
document.addEventListener("DOMContentLoaded", () => {
|
| 644 |
+
attachBodyObserver();
|
| 645 |
+
attachThemeToggleObserver();
|
| 646 |
+
});
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
function beatGradioInit() {
|
| 650 |
+
syncFromStorage();
|
| 651 |
+
queueMicrotask(syncFromStorage);
|
| 652 |
+
requestAnimationFrame(syncFromStorage);
|
| 653 |
+
}
|
| 654 |
+
if (document.readyState === "loading") {
|
| 655 |
+
document.addEventListener("DOMContentLoaded", beatGradioInit);
|
| 656 |
+
} else {
|
| 657 |
+
beatGradioInit();
|
| 658 |
+
}
|
| 659 |
+
window.addEventListener("load", () => {
|
| 660 |
+
syncFromStorage();
|
| 661 |
+
attachThemeToggleObserver();
|
| 662 |
+
[0, 50, 200, 500, 1000].forEach((ms) => setTimeout(syncFromStorage, ms));
|
| 663 |
+
});
|
| 664 |
+
|
| 665 |
+
window.matchMedia("(prefers-color-scheme: dark)").addEventListener("change", () => {
|
| 666 |
+
if ((localStorage.getItem(KEY) || "system") === "system") {
|
| 667 |
+
applyTheme("system");
|
| 668 |
+
}
|
| 669 |
+
});
|
| 670 |
+
|
| 671 |
+
window.__setEATheme = (pref) => {
|
| 672 |
+
localStorage.setItem(KEY, pref);
|
| 673 |
+
applyTheme(pref);
|
| 674 |
+
};
|
| 675 |
+
})();
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
# ---------------------------------------------------------------------------
|
| 679 |
+
# Helper functions
|
| 680 |
+
# ---------------------------------------------------------------------------
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _format_result_html(reward: float, ground_truth: Any) -> str:
|
| 684 |
+
variant = "result-ok" if reward >= 0 else "result-bad"
|
| 685 |
+
icon = "+" if reward >= 0 else ""
|
| 686 |
+
gt = html.escape(str(ground_truth))
|
| 687 |
+
return (
|
| 688 |
+
f'<div class="result-card {variant}">'
|
| 689 |
+
f'<div class="rc-title">Result {icon}</div>'
|
| 690 |
+
f'<div class="rc-body">'
|
| 691 |
+
f"<strong>Reward:</strong> {reward:.4f}<br/>"
|
| 692 |
+
f"<strong>Ground truth:</strong> {gt}"
|
| 693 |
+
f"</div></div>"
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def _format_error_html(message: str) -> str:
|
| 698 |
+
msg = html.escape(message)
|
| 699 |
+
return (
|
| 700 |
+
'<div class="result-card result-notice">'
|
| 701 |
+
'<div class="rc-title">Notice</div>'
|
| 702 |
+
f'<div class="rc-body">{msg}</div></div>'
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def _episode_status_markdown(
|
| 707 |
+
task_id: str,
|
| 708 |
+
company: str | None = None,
|
| 709 |
+
year: str | None = None,
|
| 710 |
+
quarter: str | None = None,
|
| 711 |
+
) -> str:
|
| 712 |
+
use_specific = bool(
|
| 713 |
+
company and year and quarter and company != RANDOM_EPISODE_LABEL
|
| 714 |
+
)
|
| 715 |
+
if use_specific:
|
| 716 |
+
assert company is not None and year is not None and quarter is not None
|
| 717 |
+
idx = get_episode_index()
|
| 718 |
+
sym = idx.symbol_for_display(company)
|
| 719 |
+
yi = int(year)
|
| 720 |
+
qi = int(quarter)
|
| 721 |
+
eid = format_episode_id(sym, yi, qi)
|
| 722 |
+
row_line = f"**Row:** `{html.escape(eid)}` \n"
|
| 723 |
+
else:
|
| 724 |
+
row_line = "**Row:** random sample \n"
|
| 725 |
+
return (
|
| 726 |
+
f"**Episode loaded** — `{html.escape(task_id)}` \n"
|
| 727 |
+
f"{row_line}"
|
| 728 |
+
"Use **Observation** / **Analysis** tabs."
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def _reset_failed_outputs(message: str) -> list[Any]:
|
| 733 |
+
texts, acc_updates = _text_rows_and_accordion_updates(None)
|
| 734 |
+
err_status = f"**Episode error** — {html.escape(message)}"
|
| 735 |
+
idle = "*Could not load episode. Fix the selection or try again.*"
|
| 736 |
+
return [
|
| 737 |
+
err_status,
|
| 738 |
+
idle,
|
| 739 |
+
*texts,
|
| 740 |
+
*acc_updates,
|
| 741 |
+
{},
|
| 742 |
+
gr.update(visible=False),
|
| 743 |
+
gr.update(visible=False, value=""),
|
| 744 |
+
"",
|
| 745 |
+
message,
|
| 746 |
+
]
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def _init_episode_dropdowns() -> tuple[Any, Any, Any]:
|
| 750 |
+
idx = get_episode_index()
|
| 751 |
+
displays = idx.sorted_company_displays()
|
| 752 |
+
choices = [RANDOM_EPISODE_LABEL] + displays
|
| 753 |
+
return (
|
| 754 |
+
gr.update(choices=choices, value=RANDOM_EPISODE_LABEL),
|
| 755 |
+
gr.update(choices=[], value=None, interactive=False),
|
| 756 |
+
gr.update(choices=[], value=None, interactive=False),
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def _on_company_change(company: str | None) -> tuple[Any, Any]:
|
| 761 |
+
if not company or company == RANDOM_EPISODE_LABEL:
|
| 762 |
+
return (
|
| 763 |
+
gr.update(choices=[], value=None, interactive=False),
|
| 764 |
+
gr.update(choices=[], value=None, interactive=False),
|
| 765 |
+
)
|
| 766 |
+
idx = get_episode_index()
|
| 767 |
+
sym = idx.symbol_for_display(company)
|
| 768 |
+
years = idx.years_for_symbol(sym)
|
| 769 |
+
year_choices = [str(y) for y in years]
|
| 770 |
+
y0 = year_choices[0] if year_choices else None
|
| 771 |
+
if y0 is not None:
|
| 772 |
+
qs = idx.quarters_for(sym, int(y0))
|
| 773 |
+
q_choices = [str(q) for q in qs]
|
| 774 |
+
q0 = q_choices[0] if q_choices else None
|
| 775 |
+
else:
|
| 776 |
+
q_choices, q0 = [], None
|
| 777 |
+
return (
|
| 778 |
+
gr.update(choices=year_choices, value=y0, interactive=bool(year_choices)),
|
| 779 |
+
gr.update(choices=q_choices, value=q0, interactive=bool(q_choices)),
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def _on_year_change(company: str | None, year_str: str | None) -> Any:
|
| 784 |
+
if not company or company == RANDOM_EPISODE_LABEL or not year_str:
|
| 785 |
+
return gr.update(choices=[], value=None, interactive=False)
|
| 786 |
+
idx = get_episode_index()
|
| 787 |
+
sym = idx.symbol_for_display(company)
|
| 788 |
+
qs = idx.quarters_for(sym, int(year_str))
|
| 789 |
+
q_choices = [str(q) for q in qs]
|
| 790 |
+
q0 = q_choices[0] if q_choices else None
|
| 791 |
+
return gr.update(choices=q_choices, value=q0, interactive=bool(q_choices))
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
def _text_rows_and_accordion_updates(
|
| 795 |
+
obs: Optional[EarningsAnalystObservation],
|
| 796 |
+
) -> tuple[list[str], list[dict[str, Any]]]:
|
| 797 |
+
texts: list[str] = []
|
| 798 |
+
updates: list[dict[str, Any]] = []
|
| 799 |
+
ctx = obs.text_context if obs and obs.text_context else {}
|
| 800 |
+
for key in TEXT_KEYS:
|
| 801 |
+
raw = ctx.get(key, "") if ctx else ""
|
| 802 |
+
texts.append(raw if isinstance(raw, str) else str(raw))
|
| 803 |
+
updates.append(gr.update(visible=(key in ctx)))
|
| 804 |
+
return texts, updates
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
# ---------------------------------------------------------------------------
|
| 808 |
+
# State & environment helpers
|
| 809 |
+
# ---------------------------------------------------------------------------
|
| 810 |
+
|
| 811 |
|
| 812 |
class State:
|
| 813 |
+
def __init__(self) -> None:
|
| 814 |
self.env: Optional[EarningsAnalystEnvironment] = None
|
| 815 |
self.obs: Optional[EarningsAnalystObservation] = None
|
| 816 |
self.task_id: str = "sentiment_label"
|
|
|
|
| 819 |
state = State()
|
| 820 |
|
| 821 |
|
| 822 |
+
async def reset_env(
|
| 823 |
+
task_id: str,
|
| 824 |
+
company: str | None,
|
| 825 |
+
year: str | None,
|
| 826 |
+
quarter: str | None,
|
| 827 |
+
) -> list[Any]:
|
| 828 |
state.task_id = task_id
|
| 829 |
state.env = EarningsAnalystEnvironment(task_id=task_id)
|
| 830 |
+
try:
|
| 831 |
+
use_specific = bool(
|
| 832 |
+
company and year and quarter and company != RANDOM_EPISODE_LABEL
|
| 833 |
+
)
|
| 834 |
+
if use_specific:
|
| 835 |
+
assert company is not None and year is not None and quarter is not None
|
| 836 |
+
idx = get_episode_index()
|
| 837 |
+
sym = idx.symbol_for_display(company)
|
| 838 |
+
state.obs = state.env.reset(
|
| 839 |
+
pick_symbol=sym,
|
| 840 |
+
pick_year=int(year),
|
| 841 |
+
pick_quarter=int(quarter),
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
state.obs = state.env.reset()
|
| 845 |
+
except (ValueError, TaskNotImplementedError) as e:
|
| 846 |
+
state.obs = None
|
| 847 |
+
return _reset_failed_outputs(str(e))
|
| 848 |
|
| 849 |
+
texts, acc_updates = _text_rows_and_accordion_updates(state.obs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
|
| 851 |
+
numerical_value: Any = (
|
| 852 |
+
state.obs.numerical_context if state.obs.numerical_context else {}
|
|
|
|
|
|
|
| 853 |
)
|
| 854 |
|
| 855 |
return [
|
| 856 |
+
_episode_status_markdown(task_id, company, year, quarter),
|
| 857 |
state.obs.task_instruction,
|
| 858 |
+
*texts,
|
| 859 |
+
*acc_updates,
|
| 860 |
+
numerical_value,
|
| 861 |
+
gr.update(visible=True), # prediction_row
|
| 862 |
+
gr.update(visible=False, value=""), # result_view
|
| 863 |
+
"", # pred_input
|
| 864 |
+
"", # message_view
|
| 865 |
]
|
| 866 |
|
| 867 |
|
| 868 |
+
async def step_env(prediction: str) -> list[Any]:
|
| 869 |
if not state.env or not state.obs:
|
| 870 |
+
return [
|
| 871 |
+
gr.update(),
|
| 872 |
+
gr.update(
|
| 873 |
+
visible=True,
|
| 874 |
+
value=_format_error_html(
|
| 875 |
+
"Environment not initialized. Click New Episode."
|
| 876 |
+
),
|
| 877 |
+
),
|
| 878 |
+
]
|
| 879 |
|
| 880 |
action = EarningsAnalystAction(prediction=prediction)
|
| 881 |
state.obs = state.env.step(action)
|
| 882 |
|
| 883 |
+
raw_reward = state.obs.reward
|
| 884 |
+
reward_f = float(raw_reward) if raw_reward is not None else 0.0
|
| 885 |
ground_truth = getattr(state.obs, "ground_truth", "N/A")
|
| 886 |
|
|
|
|
|
|
|
| 887 |
return [
|
| 888 |
+
gr.update(visible=False),
|
| 889 |
+
gr.update(visible=True, value=_format_result_html(reward_f, ground_truth)),
|
| 890 |
]
|
| 891 |
|
| 892 |
|
| 893 |
+
async def run_agent(
|
| 894 |
+
task_id: str,
|
| 895 |
+
api_key: str,
|
| 896 |
+
model: str,
|
| 897 |
+
base_url: str,
|
| 898 |
+
company: str | None,
|
| 899 |
+
year: str | None,
|
| 900 |
+
quarter: str | None,
|
| 901 |
+
) -> list[Any]:
|
| 902 |
+
n_out = 17
|
| 903 |
if not api_key:
|
| 904 |
+
return [gr.update()] * (n_out - 1) + ["Please provide an API Key."]
|
| 905 |
|
| 906 |
+
out = await reset_env(task_id, company, year, quarter)
|
| 907 |
+
if state.obs is None:
|
| 908 |
+
return out
|
| 909 |
|
| 910 |
+
assert state.obs is not None
|
|
|
|
|
|
|
|
|
|
| 911 |
|
| 912 |
+
if base_url:
|
| 913 |
+
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
| 914 |
+
else:
|
| 915 |
+
client = AsyncOpenAI(api_key=api_key)
|
| 916 |
|
| 917 |
user_content = f"{state.obs.task_instruction}\n\n"
|
| 918 |
if state.obs.text_context:
|
|
|
|
| 943 |
parsed = json.loads(response_text)
|
| 944 |
prediction = str(parsed.get("prediction", response_text))
|
| 945 |
|
|
|
|
| 946 |
step_out = await step_env(prediction)
|
| 947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
return [
|
| 949 |
+
*out[:13],
|
|
|
|
|
|
|
| 950 |
step_out[0],
|
| 951 |
step_out[1],
|
| 952 |
prediction,
|
|
|
|
| 954 |
]
|
| 955 |
|
| 956 |
except Exception as e:
|
| 957 |
+
return [*out[:13], out[13], out[14], "", f"Error: {str(e)}"]
|
| 958 |
|
| 959 |
|
| 960 |
+
# ---------------------------------------------------------------------------
|
| 961 |
+
# Layout helpers
|
| 962 |
+
# ---------------------------------------------------------------------------
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def _accordion_title(key: str) -> str:
|
| 966 |
+
labels = {
|
| 967 |
+
"earnings_transcript": "Earnings Call Transcript",
|
| 968 |
+
"press_release_8k_body": "Press Release — 8-K Body",
|
| 969 |
+
"press_release_ex991": "Exhibit 99.1",
|
| 970 |
+
"press_release_ex992": "Exhibit 99.2",
|
| 971 |
+
"press_release_sources": "Press Release Sources",
|
| 972 |
+
}
|
| 973 |
+
return labels.get(key, key.replace("_", " ").title())
|
|
|
|
| 974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
|
| 976 |
+
# ---------------------------------------------------------------------------
|
| 977 |
+
# Gradio layout
|
| 978 |
+
# ---------------------------------------------------------------------------
|
| 979 |
+
|
| 980 |
+
with gr.Blocks(title="Earnings Analyst - OpenEnv") as demo:
|
| 981 |
+
with gr.Column(elem_classes="app-shell"):
|
| 982 |
+
# ── Header ──────────────────────────────────────────────────────────
|
| 983 |
+
with gr.Row(elem_classes="header-row"):
|
| 984 |
+
with gr.Column(scale=5, min_width=0):
|
| 985 |
+
gr.Markdown("## Earnings Analyst")
|
| 986 |
+
gr.Markdown(
|
| 987 |
+
"Interactive environment for financial analysis using "
|
| 988 |
+
"[OpenEnv](https://github.com/meta-pytorch/OpenEnv). "
|
| 989 |
+
"Evaluate agents or your own predictions on earnings call data."
|
| 990 |
+
)
|
| 991 |
+
with gr.Column(scale=2, min_width=0):
|
| 992 |
+
theme_toggle = gr.Radio(
|
| 993 |
+
choices=["System", "Light", "Dark"],
|
| 994 |
+
value="System",
|
| 995 |
+
label="Theme",
|
| 996 |
+
elem_classes="theme-toggle",
|
| 997 |
+
container=False,
|
| 998 |
+
)
|
| 999 |
+
status_badge = gr.Markdown(
|
| 1000 |
+
"No episode loaded. Click **New Episode**.",
|
| 1001 |
+
elem_classes="status-badge",
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
# ── Main content ─────────────────────────────────────────────────────
|
| 1005 |
with gr.Row():
|
| 1006 |
+
# Sidebar
|
| 1007 |
+
with gr.Column(scale=1, min_width=220):
|
| 1008 |
+
with gr.Group():
|
| 1009 |
+
company_dd = gr.Dropdown(
|
| 1010 |
+
label="Company",
|
| 1011 |
+
choices=[RANDOM_EPISODE_LABEL],
|
| 1012 |
+
value=RANDOM_EPISODE_LABEL,
|
| 1013 |
+
)
|
| 1014 |
+
with gr.Row():
|
| 1015 |
+
year_dd = gr.Dropdown(
|
| 1016 |
+
label="Year",
|
| 1017 |
+
choices=[],
|
| 1018 |
+
interactive=False,
|
| 1019 |
+
)
|
| 1020 |
+
quarter_dd = gr.Dropdown(
|
| 1021 |
+
label="Quarter",
|
| 1022 |
+
choices=[],
|
| 1023 |
+
interactive=False,
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
with gr.Group():
|
| 1027 |
task_select = gr.Dropdown(
|
| 1028 |
+
choices=TASK_IDS,
|
| 1029 |
+
value="sentiment_label",
|
| 1030 |
+
label="Active Task",
|
| 1031 |
)
|
| 1032 |
+
reset_btn = gr.Button("New Episode", variant="primary")
|
| 1033 |
|
| 1034 |
+
gr.Markdown("Auto-Agent", elem_classes="section-label")
|
| 1035 |
with gr.Group():
|
| 1036 |
api_key = gr.Textbox(
|
| 1037 |
label="OpenAI API Key",
|
|
|
|
| 1045 |
placeholder="https://api.openai.com/v1",
|
| 1046 |
value=os.environ.get("API_BASE_URL", ""),
|
| 1047 |
)
|
| 1048 |
+
agent_btn = gr.Button("Run LLM Agent", variant="secondary")
|
| 1049 |
|
| 1050 |
+
# Main panel
|
| 1051 |
with gr.Column(scale=2):
|
| 1052 |
+
with gr.Column(elem_classes="task-instruction"):
|
| 1053 |
+
instr_view = gr.Markdown("*No task loaded. Click **New Episode**.*")
|
| 1054 |
+
|
| 1055 |
with gr.Tabs():
|
| 1056 |
+
with gr.Tab("Observation"):
|
| 1057 |
+
gr.Markdown("**Documents** — expand a section to read.")
|
| 1058 |
+
|
| 1059 |
+
with gr.Accordion(
|
| 1060 |
+
_accordion_title(TEXT_KEYS[0]),
|
| 1061 |
+
open=True,
|
| 1062 |
+
visible=False,
|
| 1063 |
+
) as acc_earnings_transcript:
|
| 1064 |
+
md_earnings_transcript = gr.Markdown(
|
| 1065 |
+
"", elem_classes="accordion-scroll"
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
with gr.Accordion(
|
| 1069 |
+
_accordion_title(TEXT_KEYS[1]),
|
| 1070 |
+
open=False,
|
| 1071 |
+
visible=False,
|
| 1072 |
+
) as acc_press_release_8k_body:
|
| 1073 |
+
md_press_release_8k_body = gr.Markdown(
|
| 1074 |
+
"", elem_classes="accordion-scroll"
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
with gr.Accordion(
|
| 1078 |
+
_accordion_title(TEXT_KEYS[2]),
|
| 1079 |
+
open=False,
|
| 1080 |
+
visible=False,
|
| 1081 |
+
) as acc_press_release_ex991:
|
| 1082 |
+
md_press_release_ex991 = gr.Markdown(
|
| 1083 |
+
"", elem_classes="accordion-scroll"
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
with gr.Accordion(
|
| 1087 |
+
_accordion_title(TEXT_KEYS[3]),
|
| 1088 |
+
open=False,
|
| 1089 |
+
visible=False,
|
| 1090 |
+
) as acc_press_release_ex992:
|
| 1091 |
+
md_press_release_ex992 = gr.Markdown(
|
| 1092 |
+
"", elem_classes="accordion-scroll"
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
with gr.Accordion(
|
| 1096 |
+
_accordion_title(TEXT_KEYS[4]),
|
| 1097 |
+
open=False,
|
| 1098 |
+
visible=False,
|
| 1099 |
+
) as acc_press_release_sources:
|
| 1100 |
+
md_press_release_sources = gr.Markdown(
|
| 1101 |
+
"", elem_classes="accordion-scroll"
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
gr.Markdown("**Market Data**", elem_classes="section-label")
|
| 1105 |
+
num_view = gr.JSON(
|
| 1106 |
+
value={},
|
| 1107 |
+
label=None,
|
| 1108 |
+
show_label=False,
|
| 1109 |
+
container=False,
|
| 1110 |
+
elem_classes="market-data-json",
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
with gr.Tab("Analysis"):
|
| 1114 |
with gr.Column(visible=False) as prediction_row:
|
| 1115 |
pred_input = gr.Textbox(
|
| 1116 |
+
label="Your prediction / analysis output",
|
| 1117 |
placeholder="e.g. bullish, or 0.05",
|
| 1118 |
)
|
| 1119 |
+
submit_btn = gr.Button("Submit analysis", variant="primary")
|
| 1120 |
|
| 1121 |
+
result_view = gr.HTML(visible=False, value="")
|
| 1122 |
message_view = gr.Textbox(
|
| 1123 |
+
label="Agent log / messages",
|
| 1124 |
+
interactive=False,
|
| 1125 |
+
buttons=["copy"],
|
| 1126 |
+
lines=5,
|
| 1127 |
+
max_lines=12,
|
| 1128 |
+
autoscroll=True,
|
| 1129 |
)
|
| 1130 |
|
| 1131 |
+
# ── Event outputs list ───────────────────────────────────────────────────
|
| 1132 |
+
reset_outputs = [
|
| 1133 |
+
status_badge,
|
| 1134 |
+
instr_view,
|
| 1135 |
+
md_earnings_transcript,
|
| 1136 |
+
md_press_release_8k_body,
|
| 1137 |
+
md_press_release_ex991,
|
| 1138 |
+
md_press_release_ex992,
|
| 1139 |
+
md_press_release_sources,
|
| 1140 |
+
acc_earnings_transcript,
|
| 1141 |
+
acc_press_release_8k_body,
|
| 1142 |
+
acc_press_release_ex991,
|
| 1143 |
+
acc_press_release_ex992,
|
| 1144 |
+
acc_press_release_sources,
|
| 1145 |
+
num_view,
|
| 1146 |
+
prediction_row,
|
| 1147 |
+
result_view,
|
| 1148 |
+
pred_input,
|
| 1149 |
+
message_view,
|
| 1150 |
+
]
|
| 1151 |
+
|
| 1152 |
+
demo.load(
|
| 1153 |
+
_init_episode_dropdowns,
|
| 1154 |
+
outputs=[company_dd, year_dd, quarter_dd],
|
| 1155 |
+
)
|
| 1156 |
+
company_dd.change(
|
| 1157 |
+
_on_company_change,
|
| 1158 |
+
inputs=[company_dd],
|
| 1159 |
+
outputs=[year_dd, quarter_dd],
|
| 1160 |
+
)
|
| 1161 |
+
year_dd.change(
|
| 1162 |
+
_on_year_change,
|
| 1163 |
+
inputs=[company_dd, year_dd],
|
| 1164 |
+
outputs=[quarter_dd],
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
reset_btn.click(
|
| 1168 |
fn=reset_env,
|
| 1169 |
+
inputs=[task_select, company_dd, year_dd, quarter_dd],
|
| 1170 |
+
outputs=reset_outputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
)
|
|
|
|
| 1172 |
submit_btn.click(
|
| 1173 |
fn=step_env, inputs=[pred_input], outputs=[prediction_row, result_view]
|
| 1174 |
)
|
|
|
|
| 1175 |
agent_btn.click(
|
| 1176 |
fn=run_agent,
|
| 1177 |
+
inputs=[
|
| 1178 |
+
task_select,
|
| 1179 |
+
api_key,
|
| 1180 |
+
model_name,
|
| 1181 |
+
base_url,
|
| 1182 |
+
company_dd,
|
| 1183 |
+
year_dd,
|
| 1184 |
+
quarter_dd,
|
|
|
|
| 1185 |
],
|
| 1186 |
+
outputs=reset_outputs,
|
| 1187 |
)
|
| 1188 |
|
| 1189 |
+
# Theme toggle — client-side only, no server round-trip
|
| 1190 |
+
theme_toggle.change(
|
| 1191 |
+
fn=None,
|
| 1192 |
+
inputs=[theme_toggle],
|
| 1193 |
+
js=(
|
| 1194 |
+
"(choice) => { "
|
| 1195 |
+
"const v = Array.isArray(choice) ? choice[0] : choice; "
|
| 1196 |
+
"const key = String(v ?? 'system').toLowerCase(); "
|
| 1197 |
+
"window.__setEATheme?.(key); "
|
| 1198 |
+
"return choice; "
|
| 1199 |
+
"}"
|
| 1200 |
+
),
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
# ---------------------------------------------------------------------------
|
| 1204 |
+
# FastAPI app
|
| 1205 |
+
# ---------------------------------------------------------------------------
|
| 1206 |
|
|
|
|
| 1207 |
api_app = FastAPI(title="Earnings Analyst API")
|
| 1208 |
|
|
|
|
| 1209 |
api_app.add_middleware(
|
| 1210 |
CORSMiddleware,
|
| 1211 |
allow_origins=["*"],
|
|
|
|
| 1213 |
allow_headers=["*"],
|
| 1214 |
)
|
| 1215 |
|
|
|
|
|
|
|
| 1216 |
|
| 1217 |
@api_app.get("/health")
|
| 1218 |
+
async def health() -> dict[str, str]:
|
| 1219 |
return {"status": "ok", "environment": "earnings_analyst"}
|
| 1220 |
|
| 1221 |
|
| 1222 |
@api_app.get("/web")
|
| 1223 |
+
async def web_redirect() -> RedirectResponse:
|
| 1224 |
return RedirectResponse(url="/")
|
| 1225 |
|
| 1226 |
|
|
|
|
| 1227 |
try:
|
| 1228 |
from earnings_analyst.server.app import app as env_app
|
| 1229 |
except (ImportError, ModuleNotFoundError):
|
| 1230 |
from server.app import app as env_app
|
| 1231 |
|
|
|
|
|
|
|
| 1232 |
api_app.include_router(env_app.router)
|
| 1233 |
api_app.mount("/api", env_app)
|
| 1234 |
|
| 1235 |
+
# Mount Gradio — theme, css, and js passed here (Gradio 6 API)
|
| 1236 |
+
app = gr.mount_gradio_app(
|
| 1237 |
+
api_app,
|
| 1238 |
+
demo,
|
| 1239 |
+
path="/",
|
| 1240 |
+
theme=theme,
|
| 1241 |
+
css=custom_css,
|
| 1242 |
+
js=theme_toggle_js,
|
| 1243 |
+
)
|
| 1244 |
|
| 1245 |
|
| 1246 |
+
def main() -> None:
|
| 1247 |
import uvicorn
|
| 1248 |
|
| 1249 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
environment_config.py
CHANGED
|
@@ -11,4 +11,21 @@ from earnings_analyst.tasks.registry import DEFAULT_TASK, TASKS
|
|
| 11 |
DATASET_ID = "RudrakshNanavaty/earnings-call-data"
|
| 12 |
DATASET_FILE = "episodes_press_release_8k.parquet"
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
DATASET_ID = "RudrakshNanavaty/earnings-call-data"
|
| 12 |
DATASET_FILE = "episodes_press_release_8k.parquet"
|
| 13 |
|
| 14 |
+
# Columns for Gradio episode selection (must exist in the parquet).
|
| 15 |
+
DATASET_EPISODE_ID_COLUMN = "episode_id"
|
| 16 |
+
DATASET_SYMBOL_COLUMN = "symbol"
|
| 17 |
+
DATASET_COMPANY_NAME_COLUMN = "company_name"
|
| 18 |
+
DATASET_YEAR_COLUMN = "year"
|
| 19 |
+
DATASET_QUARTER_COLUMN = "quarter"
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"DATASET_ID",
|
| 23 |
+
"DATASET_FILE",
|
| 24 |
+
"DATASET_EPISODE_ID_COLUMN",
|
| 25 |
+
"DATASET_SYMBOL_COLUMN",
|
| 26 |
+
"DATASET_COMPANY_NAME_COLUMN",
|
| 27 |
+
"DATASET_YEAR_COLUMN",
|
| 28 |
+
"DATASET_QUARTER_COLUMN",
|
| 29 |
+
"DEFAULT_TASK",
|
| 30 |
+
"TASKS",
|
| 31 |
+
]
|
server/__init__.py
CHANGED
|
@@ -1,5 +1,22 @@
|
|
| 1 |
-
"""Earnings Analyst environment server components.
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
__all__ = ["EarningsAnalystEnvironment"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Earnings Analyst environment server components.
|
| 2 |
|
| 3 |
+
``EarningsAnalystEnvironment`` is loaded lazily so importing sibling modules
|
| 4 |
+
(e.g. ``episode_index``) does not pull in the Hugging Face dataset.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import TYPE_CHECKING, Any
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from .earnings_analyst_environment import EarningsAnalystEnvironment
|
| 13 |
|
| 14 |
__all__ = ["EarningsAnalystEnvironment"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def __getattr__(name: str) -> Any:
|
| 18 |
+
if name == "EarningsAnalystEnvironment":
|
| 19 |
+
from .earnings_analyst_environment import EarningsAnalystEnvironment as _Env
|
| 20 |
+
|
| 21 |
+
return _Env
|
| 22 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
server/earnings_analyst_environment.py
CHANGED
|
@@ -11,7 +11,7 @@ import math
|
|
| 11 |
import os
|
| 12 |
import json
|
| 13 |
import random
|
| 14 |
-
from typing import Any
|
| 15 |
from uuid import uuid4
|
| 16 |
|
| 17 |
from openenv.core.env_server.interfaces import Environment
|
|
@@ -23,6 +23,7 @@ from earnings_analyst.tasks.exceptions import TaskNotImplementedError
|
|
| 23 |
from earnings_analyst.tasks.registry import get_grader
|
| 24 |
|
| 25 |
from .dataset_loader import dataset
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
def _resolve_task_id(explicit: str | None) -> str:
|
|
@@ -68,18 +69,62 @@ class EarningsAnalystEnvironment(Environment):
|
|
| 68 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 69 |
self._current_row: dict[str, Any] | None = None
|
| 70 |
|
| 71 |
-
def reset(
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if not self._cfg["implemented"]:
|
| 74 |
raise TaskNotImplementedError(
|
| 75 |
f"Task {self._task_id!r} is not implemented yet. "
|
| 76 |
f"Set implemented=True and fill spec/grader under tasks/ when ready."
|
| 77 |
)
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
row = dataset[idx]
|
| 82 |
-
# Normalize to a plain dict for grading and column access
|
| 83 |
self._current_row = dict(row)
|
| 84 |
|
| 85 |
text_context = {
|
|
@@ -118,7 +163,8 @@ class EarningsAnalystEnvironment(Environment):
|
|
| 118 |
|
| 119 |
# Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
|
| 120 |
if "xbrl_columns" in self._cfg:
|
| 121 |
-
|
|
|
|
| 122 |
ground_truth = json.dumps(gt_data)
|
| 123 |
else:
|
| 124 |
ground_truth = str(row.get(label_col, "")).strip()
|
|
@@ -132,7 +178,6 @@ class EarningsAnalystEnvironment(Environment):
|
|
| 132 |
)
|
| 133 |
)
|
| 134 |
|
| 135 |
-
|
| 136 |
return EarningsAnalystObservation(
|
| 137 |
text_context={},
|
| 138 |
numerical_context={},
|
|
@@ -146,7 +191,6 @@ class EarningsAnalystEnvironment(Environment):
|
|
| 146 |
},
|
| 147 |
)
|
| 148 |
|
| 149 |
-
|
| 150 |
@property
|
| 151 |
def state(self) -> State:
|
| 152 |
"""Current environment state."""
|
|
|
|
| 11 |
import os
|
| 12 |
import json
|
| 13 |
import random
|
| 14 |
+
from typing import Any, cast
|
| 15 |
from uuid import uuid4
|
| 16 |
|
| 17 |
from openenv.core.env_server.interfaces import Environment
|
|
|
|
| 23 |
from earnings_analyst.tasks.registry import get_grader
|
| 24 |
|
| 25 |
from .dataset_loader import dataset
|
| 26 |
+
from .episode_index import get_episode_index
|
| 27 |
|
| 28 |
|
| 29 |
def _resolve_task_id(explicit: str | None) -> str:
|
|
|
|
| 69 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 70 |
self._current_row: dict[str, Any] | None = None
|
| 71 |
|
| 72 |
+
def reset(
|
| 73 |
+
self,
|
| 74 |
+
seed: int | None = None,
|
| 75 |
+
episode_id: str | None = None,
|
| 76 |
+
**kwargs: Any,
|
| 77 |
+
) -> EarningsAnalystObservation:
|
| 78 |
+
"""Sample one dataset row and return the agent-visible observation bundle.
|
| 79 |
+
|
| 80 |
+
If ``pick_symbol``, ``pick_year``, and ``pick_quarter`` are passed (Gradio),
|
| 81 |
+
load that row by ``(symbol, year, quarter)``. Otherwise sample uniformly at
|
| 82 |
+
random (default for HTTP clients).
|
| 83 |
+
|
| 84 |
+
Optional ``seed`` affects only the random row path.
|
| 85 |
+
"""
|
| 86 |
+
pick_symbol = kwargs.pop("pick_symbol", None)
|
| 87 |
+
pick_year = kwargs.pop("pick_year", None)
|
| 88 |
+
pick_quarter = kwargs.pop("pick_quarter", None)
|
| 89 |
+
if kwargs:
|
| 90 |
+
raise TypeError(f"Unexpected keyword arguments: {sorted(kwargs)!r}")
|
| 91 |
+
|
| 92 |
if not self._cfg["implemented"]:
|
| 93 |
raise TaskNotImplementedError(
|
| 94 |
f"Task {self._task_id!r} is not implemented yet. "
|
| 95 |
f"Set implemented=True and fill spec/grader under tasks/ when ready."
|
| 96 |
)
|
| 97 |
|
| 98 |
+
eid = episode_id if episode_id is not None else str(uuid4())
|
| 99 |
+
self._state = State(episode_id=eid, step_count=0)
|
| 100 |
+
if (
|
| 101 |
+
pick_symbol is not None
|
| 102 |
+
and str(pick_symbol).strip() != ""
|
| 103 |
+
and pick_year is not None
|
| 104 |
+
and pick_quarter is not None
|
| 105 |
+
):
|
| 106 |
+
sym = str(pick_symbol).strip()
|
| 107 |
+
try:
|
| 108 |
+
yi = int(pick_year) if not isinstance(pick_year, int) else pick_year
|
| 109 |
+
qi = (
|
| 110 |
+
int(pick_quarter)
|
| 111 |
+
if not isinstance(pick_quarter, int)
|
| 112 |
+
else pick_quarter
|
| 113 |
+
)
|
| 114 |
+
except (TypeError, ValueError) as e:
|
| 115 |
+
raise ValueError("Year and quarter must be integers.") from e
|
| 116 |
+
try:
|
| 117 |
+
idx = get_episode_index().row_index(sym, yi, qi)
|
| 118 |
+
except KeyError as e:
|
| 119 |
+
raise ValueError(str(e)) from None
|
| 120 |
+
else:
|
| 121 |
+
rng = random.Random(seed) if seed is not None else random
|
| 122 |
+
idx = rng.randrange(len(dataset))
|
| 123 |
+
|
| 124 |
+
return self._load_row_at(idx)
|
| 125 |
+
|
| 126 |
+
def _load_row_at(self, idx: int) -> EarningsAnalystObservation:
|
| 127 |
row = dataset[idx]
|
|
|
|
| 128 |
self._current_row = dict(row)
|
| 129 |
|
| 130 |
text_context = {
|
|
|
|
| 163 |
|
| 164 |
# Handle composite ground truth if multiple columns are specified (e.g. for get_figures)
|
| 165 |
if "xbrl_columns" in self._cfg:
|
| 166 |
+
xbrl_cols = cast(list[str], self._cfg["xbrl_columns"])
|
| 167 |
+
gt_data = {col: row.get(col) for col in xbrl_cols}
|
| 168 |
ground_truth = json.dumps(gt_data)
|
| 169 |
else:
|
| 170 |
ground_truth = str(row.get(label_col, "")).strip()
|
|
|
|
| 178 |
)
|
| 179 |
)
|
| 180 |
|
|
|
|
| 181 |
return EarningsAnalystObservation(
|
| 182 |
text_context={},
|
| 183 |
numerical_context={},
|
|
|
|
| 191 |
},
|
| 192 |
)
|
| 193 |
|
|
|
|
| 194 |
@property
|
| 195 |
def state(self) -> State:
|
| 196 |
"""Current environment state."""
|
server/episode_index.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build lookups over the loaded Hugging Face split for Gradio (company / year / quarter).
|
| 3 |
+
|
| 4 |
+
Rows are keyed by ``(symbol, year, quarter)`` and must match ``episode_id`` when present
|
| 5 |
+
(``symbol_year_Qquarter``, e.g. ``A_2006_Q1``).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from earnings_analyst.environment_config import (
|
| 15 |
+
DATASET_COMPANY_NAME_COLUMN,
|
| 16 |
+
DATASET_EPISODE_ID_COLUMN,
|
| 17 |
+
DATASET_QUARTER_COLUMN,
|
| 18 |
+
DATASET_SYMBOL_COLUMN,
|
| 19 |
+
DATASET_YEAR_COLUMN,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# First sidebar option: sample a random row (same behavior as HTTP clients).
|
| 23 |
+
RANDOM_EPISODE_LABEL = "— Random row —"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _normalize_cell(value: Any) -> str | None:
|
| 27 |
+
if value is None:
|
| 28 |
+
return None
|
| 29 |
+
s = str(value).strip()
|
| 30 |
+
return s if s else None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _as_int(value: Any) -> int | None:
|
| 34 |
+
if value is None:
|
| 35 |
+
return None
|
| 36 |
+
try:
|
| 37 |
+
x = int(round(float(value)))
|
| 38 |
+
except (TypeError, ValueError):
|
| 39 |
+
return None
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def format_episode_id(symbol: str, year: int, quarter: int) -> str:
|
| 44 |
+
"""Canonical ``episode_id`` pattern: ``{symbol}_{year}_Q{quarter}``."""
|
| 45 |
+
return f"{symbol}_{year}_Q{quarter}"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EpisodeIndex:
|
| 49 |
+
"""Maps ticker symbol + fiscal year + quarter to a dataset row index."""
|
| 50 |
+
|
| 51 |
+
_REQUIRED_COLUMNS: tuple[str, ...] = (
|
| 52 |
+
DATASET_EPISODE_ID_COLUMN,
|
| 53 |
+
DATASET_SYMBOL_COLUMN,
|
| 54 |
+
DATASET_COMPANY_NAME_COLUMN,
|
| 55 |
+
DATASET_YEAR_COLUMN,
|
| 56 |
+
DATASET_QUARTER_COLUMN,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
dataset: Any,
|
| 62 |
+
episode_id_col: str = DATASET_EPISODE_ID_COLUMN,
|
| 63 |
+
symbol_col: str = DATASET_SYMBOL_COLUMN,
|
| 64 |
+
company_name_col: str = DATASET_COMPANY_NAME_COLUMN,
|
| 65 |
+
year_col: str = DATASET_YEAR_COLUMN,
|
| 66 |
+
quarter_col: str = DATASET_QUARTER_COLUMN,
|
| 67 |
+
) -> None:
|
| 68 |
+
names = dataset.column_names
|
| 69 |
+
missing = [c for c in self._REQUIRED_COLUMNS if c not in names]
|
| 70 |
+
if missing:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"Dataset is missing column(s) {missing}. "
|
| 73 |
+
f"Check environment_config.py. Available columns: {names!r}."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self._symbol_to_display: dict[str, str] = {}
|
| 77 |
+
self._display_to_symbol: dict[str, str] = {}
|
| 78 |
+
self._symbol_to_years: dict[str, set[int]] = defaultdict(set)
|
| 79 |
+
self._symbol_year_to_quarters: dict[tuple[str, int], set[int]] = defaultdict(
|
| 80 |
+
set
|
| 81 |
+
)
|
| 82 |
+
self._triple_to_index: dict[tuple[str, int, int], int] = {}
|
| 83 |
+
duplicate_triples: set[tuple[str, int, int]] = set()
|
| 84 |
+
|
| 85 |
+
n = len(dataset)
|
| 86 |
+
for i in range(n):
|
| 87 |
+
row = dataset[i]
|
| 88 |
+
sym = _normalize_cell(row.get(symbol_col))
|
| 89 |
+
cname = _normalize_cell(row.get(company_name_col))
|
| 90 |
+
yr = _as_int(row.get(year_col))
|
| 91 |
+
qn = _as_int(row.get(quarter_col))
|
| 92 |
+
if not sym or yr is None or qn is None:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
if sym not in self._symbol_to_display:
|
| 96 |
+
label = cname or sym
|
| 97 |
+
display = f"{label} ({sym})"
|
| 98 |
+
self._symbol_to_display[sym] = display
|
| 99 |
+
self._display_to_symbol[display] = sym
|
| 100 |
+
|
| 101 |
+
triple = (sym, yr, qn)
|
| 102 |
+
self._symbol_to_years[sym].add(yr)
|
| 103 |
+
self._symbol_year_to_quarters[(sym, yr)].add(qn)
|
| 104 |
+
|
| 105 |
+
if triple in self._triple_to_index:
|
| 106 |
+
duplicate_triples.add(triple)
|
| 107 |
+
else:
|
| 108 |
+
self._triple_to_index[triple] = i
|
| 109 |
+
# episode_id_col is validated to exist; row lookup uses (symbol, year, quarter).
|
| 110 |
+
|
| 111 |
+
if duplicate_triples:
|
| 112 |
+
warnings.warn(
|
| 113 |
+
"Duplicate (symbol, year, quarter) rows; using the first index for each.",
|
| 114 |
+
UserWarning,
|
| 115 |
+
stacklevel=1,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def sorted_company_displays(self) -> list[str]:
|
| 119 |
+
return sorted(self._display_to_symbol.keys())
|
| 120 |
+
|
| 121 |
+
def symbol_for_display(self, display: str) -> str:
|
| 122 |
+
if display not in self._display_to_symbol:
|
| 123 |
+
raise KeyError(f"Unknown company selection: {display!r}")
|
| 124 |
+
return self._display_to_symbol[display]
|
| 125 |
+
|
| 126 |
+
def years_for_symbol(self, symbol: str) -> list[int]:
|
| 127 |
+
return sorted(self._symbol_to_years.get(symbol, []))
|
| 128 |
+
|
| 129 |
+
def quarters_for(self, symbol: str, year: int) -> list[int]:
|
| 130 |
+
return sorted(self._symbol_year_to_quarters.get((symbol, year), []))
|
| 131 |
+
|
| 132 |
+
def row_index(self, symbol: str, year: int, quarter: int) -> int:
|
| 133 |
+
key = (symbol, year, quarter)
|
| 134 |
+
if key not in self._triple_to_index:
|
| 135 |
+
raise KeyError(
|
| 136 |
+
f"No row for symbol={symbol!r}, year={year!r}, quarter={quarter!r}. "
|
| 137 |
+
"Pick a valid combination from the dropdowns."
|
| 138 |
+
)
|
| 139 |
+
return self._triple_to_index[key]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
_instance: EpisodeIndex | None = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_episode_index() -> EpisodeIndex:
|
| 146 |
+
"""Lazily build the index over ``dataset_loader.dataset`` (single pass)."""
|
| 147 |
+
global _instance
|
| 148 |
+
if _instance is None:
|
| 149 |
+
from earnings_analyst.server.dataset_loader import dataset
|
| 150 |
+
|
| 151 |
+
_instance = EpisodeIndex(dataset)
|
| 152 |
+
return _instance
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
__all__ = [
|
| 156 |
+
"EpisodeIndex",
|
| 157 |
+
"format_episode_id",
|
| 158 |
+
"get_episode_index",
|
| 159 |
+
"RANDOM_EPISODE_LABEL",
|
| 160 |
+
]
|
tests/test_episode_index.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for :mod:`earnings_analyst.server.episode_index`."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from datasets import Dataset
|
| 7 |
+
|
| 8 |
+
from earnings_analyst.server.episode_index import (
|
| 9 |
+
EpisodeIndex,
|
| 10 |
+
RANDOM_EPISODE_LABEL,
|
| 11 |
+
format_episode_id,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_format_episode_id() -> None:
|
| 16 |
+
assert format_episode_id("A", 2006, 1) == "A_2006_Q1"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_episode_index_lookup_and_cascades() -> None:
|
| 20 |
+
ds = Dataset.from_dict(
|
| 21 |
+
{
|
| 22 |
+
"episode_id": ["AAA_2024_Q1", "AAA_2024_Q2", "BBB_2024_Q1"],
|
| 23 |
+
"symbol": ["AAA", "AAA", "BBB"],
|
| 24 |
+
"company_name": ["Alpha Co", "Alpha Co", "Beta Co"],
|
| 25 |
+
"year": [2024, 2024, 2024],
|
| 26 |
+
"quarter": [1, 2, 1],
|
| 27 |
+
}
|
| 28 |
+
)
|
| 29 |
+
idx = EpisodeIndex(ds)
|
| 30 |
+
d_aaa = "Alpha Co (AAA)"
|
| 31 |
+
d_bbb = "Beta Co (BBB)"
|
| 32 |
+
assert idx.sorted_company_displays() == [d_aaa, d_bbb]
|
| 33 |
+
assert idx.symbol_for_display(d_aaa) == "AAA"
|
| 34 |
+
assert idx.years_for_symbol("AAA") == [2024]
|
| 35 |
+
assert idx.quarters_for("AAA", 2024) == [1, 2]
|
| 36 |
+
assert idx.row_index("AAA", 2024, 1) == 0
|
| 37 |
+
assert idx.row_index("AAA", 2024, 2) == 1
|
| 38 |
+
assert idx.row_index("BBB", 2024, 1) == 2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_episode_index_duplicate_triple_warns() -> None:
|
| 42 |
+
ds = Dataset.from_dict(
|
| 43 |
+
{
|
| 44 |
+
"episode_id": ["X_2024_Q1", "X_2024_Q1_dup"],
|
| 45 |
+
"symbol": ["X", "X"],
|
| 46 |
+
"company_name": ["X", "X"],
|
| 47 |
+
"year": [2024, 2024],
|
| 48 |
+
"quarter": [1, 1],
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
with pytest.warns(UserWarning, match="Duplicate"):
|
| 52 |
+
idx = EpisodeIndex(ds)
|
| 53 |
+
assert idx.row_index("X", 2024, 1) == 0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_episode_index_skips_incomplete_rows() -> None:
|
| 57 |
+
ds = Dataset.from_dict(
|
| 58 |
+
{
|
| 59 |
+
"episode_id": ["_", "Z_2024_Q1"],
|
| 60 |
+
"symbol": ["", "ZZZ"],
|
| 61 |
+
"company_name": ["", "Zed"],
|
| 62 |
+
"year": [2024, 2024],
|
| 63 |
+
"quarter": [1, 1],
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
+
idx = EpisodeIndex(ds)
|
| 67 |
+
assert idx.sorted_company_displays() == ["Zed (ZZZ)"]
|
| 68 |
+
assert idx.row_index("ZZZ", 2024, 1) == 1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_episode_index_missing_column_raises() -> None:
|
| 72 |
+
ds = Dataset.from_dict({"a": [1]})
|
| 73 |
+
with pytest.raises(ValueError, match="missing column"):
|
| 74 |
+
EpisodeIndex(ds)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_random_label_constant() -> None:
|
| 78 |
+
assert "Random" in RANDOM_EPISODE_LABEL
|