RudrakshNanavaty commited on
Commit
e8a3e4b
·
1 Parent(s): 59d9fa9

updated pyproject.toml

Browse files
Files changed (3) hide show
  1. app.py +30 -31
  2. pyproject.toml +0 -7
  3. uv.lock +6 -0
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
  import sys
3
- import asyncio
4
  import json
5
  import gradio as gr
6
- from typing import Optional, Dict, Any, List
7
  from dotenv import load_dotenv
8
  from openai import AsyncOpenAI
9
 
@@ -13,11 +12,11 @@ sys.path.append(os.path.abspath(os.path.dirname(__file__)))
13
  try:
14
  from earnings_analyst.server.earnings_analyst_environment import EarningsAnalystEnvironment
15
  from earnings_analyst.models import EarningsAnalystAction, EarningsAnalystObservation
16
- from earnings_analyst.tasks.registry import TASK_IDS, TASKS
17
  except (ImportError, ModuleNotFoundError):
18
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
19
  from models import EarningsAnalystAction, EarningsAnalystObservation
20
- from tasks.registry import TASK_IDS, TASKS
21
 
22
  load_dotenv()
23
 
@@ -33,15 +32,15 @@ async def reset_env(task_id: str):
33
  state.task_id = task_id
34
  state.env = EarningsAnalystEnvironment(task_id=task_id)
35
  state.obs = state.env.reset()
36
-
37
  # Format observation for display
38
  text_context = ""
39
  if state.obs.text_context:
40
  for name, text in sorted(state.obs.text_context.items()):
41
  text_context += f"### {name}\n{text}\n\n"
42
-
43
  numerical_context = json.dumps(state.obs.numerical_context, indent=2) if state.obs.numerical_context else "No numerical data."
