intelkishan commited on
Commit
dd9bd58
Β·
1 Parent(s): 919d04b

Add Gradio UI and Dockerfile for HF Space deployment

Browse files
Files changed (4) hide show
  1. Dockerfile +33 -0
  2. app.py +232 -0
  3. pyproject.toml +3 -0
  4. scratch/inspect_dataset.py +0 -26
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a standard Python image
2
+ FROM python:3.12-slim
3
+
4
+ # Install system dependencies
5
+ RUN apt-get update && apt-get install -y \
6
+ git \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install uv
11
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
12
+ mv /root/.local/bin/uv /usr/local/bin/uv
13
+
14
+ # Set working directory
15
+ WORKDIR /app
16
+
17
+ # Copy project files
18
+ COPY . .
19
+
20
+ # Install dependencies using uv
21
+ # We use --frozen if uv.lock is present, but for HF Space we might want to be flexible
22
+ RUN uv sync --no-cache
23
+
24
+ # Set environment variables
25
+ ENV PATH="/app/.venv/bin:$PATH"
26
+ ENV PYTHONPATH="/app:$PYTHONPATH"
27
+ ENV PYTHONUNBUFFERED=1
28
+
29
+ # Expose the port Gradio will run on
30
+ EXPOSE 8000
31
+
32
+ # Default command to run the Gradio app
33
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ import json
5
+ import gradio as gr
6
+ from typing import Optional, Dict, Any, List
7
+ from dotenv import load_dotenv
8
+ from openai import AsyncOpenAI
9
+
10
+ # Ensure the project root is in sys.path
11
+ sys.path.append(os.path.abspath(os.path.dirname(__file__)))
12
+
13
+ try:
14
+ from earnings_analyst.server.earnings_analyst_environment import EarningsAnalystEnvironment
15
+ from earnings_analyst.models import EarningsAnalystAction, EarningsAnalystObservation
16
+ from earnings_analyst.tasks.registry import TASK_IDS, TASKS
17
+ except (ImportError, ModuleNotFoundError):
18
+ from server.earnings_analyst_environment import EarningsAnalystEnvironment
19
+ from models import EarningsAnalystAction, EarningsAnalystObservation
20
+ from tasks.registry import TASK_IDS, TASKS
21
+
22
+ load_dotenv()
23
+
24
+ class State:
25
+ def __init__(self):
26
+ self.env: Optional[EarningsAnalystEnvironment] = None
27
+ self.obs: Optional[EarningsAnalystObservation] = None
28
+ self.task_id: str = "sentiment_label"
29
+
30
+ state = State()
31
+
32
+ async def reset_env(task_id: str):
33
+ state.task_id = task_id
34
+ state.env = EarningsAnalystEnvironment(task_id=task_id)
35
+ state.obs = state.env.reset()
36
+
37
+ # Format observation for display
38
+ text_context = ""
39
+ if state.obs.text_context:
40
+ for name, text in sorted(state.obs.text_context.items()):
41
+ text_context += f"### {name}\n{text}\n\n"
42
+
43
+ numerical_context = json.dumps(state.obs.numerical_context, indent=2) if state.obs.numerical_context else "No numerical data."
44
+
45
+ return [
46
+ state.obs.task_instruction,
47
+ text_context,
48
+ numerical_context,
49
+ gr.update(visible=True), # Prediction row
50
+ gr.update(visible=False), # Result row
51
+ "", # Prediction input
52
+ "", # Log/Message
53
+ ]
54
+
55
+ async def step_env(prediction: str):
56
+ if not state.env or not state.obs:
57
+ return [gr.update(), "Error: Environment not initialized. Click Reset."]
58
+
59
+ action = EarningsAnalystAction(prediction=prediction)
60
+ state.obs = state.env.step(action)
61
+
62
+ reward = state.obs.reward
63
+ ground_truth = getattr(state.obs, "ground_truth", "N/A")
64
+
65
+ result_text = f"**Reward:** {reward:.4f} \n\n**Ground Truth:** {ground_truth}"
66
+
67
+ return [
68
+ gr.update(visible=False), # Hide prediction row
69
+ gr.update(visible=True, value=result_text), # Show result row
70
+ ]
71
+
72
+ async def run_agent(task_id: str, api_key: str, model: str, base_url: str):
73
+ if not api_key:
74
+ return [gr.update()] * 6 + ["Please provide an API Key."]
75
+
76
+ # 1. Reset
77
+ out = await reset_env(task_id)
78
+
79
+ # 2. Predict with LLM
80
+ client_params = {"api_key": api_key}
81
+ if base_url:
82
+ client_params["base_url"] = base_url
83
+
84
+ client = AsyncOpenAI(**client_params)
85
+
86
+ user_content = f"{state.obs.task_instruction}\n\n"
87
+ if state.obs.text_context:
88
+ user_content += "## Text context\n"
89
+ for name, text in sorted(state.obs.text_context.items()):
90
+ user_content += f"### {name}\n{text}\n"
91
+ if state.obs.numerical_context:
92
+ user_content += f"\n## Numerical context\n{json.dumps(state.obs.numerical_context)}\n"
93
+
94
+ system_prompt = (
95
+ "You are a financial analyst assistant. "
96
+ "Analyze the data and respond EXACTLY as instructed. "
97
+ "Reply with a single JSON object containing 'prediction' key."
98
+ )
99
+
100
+ try:
101
+ completion = await client.chat.completions.create(
102
+ model=model,
103
+ messages=[
104
+ {"role": "system", "content": system_prompt},
105
+ {"role": "user", "content": user_content},
106
+ ],
107
+ response_format={"type": "json_object"},
108
+ )
109
+ response_text = completion.choices[0].message.content or "{}"
110
+ parsed = json.loads(response_text)
111
+ prediction = str(parsed.get("prediction", response_text))
112
+
113
+ # 3. Step
114
+ step_out = await step_env(prediction)
115
+
116
+ # Update UI components
117
+ # out: [instr, text, num, pred_row, res_row, pred_input, msg]
118
+ # step_out: [pred_row, res_row]
119
+
120
+ return [
121
+ out[0], out[1], out[2],
122
+ step_out[0], step_out[1],
123
+ prediction,
124
+ f"Agent used {model}. Raw response: {response_text}"
125
+ ]
126
+
127
+ except Exception as e:
128
+ return [out[0], out[1], out[2], out[3], out[4], "", f"Error: {str(e)}"]
129
+
130
+ # Custom CSS for a premium look
131
+ custom_css = """
132
+ footer {visibility: hidden}
133
+ .container {
134
+ max-width: 1100px;
135
+ margin: auto;
136
+ padding-top: 2rem;
137
+ font-family: 'Inter', system-ui, -apple-system, sans-serif;
138
+ }
139
+ .header { text-align: center; margin-bottom: 2rem; }
140
+ .header h1 { font-weight: 800; font-size: 2.5rem; background: linear-gradient(90deg, #ff4d94, #ff85b3); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
141
+ .header p { color: #666; font-size: 1.1rem; }
142
+ .card { border-radius: 12px; border: 1px solid #eee; background: white; padding: 1.5rem; box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05); }
143
+ .context-box { height: 400px; overflow-y: auto; border: 1px solid #f0f0f0; border-radius: 8px; padding: 1rem; background: #fafafa; }
144
+ """
145
+
146
+ with gr.Blocks(css=custom_css, title="Earnings Analyst - OpenEnv") as demo:
147
+ with gr.Div(elem_classes="container"):
148
+ with gr.Div(elem_classes="header"):
149
+ gr.Markdown("# πŸ‘ Earnings Analyst")
150
+ gr.Markdown("Interactive environment for financial analysis tasks using [OpenEnv](https://github.com/meta-pytorch/OpenEnv). Evaluate agents or your own analysis on earnings call data.")
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=1):
154
+ with gr.Group():
155
+ task_select = gr.Dropdown(choices=TASK_IDS, value="sentiment_label", label="Active Task")
156
+ reset_btn = gr.Button("πŸ”„ New Episode", variant="primary")
157
+
158
+ gr.Markdown("### πŸ€– Auto-Agent Settings")
159
+ with gr.Group():
160
+ api_key = gr.Textbox(label="OpenAI API Key", type="password", placeholder="sk-...", value=os.environ.get("OPENAI_API_KEY", ""))
161
+ model_name = gr.Textbox(label="Model", value="gpt-4o-mini")
162
+ base_url = gr.Textbox(label="Base URL (optional)", placeholder="https://api.openai.com/v1", value=os.environ.get("OPENAI_BASE_URL", ""))
163
+ agent_btn = gr.Button("πŸš€ Run LLM Agent", variant="secondary")
164
+
165
+ with gr.Column(scale=2):
166
+ with gr.Tabs():
167
+ with gr.TabItem("Observation"):
168
+ instr_view = gr.Markdown("### Instruction\n*N/A*")
169
+ with gr.Row():
170
+ with gr.Column():
171
+ gr.Markdown("#### Text Context")
172
+ text_view = gr.Markdown(elem_classes="context-box")
173
+ with gr.Column():
174
+ gr.Markdown("#### Numerical Context")
175
+ num_view = gr.Code(label="JSON", language="json", elem_classes="context-box")
176
+
177
+ with gr.TabItem("Analysis"):
178
+ with gr.Column(visible=False) as prediction_row:
179
+ pred_input = gr.Textbox(label="Your Prediction / Analysis Output", placeholder="e.g. bullish, or 0.05")
180
+ submit_btn = gr.Button("Submit Analysis", variant="primary")
181
+
182
+ result_view = gr.Markdown(visible=False, elem_classes="card")
183
+ message_view = gr.Textbox(label="Agent Log / Error Messages", interactive=False)
184
+
185
+ # Event handlers
186
+ reset_btn.click(
187
+ fn=reset_env,
188
+ inputs=[task_select],
189
+ outputs=[instr_view, text_view, num_view, prediction_row, result_view, pred_input, message_view]
190
+ )
191
+
192
+ submit_btn.click(
193
+ fn=step_env,
194
+ inputs=[pred_input],
195
+ outputs=[prediction_row, result_view]
196
+ )
197
+
198
+ agent_btn.click(
199
+ fn=run_agent,
200
+ inputs=[task_select, api_key, model_name, base_url],
201
+ outputs=[instr_view, text_view, num_view, prediction_row, result_view, pred_input, message_view]
202
+ )
203
+
204
+ from fastapi import FastAPI
205
+ from fastapi.middleware.cors import CORSMiddleware
206
+
207
+ # Create the main FastAPI app
208
+ api_app = FastAPI(title="Earnings Analyst API")
209
+
210
+ # Add CORS middleware
211
+ api_app.add_middleware(
212
+ CORSMiddleware,
213
+ allow_origins=["*"],
214
+ allow_methods=["*"],
215
+ allow_headers=["*"],
216
+ )
217
+
218
+ # Import the environment app
219
+ try:
220
+ from earnings_analyst.server.app import app as env_app
221
+ except (ImportError, ModuleNotFoundError):
222
+ from server.app import app as env_app
223
+
224
+ # Mount the environment API at /api
225
+ api_app.mount("/api", env_app)
226
+
227
+ # Wrap Gradio with FastAPI
228
+ app = gr.mount_gradio_app(api_app, demo, path="/")
229
+
230
+ if __name__ == "__main__":
231
+ import uvicorn
232
+ uvicorn.run(app, host="0.0.0.0", port=8000)
pyproject.toml CHANGED
@@ -16,6 +16,9 @@ dependencies = [
16
  "openai>=2.31.0",
17
  "openenv-core[core]>=0.2.2",
18
  "python-dotenv>=1.2.2",
 
 
 
19
  ]
