Badumetsibb commited on
Commit
2de937d
·
verified ·
1 Parent(s): 0959af5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -30
app.py CHANGED
@@ -18,6 +18,8 @@ import torch.nn.functional as F
18
  from sentence_transformers import SentenceTransformer
19
  from sklearn.preprocessing import StandardScaler
20
  import gradio as gr
 
 
21
 
22
  # --- Configuration ---
23
  DEVICE = 'cpu'
@@ -26,15 +28,15 @@ VQ_DIM = 64
26
  REGIME_EMBED_DIM = 8
27
  EVENT_EMBED_DIM = 384
28
  BASE_JSON_WEEK = "https://nfs.faireconomy.media/ff_calendar_thisweek.json"
29
- CACHE_DURATION_SECONDS = 600 # Cache data for 10 minutes
30
 
31
  BASE_CHANNELS = ['open', 'high', 'low', 'close']
32
  DERIVED_CHANNELS = ['Volume', 'VolMissing', 'Logclose', 'Return']
33
  ALL_CHANNELS = BASE_CHANNELS + DERIVED_CHANNELS
34
  IMPACT_MAP = {"High": 3, "Medium": 2, "Low": 1, "Holiday": 0}
35
 
36
- # --- Simple In-Memory Cache ---
37
- _CACHE = {"data": None, "timestamp": 0}
38
 
39
  # --- Model Architecture Definitions ---
40
  class SmallTCN(nn.Module):
@@ -78,24 +80,36 @@ class ModelSingleton:
78
 
79
  # --- Data Fetching and Processing Functions ---
80
  def fetch_week_json_with_cache():
81
- """Fetches data from the API, using a time-based cache to avoid rate limits."""
82
  current_time = time.time()
83
- # Check if cache is valid
84
- if _CACHE["data"] and (current_time - _CACHE["timestamp"] < CACHE_DURATION_SECONDS):
85
  print("Using cached event data.")
86
- return _CACHE["data"]
87
-
88
  print("Fetching live economic events (cache was stale)...")
89
  try:
90
  response = requests.get(BASE_JSON_WEEK, headers={"User-Agent": "kairos-demo/1.0"}, timeout=10)
91
  response.raise_for_status()
92
  data = response.json()
93
- # Update cache
94
- _CACHE["data"] = data
95
- _CACHE["timestamp"] = current_time
96
  return data