44
-
45
  return [
46
  state.obs.task_instruction,
47
  text_context,
@@ -55,15 +54,15 @@ async def reset_env(task_id: str):
55
  async def step_env(prediction: str):
56
  if not state.env or not state.obs:
57
  return [gr.update(), "Error: Environment not initialized. Click Reset."]
58
-
59
  action = EarningsAnalystAction(prediction=prediction)
60
  state.obs = state.env.step(action)
61
-
62
  reward = state.obs.reward
63
  ground_truth = getattr(state.obs, "ground_truth", "N/A")
64
-
65
  result_text = f"**Reward:** {reward:.4f} \n\n**Ground Truth:** {ground_truth}"
66
-
67
  return [
68
  gr.update(visible=False), # Hide prediction row
69
  gr.update(visible=True, value=result_text), # Show result row
@@ -72,17 +71,17 @@ async def step_env(prediction: str):
72
  async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
73
  if not api_key:
74
  return [gr.update()] * 6 + ["Please provide an API Key."]
75
-
76
  # 1. Reset
77
  out = await reset_env(task_id)
78
-
79
  # 2. Predict with LLM
80
  client_params = {"api_key": api_key}
81
  if base_url:
82
  client_params["base_url"] = base_url
83
-
84
  client = AsyncOpenAI(**client_params)
85
-
86
  user_content = f"{state.obs.task_instruction}\n\n"
87
  if state.obs.text_context:
88
  user_content += "## Text context\n"
@@ -90,13 +89,13 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
90
  user_content += f"### {name}\n{text}\n"
91
  if state.obs.numerical_context:
92
  user_content += f"\n## Numerical context\n{json.dumps(state.obs.numerical_context)}\n"
93
-
94
  system_prompt = (
95
  "You are a financial analyst assistant. "
96
  "Analyze the data and respond EXACTLY as instructed. "
97
  "Reply with a single JSON object containing 'prediction' key."
98
  )
99
-
100
  try:
101
  completion = await client.chat.completions.create(
102
  model=model,
@@ -109,31 +108,31 @@ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
109
  response_text = completion.choices[0].message.content or "{}"
110
  parsed = json.loads(response_text)
111
  prediction = str(parsed.get("prediction", response_text))
112
-
113
  # 3. Step
114
  step_out = await step_env(prediction)
115
-
116
  # Update UI components
117
  # out: [instr, text, num, pred_row, res_row, pred_input, msg]
118
  # step_out: [pred_row, res_row]
119
-
120
  return [
121
- out[0], out[1], out[2],
122
- step_out[0], step_out[1],
123
- prediction,
124
  f"Agent used {model}. Raw response: {response_text}"
125
  ]
126
-
127
  except Exception as e:
128
  return [out[0], out[1], out[2], out[3], out[4], "", f"Error: {str(e)}"]
129
 
130
  # Custom CSS for a premium look
131
  custom_css = """
132
  footer {visibility: hidden}
133
- .container {
134
- max-width: 1100px;
135
- margin: auto;
136
- padding-top: 2rem;
137
  font-family: 'Inter', system-ui, -apple-system, sans-serif;
138
  }
139
  .header { text-align: center; margin-bottom: 2rem; }
@@ -154,7 +153,7 @@ with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
154
  with gr.Group():
155
  task_select = gr.Dropdown(choices=TASK_IDS, value="sentiment_label", label="Active Task")
156
  reset_btn = gr.Button("🔄 New Episode", variant="primary")
157
-
158
  gr.Markdown("### 🤖 Auto-Agent Settings")
159
  with gr.Group():
160
  api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...", value=os.environ.get("OPENAI_API_KEY", ""))
@@ -173,12 +172,12 @@ with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
173
  with gr.Column():
174
  gr.Markdown("#### Numerical Context")
175
  num_view = gr.Code(label="JSON", language="json", elem_classes="context-box")
176
-
177
  with gr.TabItem("Analysis"):
178
  with gr.Column(visible=False) as prediction_row:
179
  pred_input = gr.Textbox(label="Your Prediction / Analysis Output", placeholder="e.g. bullish, or 0.05")
180
  submit_btn = gr.Button("Submit Analysis", variant="primary")
181
-
182
  result_view = gr.Markdown(visible=False, elem_classes="card")
183
  message_view = gr.Textbox(label="Agent Log / Error Messages", interactive=False)
184
 
 
1
  import os
2
  import sys
 
3
  import json
4
  import gradio as gr
5
+ from typing import Optional
6
  from dotenv import load_dotenv
7
  from openai import AsyncOpenAI
8
 
 
12
  try:
13
  from earnings_analyst.server.earnings_analyst_environment import EarningsAnalystEnvironment
14
  from earnings_analyst.models import EarningsAnalystAction, EarningsAnalystObservation
15
+ from earnings_analyst.tasks.registry import TASK_IDS
16
  except (ImportError, ModuleNotFoundError):
17
  from server.earnings_analyst_environment import EarningsAnalystEnvironment
18
  from models import EarningsAnalystAction, EarningsAnalystObservation
19
+ from tasks.registry import TASK_IDS
20
 
21
  load_dotenv()
22
 
 
32
  state.task_id = task_id
33
  state.env = EarningsAnalystEnvironment(task_id=task_id)
34
  state.obs = state.env.reset()
35
+
36
  # Format observation for display
37
  text_context = ""
38
  if state.obs.text_context:
39
  for name, text in sorted(state.obs.text_context.items()):
40
  text_context += f"### {name}\n{text}\n\n"
41
+
42
  numerical_context = json.dumps(state.obs.numerical_context, indent=2) if state.obs.numerical_context else "No numerical data."
43
+
44
  return [
45
  state.obs.task_instruction,
46
  text_context,
 
54
  async def step_env(prediction: str):
55
  if not state.env or not state.obs:
56
  return [gr.update(), "Error: Environment not initialized. Click Reset."]
57
+
58
  action = EarningsAnalystAction(prediction=prediction)
59
  state.obs = state.env.step(action)
60
+
61
  reward = state.obs.reward
62
  ground_truth = getattr(state.obs, "ground_truth", "N/A")
63
+
64
  result_text = f"**Reward:** {reward:.4f} \n\n**Ground Truth:** {ground_truth}"
65
+
66
  return [
67
  gr.update(visible=False), # Hide prediction row
68
  gr.update(visible=True, value=result_text), # Show result row
 
71
  async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
72
  if not api_key:
73
  return [gr.update()] * 6 + ["Please provide an API Key."]
74
+
75
  # 1. Reset
76
  out = await reset_env(task_id)
77
+
78
  # 2. Predict with LLM
79
  client_params = {"api_key": api_key}
80
  if base_url:
81
  client_params["base_url"] = base_url
82
+
83
  client = AsyncOpenAI(**client_params)
84
+
85
  user_content = f"{state.obs.task_instruction}\n\n"
86
  if state.obs.text_context:
87
  user_content += "## Text context\n"
 
89
  user_content += f"### {name}\n{text}\n"
90
  if state.obs.numerical_context:
91
  user_content += f"\n## Numerical context\n{json.dumps(state.obs.numerical_context)}\n"
92
+
93
  system_prompt = (
94
  "You are a financial analyst assistant. "
95
  "Analyze the data and respond EXACTLY as instructed. "
96
  "Reply with a single JSON object containing 'prediction' key."
97
  )
98
+
99
  try:
100
  completion = await client.chat.completions.create(
101
  model=model,
 
108
  response_text = completion.choices[0].message.content or "{}"
109
  parsed = json.loads(response_text)
110
  prediction = str(parsed.get("prediction", response_text))
111
+
112
  # 3. Step
113
  step_out = await step_env(prediction)
114
+
115
  # Update UI components
116
  # out: [instr, text, num, pred_row, res_row, pred_input, msg]
117
  # step_out: [pred_row, res_row]
118
+
119
  return [
120
+ out[0], out[1], out[2],
121
+ step_out[0], step_out[1],
122
+ prediction,
123
  f"Agent used {model}. Raw response: {response_text}"
124
  ]
125
+
126
  except Exception as e:
127
  return [out[0], out[1], out[2], out[3], out[4], "", f"Error: {str(e)}"]
128
 
129
  # Custom CSS for a premium look
130
  custom_css = """
131
  footer {visibility: hidden}
132
+ .container {
133
+ max-width: 1100px;
134
+ margin: auto;
135
+ padding-top: 2rem;
136
  font-family: 'Inter', system-ui, -apple-system, sans-serif;
137
  }
138
  .header { text-align: center; margin-bottom: 2rem; }
 
153
  with gr.Group():
154
  task_select = gr.Dropdown(choices=TASK_IDS, value="sentiment_label", label="Active Task")
155
  reset_btn = gr.Button("🔄 New Episode", variant="primary")
156
+
157
  gr.Markdown("### 🤖 Auto-Agent Settings")
158
  with gr.Group():
159
  api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...", value=os.environ.get("OPENAI_API_KEY", ""))
 
172
  with gr.Column():
173
  gr.Markdown("#### Numerical Context")
174
  num_view = gr.Code(label="JSON", language="json", elem_classes="context-box")
175
+
176
  with gr.TabItem("Analysis"):
177
  with gr.Column(visible=False) as prediction_row:
178
  pred_input = gr.Textbox(label="Your Prediction / Analysis Output", placeholder="e.g. bullish, or 0.05")
179
  submit_btn = gr.Button("Submit Analysis", variant="primary")
180
+
181
  result_view = gr.Markdown(visible=False, elem_classes="card")
182
  message_view = gr.Textbox(label="Agent Log / Error Messages", interactive=False)
183
 
pyproject.toml CHANGED
@@ -44,13 +44,6 @@ packages = [
44
  # Root of this repo is the `earnings_analyst` package (setuptools package-dir `.`).
45
  package-dir = { "earnings_analyst" = ".", "earnings_analyst.server" = "server", "earnings_analyst.tasks" = "tasks", "earnings_analyst.tasks.get_figures" = "tasks/get_figures", "earnings_analyst.tasks.sentiment_label" = "tasks/sentiment_label", "earnings_analyst.tasks.next_quarter_move" = "tasks/next_quarter_move", "earnings_analyst.tasks.one_day_move" = "tasks/one_day_move", "earnings_analyst.tasks.thirty_day_move" = "tasks/thirty_day_move" }
46
 
47
- # Pyright resolves imports from disk; parent on extraPaths makes `earnings_analyst` resolve
48
- # to this repo folder when it is named `earnings_analyst` (matches setuptools mapping above).
49
- [tool.pyright]
50
- pythonVersion = "3.12"
51
- venvPath = ".venv"
52
- extraPaths = [".."]
53
-
54
  [tool.basedpyright]
55
  pythonVersion = "3.12"
56
  venvPath = ".venv"
 
44
  # Root of this repo is the `earnings_analyst` package (setuptools package-dir `.`).
45
  package-dir = { "earnings_analyst" = ".", "earnings_analyst.server" = "server", "earnings_analyst.tasks" = "tasks", "earnings_analyst.tasks.get_figures" = "tasks/get_figures", "earnings_analyst.tasks.sentiment_label" = "tasks/sentiment_label", "earnings_analyst.tasks.next_quarter_move" = "tasks/next_quarter_move", "earnings_analyst.tasks.one_day_move" = "tasks/one_day_move", "earnings_analyst.tasks.thirty_day_move" = "tasks/thirty_day_move" }
46
 
 
 
 
 
 
 
 
47
  [tool.basedpyright]
48
  pythonVersion = "3.12"
49
  venvPath = ".venv"
uv.lock CHANGED
@@ -1665,10 +1665,13 @@ version = "0.1.0"
1665
  source = { editable = "." }
1666
  dependencies = [
1667
  { name = "datasets" },
 
 
1668
  { name = "huggingface-hub" },
1669
  { name = "openai" },
1670
  { name = "openenv-core", extra = ["core"] },
1671
  { name = "python-dotenv" },
 
1672
  ]
1673
 
1674
  [package.optional-dependencies]
@@ -1680,12 +1683,15 @@ dev = [
1680
  [package.metadata]
1681
  requires-dist = [
1682
  { name = "datasets", specifier = ">=4.8.4" },
 
 
1683
  { name = "huggingface-hub", specifier = ">=1.10.1" },
1684
  { name = "openai", specifier = ">=2.31.0" },
1685
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
1686
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
1687
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" },
1688
  { name = "python-dotenv", specifier = ">=1.2.2" },
 
1689
  ]
1690
  provides-extras = ["dev"]
1691
 
 
1665
  source = { editable = "." }
1666
  dependencies = [
1667
  { name = "datasets" },
1668
+ { name = "fastapi" },
1669
+ { name = "gradio" },
1670
  { name = "huggingface-hub" },
1671
  { name = "openai" },
1672
  { name = "openenv-core", extra = ["core"] },
1673
  { name = "python-dotenv" },
1674
+ { name = "uvicorn" },
1675
  ]
1676
 
1677
  [package.optional-dependencies]
 
1683
  [package.metadata]
1684
  requires-dist = [
1685
  { name = "datasets", specifier = ">=4.8.4" },
1686
+ { name = "fastapi", specifier = ">=0.111.0" },
1687
+ { name = "gradio", specifier = ">=4.0.0" },
1688
  { name = "huggingface-hub", specifier = ">=1.10.1" },
1689
  { name = "openai", specifier = ">=2.31.0" },
1690
  { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" },
1691
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
1692
  { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" },
1693
  { name = "python-dotenv", specifier = ">=1.2.2" },
1694
+ { name = "uvicorn", specifier = ">=0.30.0" },
1695
  ]
1696
  provides-extras = ["dev"]
1697