20
 
21
  [project.optional-dependencies]
 
16
  "openai>=2.31.0",
17
  "openenv-core[core]>=0.2.2",
18
  "python-dotenv>=1.2.2",
19
+ "gradio>=4.0.0",
20
+ "uvicorn>=0.30.0",
21
+ "fastapi>=0.111.0",
22
  ]
23
 
24
  [project.optional-dependencies]
scratch/inspect_dataset.py DELETED
@@ -1,26 +0,0 @@
1
- from datasets import load_dataset
2
- import os
3
-
4
- DATASET_ID = "RudrakshNanavaty/earnings-call-data"
5
- DATASET_FILE = "episodes_press_release_8k.parquet"
6
-
7
- print(f"Loading dataset {DATASET_ID} file {DATASET_FILE}...")
8
- dataset = load_dataset(
9
- DATASET_ID,
10
- data_files={"train": DATASET_FILE},
11
- split="train",
12
- )
13
-
14
- print("\nColumns:")
15
- print(dataset.column_names)
16
-
17
- print("\nFirst row summary:")
18
- row = dataset[0]
19
- for k, v in row.items():
20
- if v is not None:
21
- val_str = str(v)
22
- if len(val_str) > 100:
23
- val_str = val_str[:100] + "..."
24
- print(f"{k}: {val_str}")
25
- else:
26
- print(f"{k}: None")