Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,6 +18,8 @@ import torch.nn.functional as F
|
|
| 18 |
from sentence_transformers import SentenceTransformer
|
| 19 |
from sklearn.preprocessing import StandardScaler
|
| 20 |
import gradio as gr
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# --- Configuration ---
|
| 23 |
DEVICE = 'cpu'
|
|
@@ -26,15 +28,15 @@ VQ_DIM = 64
|
|
| 26 |
REGIME_EMBED_DIM = 8
|
| 27 |
EVENT_EMBED_DIM = 384
|
| 28 |
BASE_JSON_WEEK = "https://nfs.faireconomy.media/ff_calendar_thisweek.json"
|
| 29 |
-
CACHE_DURATION_SECONDS = 600 # Cache data for 10 minutes
|
| 30 |
|
| 31 |
BASE_CHANNELS = ['open', 'high', 'low', 'close']
|
| 32 |
DERIVED_CHANNELS = ['Volume', 'VolMissing', 'Logclose', 'Return']
|
| 33 |
ALL_CHANNELS = BASE_CHANNELS + DERIVED_CHANNELS
|
| 34 |
IMPACT_MAP = {"High": 3, "Medium": 2, "Low": 1, "Holiday": 0}
|
| 35 |
|
| 36 |
-
# --- Simple In-Memory Cache ---
|
| 37 |
-
|
| 38 |
|
| 39 |
# --- Model Architecture Definitions ---
|
| 40 |
class SmallTCN(nn.Module):
|
|
@@ -78,24 +80,36 @@ class ModelSingleton:
|
|
| 78 |
|
| 79 |
# --- Data Fetching and Processing Functions ---
|
| 80 |
def fetch_week_json_with_cache():
|
| 81 |
-
"""Fetches data from the API, using a time-based cache to avoid rate limits."""
|
| 82 |
current_time = time.time()
|
| 83 |
-
|
| 84 |
-
if _CACHE["data"] and (current_time - _CACHE["timestamp"] < CACHE_DURATION_SECONDS):
|
| 85 |
print("Using cached event data.")
|
| 86 |
-
return
|
| 87 |
-
|
| 88 |
print("Fetching live economic events (cache was stale)...")
|
| 89 |
try:
|
| 90 |
response = requests.get(BASE_JSON_WEEK, headers={"User-Agent": "kairos-demo/1.0"}, timeout=10)
|
| 91 |
response.raise_for_status()
|
| 92 |
data = response.json()
|
| 93 |
-
|
| 94 |
-
_CACHE["data"] = data
|
| 95 |
-
_CACHE["timestamp"] = current_time
|
| 96 |
return data
|
| 97 |
-
except requests.RequestException as e:
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
def parse_financial_number(s: str) -> float:
|
| 101 |
if not isinstance(s, str) or not s: return np.nan
|
|
@@ -113,53 +127,78 @@ def format_events_for_kairos(events: list, encoder: SentenceTransformer) -> torc
|
|
| 113 |
features.append(torch.cat((embedding, nums)))
|
| 114 |
return torch.nan_to_num(torch.stack(features))
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# --- Main Prediction Function ---
|
| 117 |
@torch.no_grad()
|
| 118 |
def predict_action():
|
|
|
|
| 119 |
models = ModelSingleton()
|
| 120 |
kairos_model, sentence_encoder = models.kairos_model, models.sentence_encoder
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
live_events = fetch_week_json_with_cache()
|
|
|
|
|
|
|
|
|
|
| 123 |
event_tensor = format_events_for_kairos(live_events, sentence_encoder)
|
| 124 |
if event_tensor.numel() > 0: event_tensor = event_tensor.unsqueeze(0)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
| 133 |
log_probs = kairos_model.reasoner(tokens=tokens, regime_emb=regime_emb, event_features=event_tensor, memory_features=None)
|
| 134 |
|
|
|
|
| 135 |
probs = torch.exp(log_probs).squeeze().cpu().numpy()
|
| 136 |
action_map = {0: "BUY", 1: "SELL", 2: "HOLD"}
|
| 137 |
prediction = action_map[np.argmax(probs)]
|
| 138 |
confidences = {action_map[i]: f"{p:.2%}" for i, p in enumerate(probs)}
|
| 139 |
timestamp = pd.Timestamp.now(tz=pytz.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')
|
| 140 |
-
print(f"Prediction
|
| 141 |
|
| 142 |
events_df = pd.DataFrame(live_events)
|
| 143 |
events_df['currency'] = events_df.get('currency', events_df.get('country', 'N/A'))
|
| 144 |
events_df['event'] = events_df.get('title', events_df.get('event', 'No Title'))
|
| 145 |
display_df = events_df[['date', 'currency', 'impact', 'event']]
|
| 146 |
|
| 147 |
-
return prediction, confidences, display_df, f"
|
| 148 |
|
| 149 |
# --- Gradio Interface ---
|
| 150 |
-
with gr.Blocks(css="footer {display: none !important}", title="Kairos Live
|
| 151 |
-
gr.Markdown("# Kairos Live
|
| 152 |
-
gr.Markdown("This
|
|
|
|
|
|
|
| 153 |
with gr.Row():
|
| 154 |
-
predict_btn = gr.Button("
|
| 155 |
with gr.Row():
|
| 156 |
action_label = gr.Label(label="Recommended Action")
|
| 157 |
timestamp_label = gr.Textbox(label="Analysis Timestamp", interactive=False)
|
| 158 |
confidence_json = gr.JSON(label="Confidence Scores")
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
-
ModelSingleton()
|
| 165 |
demo.launch()
|
|
|
|
| 18 |
from sentence_transformers import SentenceTransformer
|
| 19 |
from sklearn.preprocessing import StandardScaler
|
| 20 |
import gradio as gr
|
| 21 |
+
from twelvedata import TDClient
|
| 22 |
+
import matplotlib
|
| 23 |
|
| 24 |
# --- Configuration ---
|
| 25 |
DEVICE = 'cpu'
|
|
|
|
| 28 |
REGIME_EMBED_DIM = 8
|
| 29 |
EVENT_EMBED_DIM = 384
|
| 30 |
BASE_JSON_WEEK = "https://nfs.faireconomy.media/ff_calendar_thisweek.json"
|
| 31 |
+
CACHE_DURATION_SECONDS = 600 # Cache event data for 10 minutes
|
| 32 |
|
| 33 |
BASE_CHANNELS = ['open', 'high', 'low', 'close']
|
| 34 |
DERIVED_CHANNELS = ['Volume', 'VolMissing', 'Logclose', 'Return']
|
| 35 |
ALL_CHANNELS = BASE_CHANNELS + DERIVED_CHANNELS
|
| 36 |
IMPACT_MAP = {"High": 3, "Medium": 2, "Low": 1, "Holiday": 0}
|
| 37 |
|
| 38 |
+
# --- Simple In-Memory Cache for Events ---
|
| 39 |
+
_EVENT_CACHE = {"data": None, "timestamp": 0}
|
| 40 |
|
| 41 |
# --- Model Architecture Definitions ---
|
| 42 |
class SmallTCN(nn.Module):
|
|
|
|
| 80 |
|
| 81 |
# --- Data Fetching and Processing Functions ---
|
| 82 |
def fetch_week_json_with_cache():
|
|
|
|
| 83 |
current_time = time.time()
|
| 84 |
+
if _EVENT_CACHE["data"] and (current_time - _EVENT_CACHE["timestamp"] < CACHE_DURATION_SECONDS):
|
|
|
|
| 85 |
print("Using cached event data.")
|
| 86 |
+
return _EVENT_CACHE["data"]
|
|
|
|
| 87 |
print("Fetching live economic events (cache was stale)...")
|
| 88 |
try:
|
| 89 |
response = requests.get(BASE_JSON_WEEK, headers={"User-Agent": "kairos-demo/1.0"}, timeout=10)
|
| 90 |
response.raise_for_status()
|
| 91 |
data = response.json()
|
| 92 |
+
_EVENT_CACHE["data"], _EVENT_CACHE["timestamp"] = data, current_time
|
|
|
|
|
|
|
| 93 |
return data
|
| 94 |
+
except requests.RequestException as e: raise gr.Error(f"Failed to fetch event data. Error: {e}")
|
| 95 |
+
|
| 96 |
+
def fetch_twelvedata(api_key, symbol, interval, output_size):
|
| 97 |
+
try:
|
| 98 |
+
td = TDClient(apikey=api_key)
|
| 99 |
+
ts = td.time_series(symbol=symbol, interval=interval, outputsize=output_size)
|
| 100 |
+
df = ts.as_pandas()
|
| 101 |
+
df.reset_index(inplace=True)
|
| 102 |
+
df = df.sort_values(by='datetime').reset_index(drop=True)
|
| 103 |
+
df.columns = [c.lower() for c in df.columns]
|
| 104 |
+
return df
|
| 105 |
+
except Exception as e: raise gr.Error(f"Failed to fetch price data. Error: {e}")
|
| 106 |
+
|
| 107 |
+
def preprocess_price_df(df):
|
| 108 |
+
if df is None: return None
|
| 109 |
+
df['Volume'], df['VolMissing'] = 0.0, 1.0
|
| 110 |
+
df['Logclose'] = np.log(df['close'].astype(float) + 1e-12)
|
| 111 |
+
df['Return'] = df['Logclose'].diff().fillna(0.0)
|
| 112 |
+
return df
|
| 113 |
|
| 114 |
def parse_financial_number(s: str) -> float:
|
| 115 |
if not isinstance(s, str) or not s: return np.nan
|
|
|
|
| 127 |
features.append(torch.cat((embedding, nums)))
|
| 128 |
return torch.nan_to_num(torch.stack(features))
|
| 129 |
|
| 130 |
+
def get_regime(price_slice):
|
| 131 |
+
log_return = np.log(price_slice['close'] + 1e-9).diff().dropna()
|
| 132 |
+
volatility = log_return.std()
|
| 133 |
+
if volatility > 0.0008: return 0
|
| 134 |
+
if volatility > 0.0004: return 1
|
| 135 |
+
return 2
|
| 136 |
+
|
| 137 |
# --- Main Prediction Function ---
|
| 138 |
@torch.no_grad()
|
| 139 |
def predict_action():
|
| 140 |
+
# 1. Load Models and API Key
|
| 141 |
models = ModelSingleton()
|
| 142 |
kairos_model, sentence_encoder = models.kairos_model, models.sentence_encoder
|
| 143 |
+
api_key = os.environ.get('TWELVE_DATA_API_KEY')
|
| 144 |
+
if not api_key: raise gr.Error("TWELVE_DATA_API_KEY not found. Please set it in Space secrets.")
|
| 145 |
+
|
| 146 |
+
# 2. Fetch Live Events and Prices
|
| 147 |
live_events = fetch_week_json_with_cache()
|
| 148 |
+
price_data_store = {res: preprocess_price_df(fetch_twelvedata(api_key, 'EUR/USD', res, HISTORY_LEN)) for res in ['15min', '30min', '1h']}
|
| 149 |
+
|
| 150 |
+
# 3. Prepare Inputs for the Model
|
| 151 |
event_tensor = format_events_for_kairos(live_events, sentence_encoder)
|
| 152 |
if event_tensor.numel() > 0: event_tensor = event_tensor.unsqueeze(0)
|
| 153 |
|
| 154 |
+
base_df = price_data_store['15min']
|
| 155 |
+
hist_slice = base_df.iloc[-HISTORY_LEN:]
|
| 156 |
+
|
| 157 |
+
scaler = StandardScaler()
|
| 158 |
+
price_history_tensor = torch.tensor(scaler.fit_transform(hist_slice[ALL_CHANNELS].values), dtype=torch.float32).unsqueeze(0).to(DEVICE)
|
| 159 |
+
aligned_tensors = {res: torch.tensor(pd.merge_asof(hist_slice['datetime'].to_frame(), price_data_store[res][['datetime'] + ALL_CHANNELS], on='datetime', direction='backward')[ALL_CHANNELS].ffill().bfill().fillna(0.0).values, dtype=torch.float32).unsqueeze(0).to(DEVICE) for res in ['15min', '30min', '1h']}
|
| 160 |
+
|
| 161 |
+
regime = get_regime(hist_slice)
|
| 162 |
+
regime_tensor = torch.tensor([regime], device=DEVICE)
|
| 163 |
|
| 164 |
+
# 4. Run Model Inference
|
| 165 |
+
print("Running model inference with live prices and events...")
|
| 166 |
+
_, _, tokens, _ = kairos_model.forward_encode(price_history_tensor, aligned_tensors)
|
| 167 |
+
regime_emb = kairos_model.regime_embedding(regime_tensor)
|
| 168 |
log_probs = kairos_model.reasoner(tokens=tokens, regime_emb=regime_emb, event_features=event_tensor, memory_features=None)
|
| 169 |
|
| 170 |
+
# 5. Interpret and Return Results
|
| 171 |
probs = torch.exp(log_probs).squeeze().cpu().numpy()
|
| 172 |
action_map = {0: "BUY", 1: "SELL", 2: "HOLD"}
|
| 173 |
prediction = action_map[np.argmax(probs)]
|
| 174 |
confidences = {action_map[i]: f"{p:.2%}" for i, p in enumerate(probs)}
|
| 175 |
timestamp = pd.Timestamp.now(tz=pytz.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')
|
| 176 |
+
print(f"Prediction for {timestamp}: {prediction}")
|
| 177 |
|
| 178 |
events_df = pd.DataFrame(live_events)
|
| 179 |
events_df['currency'] = events_df.get('currency', events_df.get('country', 'N/A'))
|
| 180 |
events_df['event'] = events_df.get('title', events_df.get('event', 'No Title'))
|
| 181 |
display_df = events_df[['date', 'currency', 'impact', 'event']]
|
| 182 |
|
| 183 |
+
return prediction, confidences, hist_slice, display_df, f"Analysis based on data up to: {timestamp}"
|
| 184 |
|
| 185 |
# --- Gradio Interface ---
|
| 186 |
+
with gr.Blocks(css="footer {display: none !important}", title="Kairos Live AI") as demo:
|
| 187 |
+
gr.Markdown("# Kairos Live Trading AI")
|
| 188 |
+
gr.Markdown("This interface fetches **live price data** (from `twelvedata`) and **live economic events** (from `Forex Factory`). The Kairos AI analyzes both to generate a real-time trading signal for **EUR/USD**.")
|
| 189 |
+
with gr.Accordion("⚠️ How to Use", open=False):
|
| 190 |
+
gr.Markdown("1. **Get a free API key** from [twelvedata.com](https://twelvedata.com/apikey).\n2. On this Space, go to **Settings** > **Repository secrets**.\n3. Create a new secret named `TWELVE_DATA_API_KEY` and paste your key as the value.")
|
| 191 |
with gr.Row():
|
| 192 |
+
predict_btn = gr.Button("Get Live Signal", variant="primary")
|
| 193 |
with gr.Row():
|
| 194 |
action_label = gr.Label(label="Recommended Action")
|
| 195 |
timestamp_label = gr.Textbox(label="Analysis Timestamp", interactive=False)
|
| 196 |
confidence_json = gr.JSON(label="Confidence Scores")
|
| 197 |
+
price_plot = gr.LinePlot(label="Live Market Data Used for Prediction", x="datetime", y="close")
|
| 198 |
+
events_table = gr.DataFrame(label="Live Events Considered in This Analysis", interactive=False)
|
| 199 |
+
predict_btn.click(fn=predict_action, inputs=[], outputs=[action_label, confidence_json, price_plot, events_table, timestamp_label])
|
| 200 |
+
gr.Markdown("--- \n *Model: Kairos Stage 4 Policy Model. Not financial advice.*")
|
| 201 |
|
| 202 |
if __name__ == "__main__":
|
| 203 |
+
ModelSingleton()
|
| 204 |
demo.launch()
|