97
- except requests.RequestException as e:
98
- raise gr.Error(f"Failed to fetch data. Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def parse_financial_number(s: str) -> float:
101
  if not isinstance(s, str) or not s: return np.nan
@@ -113,53 +127,78 @@ def format_events_for_kairos(events: list, encoder: SentenceTransformer) -> torc
113
  features.append(torch.cat((embedding, nums)))
114
  return torch.nan_to_num(torch.stack(features))
115
 
 
 
 
 
 
 
 
116
  # --- Main Prediction Function ---
117
  @torch.no_grad()
118
  def predict_action():
 
119
  models = ModelSingleton()
120
  kairos_model, sentence_encoder = models.kairos_model, models.sentence_encoder
121
-
 
 
 
122
  live_events = fetch_week_json_with_cache()
 
 
 
123
  event_tensor = format_events_for_kairos(live_events, sentence_encoder)
124
  if event_tensor.numel() > 0: event_tensor = event_tensor.unsqueeze(0)
125
 
126
- dummy_price = torch.randn(1, HISTORY_LEN, len(ALL_CHANNELS), device=DEVICE)
127
- dummy_aligned = {res: torch.randn(1, HISTORY_LEN, len(ALL_CHANNELS), device=DEVICE) for res in ['15min', '30min', '1h']}
128
- dummy_regime = torch.tensor([1], device=DEVICE)
 
 
 
 
 
 
129
 
130
- print("Running model inference...")
131
- _, _, tokens, _ = kairos_model.forward_encode(dummy_price, dummy_aligned)
132
- regime_emb = kairos_model.regime_embedding(dummy_regime)
 
133
  log_probs = kairos_model.reasoner(tokens=tokens, regime_emb=regime_emb, event_features=event_tensor, memory_features=None)
134
 
 
135
  probs = torch.exp(log_probs).squeeze().cpu().numpy()
136
  action_map = {0: "BUY", 1: "SELL", 2: "HOLD"}
137
  prediction = action_map[np.argmax(probs)]
138
  confidences = {action_map[i]: f"{p:.2%}" for i, p in enumerate(probs)}
139
  timestamp = pd.Timestamp.now(tz=pytz.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')
140
- print(f"Prediction based on events up to {timestamp}: {prediction}")
141
 
142
  events_df = pd.DataFrame(live_events)
143
  events_df['currency'] = events_df.get('currency', events_df.get('country', 'N/A'))
144
  events_df['event'] = events_df.get('title', events_df.get('event', 'No Title'))
145
  display_df = events_df[['date', 'currency', 'impact', 'event']]
146
 
147
- return prediction, confidences, display_df, f"Prediction based on events up to: {timestamp}"
148
 
149
  # --- Gradio Interface ---
150
- with gr.Blocks(css="footer {display: none !important}", title="Kairos Live Event Analyzer") as demo:
151
- gr.Markdown("# Kairos Live Economic Event Analyzer")
152
- gr.Markdown("This demo fetches the latest economic calendar from Forex Factory. It then uses the **Kairos AI model** to analyze the *meaning* of these events and recommend a trading action for **EUR/USD** based on the current news context. **Data is cached for 10 minutes.**")
 
 
153
  with gr.Row():
154
- predict_btn = gr.Button("Analyze Live Events", variant="primary")
155
  with gr.Row():
156
  action_label = gr.Label(label="Recommended Action")
157
  timestamp_label = gr.Textbox(label="Analysis Timestamp", interactive=False)
158
  confidence_json = gr.JSON(label="Confidence Scores")
159
- events_table = gr.DataFrame(label="Events Considered in This Analysis", interactive=False)
160
- predict_btn.click(fn=predict_action, inputs=[], outputs=[action_label, confidence_json, events_table, timestamp_label])
161
- gr.Markdown("--- \n *Model: Kairos Stage 4 Policy Model. Not financial advice. For educational purposes only.*")
 
162
 
163
  if __name__ == "__main__":
164
- ModelSingleton()
165
  demo.launch()
 
18
  from sentence_transformers import SentenceTransformer
19
  from sklearn.preprocessing import StandardScaler
20
  import gradio as gr
21
+ from twelvedata import TDClient
22
+ import matplotlib
23
 
24
  # --- Configuration ---
25
  DEVICE = 'cpu'
 
28
  REGIME_EMBED_DIM = 8
29
  EVENT_EMBED_DIM = 384
30
  BASE_JSON_WEEK = "https://nfs.faireconomy.media/ff_calendar_thisweek.json"
31
+ CACHE_DURATION_SECONDS = 600 # Cache event data for 10 minutes
32
 
33
  BASE_CHANNELS = ['open', 'high', 'low', 'close']
34
  DERIVED_CHANNELS = ['Volume', 'VolMissing', 'Logclose', 'Return']
35
  ALL_CHANNELS = BASE_CHANNELS + DERIVED_CHANNELS
36
  IMPACT_MAP = {"High": 3, "Medium": 2, "Low": 1, "Holiday": 0}
37
 
38
+ # --- Simple In-Memory Cache for Events ---
39
+ _EVENT_CACHE = {"data": None, "timestamp": 0}
40
 
41
  # --- Model Architecture Definitions ---
42
  class SmallTCN(nn.Module):
 
80
 
81
  # --- Data Fetching and Processing Functions ---
82
  def fetch_week_json_with_cache():
 
83
  current_time = time.time()
84
+ if _EVENT_CACHE["data"] and (current_time - _EVENT_CACHE["timestamp"] < CACHE_DURATION_SECONDS):
 
85
  print("Using cached event data.")
86
+ return _EVENT_CACHE["data"]
 
87
  print("Fetching live economic events (cache was stale)...")
88
  try:
89
  response = requests.get(BASE_JSON_WEEK, headers={"User-Agent": "kairos-demo/1.0"}, timeout=10)
90
  response.raise_for_status()
91
  data = response.json()
92
+ _EVENT_CACHE["data"], _EVENT_CACHE["timestamp"] = data, current_time
 
 
93
  return data
94
+ except requests.RequestException as e: raise gr.Error(f"Failed to fetch event data. Error: {e}")
95
+
96
+ def fetch_twelvedata(api_key, symbol, interval, output_size):
97
+ try:
98
+ td = TDClient(apikey=api_key)
99
+ ts = td.time_series(symbol=symbol, interval=interval, outputsize=output_size)
100
+ df = ts.as_pandas()
101
+ df.reset_index(inplace=True)
102
+ df = df.sort_values(by='datetime').reset_index(drop=True)
103
+ df.columns = [c.lower() for c in df.columns]
104
+ return df
105
+ except Exception as e: raise gr.Error(f"Failed to fetch price data. Error: {e}")
106
+
107
+ def preprocess_price_df(df):
108
+ if df is None: return None
109
+ df['Volume'], df['VolMissing'] = 0.0, 1.0
110
+ df['Logclose'] = np.log(df['close'].astype(float) + 1e-12)
111
+ df['Return'] = df['Logclose'].diff().fillna(0.0)
112
+ return df
113
 
114
  def parse_financial_number(s: str) -> float:
115
  if not isinstance(s, str) or not s: return np.nan
 
127
  features.append(torch.cat((embedding, nums)))
128
  return torch.nan_to_num(torch.stack(features))
129
 
130
+ def get_regime(price_slice):
131
+ log_return = np.log(price_slice['close'] + 1e-9).diff().dropna()
132
+ volatility = log_return.std()
133
+ if volatility > 0.0008: return 0
134
+ if volatility > 0.0004: return 1
135
+ return 2
136
+
137
  # --- Main Prediction Function ---
138
  @torch.no_grad()
139
  def predict_action():
140
+ # 1. Load Models and API Key
141
  models = ModelSingleton()
142
  kairos_model, sentence_encoder = models.kairos_model, models.sentence_encoder
143
+ api_key = os.environ.get('TWELVE_DATA_API_KEY')
144
+ if not api_key: raise gr.Error("TWELVE_DATA_API_KEY not found. Please set it in Space secrets.")
145
+
146
+ # 2. Fetch Live Events and Prices
147
  live_events = fetch_week_json_with_cache()
148
+ price_data_store = {res: preprocess_price_df(fetch_twelvedata(api_key, 'EUR/USD', res, HISTORY_LEN)) for res in ['15min', '30min', '1h']}
149
+
150
+ # 3. Prepare Inputs for the Model
151
  event_tensor = format_events_for_kairos(live_events, sentence_encoder)
152
  if event_tensor.numel() > 0: event_tensor = event_tensor.unsqueeze(0)
153
 
154
+ base_df = price_data_store['15min']
155
+ hist_slice = base_df.iloc[-HISTORY_LEN:]
156
+
157
+ scaler = StandardScaler()
158
+ price_history_tensor = torch.tensor(scaler.fit_transform(hist_slice[ALL_CHANNELS].values), dtype=torch.float32).unsqueeze(0).to(DEVICE)
159
+ aligned_tensors = {res: torch.tensor(pd.merge_asof(hist_slice['datetime'].to_frame(), price_data_store[res][['datetime'] + ALL_CHANNELS], on='datetime', direction='backward')[ALL_CHANNELS].ffill().bfill().fillna(0.0).values, dtype=torch.float32).unsqueeze(0).to(DEVICE) for res in ['15min', '30min', '1h']}
160
+
161
+ regime = get_regime(hist_slice)
162
+ regime_tensor = torch.tensor([regime], device=DEVICE)
163
 
164
+ # 4. Run Model Inference
165
+ print("Running model inference with live prices and events...")
166
+ _, _, tokens, _ = kairos_model.forward_encode(price_history_tensor, aligned_tensors)
167
+ regime_emb = kairos_model.regime_embedding(regime_tensor)
168
  log_probs = kairos_model.reasoner(tokens=tokens, regime_emb=regime_emb, event_features=event_tensor, memory_features=None)
169
 
170
+ # 5. Interpret and Return Results
171
  probs = torch.exp(log_probs).squeeze().cpu().numpy()
172
  action_map = {0: "BUY", 1: "SELL", 2: "HOLD"}
173
  prediction = action_map[np.argmax(probs)]
174
  confidences = {action_map[i]: f"{p:.2%}" for i, p in enumerate(probs)}
175
  timestamp = pd.Timestamp.now(tz=pytz.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')
176
+ print(f"Prediction for {timestamp}: {prediction}")
177
 
178
  events_df = pd.DataFrame(live_events)
179
  events_df['currency'] = events_df.get('currency', events_df.get('country', 'N/A'))
180
  events_df['event'] = events_df.get('title', events_df.get('event', 'No Title'))
181
  display_df = events_df[['date', 'currency', 'impact', 'event']]
182
 
183
+ return prediction, confidences, hist_slice, display_df, f"Analysis based on data up to: {timestamp}"
184
 
185
  # --- Gradio Interface ---
186
+ with gr.Blocks(css="footer {display: none !important}", title="Kairos Live AI") as demo:
187
+ gr.Markdown("# Kairos Live Trading AI")
188
+ gr.Markdown("This interface fetches **live price data** (from `twelvedata`) and **live economic events** (from `Forex Factory`). The Kairos AI analyzes both to generate a real-time trading signal for **EUR/USD**.")
189
+ with gr.Accordion("⚠️ How to Use", open=False):
190
+ gr.Markdown("1. **Get a free API key** from [twelvedata.com](https://twelvedata.com/apikey).\n2. On this Space, go to **Settings** > **Repository secrets**.\n3. Create a new secret named `TWELVE_DATA_API_KEY` and paste your key as the value.")
191
  with gr.Row():
192
+ predict_btn = gr.Button("Get Live Signal", variant="primary")
193
  with gr.Row():
194
  action_label = gr.Label(label="Recommended Action")
195
  timestamp_label = gr.Textbox(label="Analysis Timestamp", interactive=False)
196
  confidence_json = gr.JSON(label="Confidence Scores")
197
+ price_plot = gr.LinePlot(label="Live Market Data Used for Prediction", x="datetime", y="close")
198
+ events_table = gr.DataFrame(label="Live Events Considered in This Analysis", interactive=False)
199
+ predict_btn.click(fn=predict_action, inputs=[], outputs=[action_label, confidence_json, price_plot, events_table, timestamp_label])
200
+ gr.Markdown("--- \n *Model: Kairos Stage 4 Policy Model. Not financial advice.*")
201
 
202
  if __name__ == "__main__":
203
+ ModelSingleton()
204
  demo.launch()