Update scripts/app.py
Browse files- scripts/app.py +56 -127
scripts/app.py
CHANGED
|
@@ -139,93 +139,51 @@ def calculate_metrics_pro(portfolio_values, freq=252, rf=0.0):
|
|
| 139 |
# XAI: Feature Importance Function
|
| 140 |
# =========================================
|
| 141 |
def calculate_feature_importance(model, obs):
|
| 142 |
-
"""
|
| 143 |
-
Calculates feature importance using Integrated Gradients on the RL agent's policy network.
|
| 144 |
-
"""
|
| 145 |
-
# Convert observation to torch tensor and enable gradient tracking
|
| 146 |
obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
|
|
|
|
| 147 |
obs_tensor.requires_grad_()
|
| 148 |
-
|
| 149 |
-
# Get the policy network (actor)
|
| 150 |
actor = model.policy.actor
|
| 151 |
-
|
| 152 |
-
# Define a baseline (e.g., a zero observation)
|
| 153 |
baseline = torch.zeros_like(obs_tensor)
|
| 154 |
-
|
| 155 |
-
# Number of steps for integral approximation
|
| 156 |
steps = 50
|
| 157 |
-
|
| 158 |
-
# Generate scaled inputs along the path from baseline to input
|
| 159 |
scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
|
| 160 |
-
|
| 161 |
grads = []
|
| 162 |
for scaled_input in scaled_inputs:
|
| 163 |
-
# Forward pass to get action distribution parameters (mean)
|
| 164 |
action_mean = actor(scaled_input)
|
| 165 |
-
|
| 166 |
-
# We need a scalar output to calculate gradients against.
|
| 167 |
-
# Here we sum, representing overall sensitivity of the action vector.
|
| 168 |
target_output = action_mean.sum()
|
| 169 |
-
|
| 170 |
-
# Calculate gradients of the target output with respect to the input features
|
| 171 |
grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
|
| 172 |
grads.append(grad)
|
| 173 |
|
| 174 |
-
#
|
| 175 |
-
|
| 176 |
-
avg_grads =
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
# Calculate Integrated Gradients: (input - baseline) * average_gradients
|
| 179 |
integrated_grads = (obs_tensor - baseline) * avg_grads
|
| 180 |
-
|
| 181 |
-
# Detach, move to cpu, and convert to numpy array
|
| 182 |
importance_scores = integrated_grads.detach().cpu().numpy().flatten()
|
| 183 |
-
|
| 184 |
-
# Feature Names mapping
|
| 185 |
-
num_assets = len(ASSETS)
|
| 186 |
-
num_macro = len(MACRO_COLS)
|
| 187 |
-
|
| 188 |
-
# Create feature names based on the observation structure
|
| 189 |
feature_names = []
|
| 190 |
for i in range(WINDOW_SIZE):
|
| 191 |
-
for asset in ASSETS:
|
| 192 |
-
feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
|
| 193 |
for i in range(WINDOW_SIZE):
|
| 194 |
-
for macro in MACRO_COLS:
|
| 195 |
-
feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
|
| 196 |
|
| 197 |
-
# Combine into a dictionary and sort by absolute importance
|
| 198 |
feature_importance_dict = dict(zip(feature_names, importance_scores))
|
| 199 |
-
|
| 200 |
-
# Aggregate importance by feature type (sum of absolute values across time steps)
|
| 201 |
aggregated_importance = {}
|
| 202 |
for base_feature in ASSETS + MACRO_COLS:
|
| 203 |
total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
|
| 204 |
aggregated_importance[base_feature] = total_imp
|
| 205 |
|
| 206 |
-
# Sort and take top N for display
|
| 207 |
top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
labels={'x': 'Relative Importance Score', 'y': 'Feature'},
|
| 216 |
-
color=list(top_features.values()),
|
| 217 |
-
color_continuous_scale=px.colors.sequential.Viridis
|
| 218 |
-
)
|
| 219 |
-
fig.update_layout(
|
| 220 |
-
template="plotly_dark",
|
| 221 |
-
paper_bgcolor='rgba(0,0,0,0)',
|
| 222 |
-
plot_bgcolor='rgba(0,0,0,0)',
|
| 223 |
-
yaxis={'categoryorder':'total ascending'},
|
| 224 |
-
coloraxis_showscale=False,
|
| 225 |
-
margin=dict(l=10, r=10, t=40, b=10),
|
| 226 |
-
height=300 # Keep it compact
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
return fig
|
| 230 |
|
| 231 |
# =========================================
|
|
@@ -411,83 +369,54 @@ def prepare_observation(data_window):
|
|
| 411 |
return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
|
| 412 |
|
| 413 |
def predict_and_analyze():
|
| 414 |
-
""
|
| 415 |
-
status_msg = "Starting process..."
|
| 416 |
-
loading_html = """<div style="color: #9ca3af;">🔄 Fetching data & running prediction...</div>"""
|
| 417 |
-
# Update to yield an empty plot for the XAI chart initially
|
| 418 |
-
yield status_msg, None, go.Figure(), loading_html
|
| 419 |
-
|
| 420 |
try:
|
| 421 |
-
data_window =
|
| 422 |
-
# Get flattened obs for prediction and raw obs for XAI
|
| 423 |
flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
|
| 424 |
-
|
| 425 |
-
if not os.path.exists(MODEL_PATH): raise FileNotFoundError(
|
| 426 |
model = SAC.load(MODEL_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
# --- XAI: Calculate Feature Importance ---
|
| 429 |
-
status_msg = "Calculating feature importance..."
|
| 430 |
-
yield status_msg, None, go.Figure(), loading_html
|
| 431 |
-
xai_plot = calculate_feature_importance(model, raw_obs)
|
| 432 |
-
|
| 433 |
-
# --- Prediction ---
|
| 434 |
action, _ = model.predict(flat_obs, deterministic=True)
|
| 435 |
-
|
| 436 |
-
weights =
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
alloc_df
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
status_text = "TRADE BLOCKED: High Risk Detected"
|
| 462 |
-
else:
|
| 463 |
-
risk_css = "color: #10b981; font-weight: bold;"
|
| 464 |
-
status_bg = "#064e3b"
|
| 465 |
-
status_border = "#10b981"
|
| 466 |
-
status_icon = "🚀"
|
| 467 |
-
status_text = "TRADE APPROVED"
|
| 468 |
-
|
| 469 |
-
report_html = f"""
|
| 470 |
-
<div style="background-color: #1f2937; padding: 20px; border-radius: 12px 12px 0 0; border: 1px solid #374151; border-bottom: none;">
|
| 471 |
-
<h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Risk Analyst Report</h3>
|
| 472 |
-
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Strategy:</strong><br><span style="color: #d1d5db;">{strat}</span></div>
|
| 473 |
-
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Risk Level:</strong><span style="margin-left: 8px; {risk_css}">{risk}</span></div>
|
| 474 |
-
<div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Justification:</strong><br><span style="color: #d1d5db;">{just}</span></div>
|
| 475 |
-
<div><strong style="color: #9ca3af;">Confidence:</strong> <span style="color: #d1d5db;">{conf}/10</span></div>
|
| 476 |
-
</div>
|
| 477 |
-
<div style="background-color: {status_bg}; color: white; padding: 15px; border-radius: 0 0 12px 12px; border: 2px solid {status_border}; text-align: center; font-size: 1.2em; font-weight: bold; display: flex; align-items: center; justify-content: center;">
|
| 478 |
-
<span style="margin-right: 10px; font-size: 1.4em;">{status_icon}</span>{status_text}
|
| 479 |
-
</div>"""
|
| 480 |
else:
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
yield
|
| 484 |
except Exception as e:
|
| 485 |
import traceback
|
| 486 |
traceback.print_exc()
|
| 487 |
-
|
| 488 |
-
error_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Process Error</h3><p>{str(e)}</p></div>"""
|
| 489 |
-
# Final yield in case of error
|
| 490 |
-
yield status_msg, None, go.Figure(), error_html
|
| 491 |
|
| 492 |
|
| 493 |
# =========================================
|
|
|
|
| 139 |
# XAI: Feature Importance Function
|
| 140 |
# =========================================
|
| 141 |
def calculate_feature_importance(model, obs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
|
| 143 |
+
if obs_tensor.dim() == 1: obs_tensor = obs_tensor.unsqueeze(0)
|
| 144 |
obs_tensor.requires_grad_()
|
| 145 |
+
|
|
|
|
| 146 |
actor = model.policy.actor
|
|
|
|
|
|
|
| 147 |
baseline = torch.zeros_like(obs_tensor)
|
|
|
|
|
|
|
| 148 |
steps = 50
|
|
|
|
|
|
|
| 149 |
scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
|
| 150 |
+
|
| 151 |
grads = []
|
| 152 |
for scaled_input in scaled_inputs:
|
|
|
|
| 153 |
action_mean = actor(scaled_input)
|
|
|
|
|
|
|
|
|
|
| 154 |
target_output = action_mean.sum()
|
|
|
|
|
|
|
| 155 |
grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
|
| 156 |
grads.append(grad)
|
| 157 |
|
| 158 |
+
# --- Stack gradients first, then perform arithmetic ---
|
| 159 |
+
stacked_grads = torch.stack(grads)
|
| 160 |
+
avg_grads = (stacked_grads[:-1] + stacked_grads[1:]) / 2.0
|
| 161 |
+
avg_grads = avg_grads.mean(dim=0)
|
| 162 |
+
# -----------------------------------------------------------
|
| 163 |
|
|
|
|
| 164 |
integrated_grads = (obs_tensor - baseline) * avg_grads
|
|
|
|
|
|
|
| 165 |
importance_scores = integrated_grads.detach().cpu().numpy().flatten()
|
| 166 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
feature_names = []
|
| 168 |
for i in range(WINDOW_SIZE):
|
| 169 |
+
for asset in ASSETS: feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
|
|
|
|
| 170 |
for i in range(WINDOW_SIZE):
|
| 171 |
+
for macro in MACRO_COLS: feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
|
|
|
|
| 172 |
|
|
|
|
| 173 |
feature_importance_dict = dict(zip(feature_names, importance_scores))
|
|
|
|
|
|
|
| 174 |
aggregated_importance = {}
|
| 175 |
for base_feature in ASSETS + MACRO_COLS:
|
| 176 |
total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
|
| 177 |
aggregated_importance[base_feature] = total_imp
|
| 178 |
|
|
|
|
| 179 |
top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
|
| 180 |
|
| 181 |
+
fig = px.bar(x=list(top_features.values()), y=list(top_features.keys()), orientation='h',
|
| 182 |
+
title="Top Influential Features (XAI)", labels={'x': 'Importance', 'y': 'Feature'},
|
| 183 |
+
color=list(top_features.values()), color_continuous_scale=px.colors.sequential.Viridis)
|
| 184 |
+
fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
| 185 |
+
yaxis={'categoryorder':'total ascending'}, coloraxis_showscale=False, margin=dict(l=10, r=10, t=40, b=10), height=300,
|
| 186 |
+
hoverlabel=dict(bgcolor="white", font_size=14, font_family="Roboto", font_color="black"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
return fig
|
| 188 |
|
| 189 |
# =========================================
|
|
|
|
| 369 |
return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
|
| 370 |
|
| 371 |
def predict_and_analyze():
|
| 372 |
+
yield "Starting...", None, go.Figure(), "Loading..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
try:
|
| 374 |
+
data_window = get_prediction_data(WINDOW_SIZE)
|
|
|
|
| 375 |
flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
|
| 376 |
+
|
| 377 |
+
if not os.path.exists(MODEL_PATH): raise FileNotFoundError("Model not found.")
|
| 378 |
model = SAC.load(MODEL_PATH)
|
| 379 |
+
|
| 380 |
+
# --- Pass the FLATTENED observation to XAI function ---
|
| 381 |
+
# The XAI function logic expects an input that matches the model's input layer.
|
| 382 |
+
yield "XAI Calc...", None, go.Figure(), "Calculating XAI..."
|
| 383 |
+
xai_plot = calculate_feature_importance(model, flat_obs)
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
action, _ = model.predict(flat_obs, deterministic=True)
|
| 386 |
+
exp_act = np.exp(np.asarray(action).flatten())
|
| 387 |
+
weights = exp_act / np.sum(exp_act)
|
| 388 |
+
|
| 389 |
+
allocs = {ASSETS[i]: weights[i] for i in range(len(ASSETS))}
|
| 390 |
+
allocs['Cash'] = weights[-1]
|
| 391 |
+
alloc_df = pd.DataFrame(list(allocs.items()), columns=['Asset', 'Alloc'])
|
| 392 |
+
alloc_df['Alloc'] = alloc_df['Alloc'].apply(lambda x: f"{x:.2%}")
|
| 393 |
+
|
| 394 |
+
yield "AI Analysis...", alloc_df, xai_plot, "Running AI..."
|
| 395 |
+
llm_allocs = {k: float(v) for k, v in allocs.items()}
|
| 396 |
+
res = analyze_agent_decision(df_window_for_analyst, llm_allocs)
|
| 397 |
+
|
| 398 |
+
if isinstance(res, dict):
|
| 399 |
+
strat, risk, just, conf = res.get('strategy_summary','N/A'), res.get('risk_level','N/A').upper(), res.get('justification','N/A'), res.get('confidence_score','N/A')
|
| 400 |
+
border_col = "#ef4444" if 'HIGH' in risk else "#10b981"
|
| 401 |
+
bg_col = "#7f1d1d" if 'HIGH' in risk else "#064e3b"
|
| 402 |
+
icon = "⛔" if 'HIGH' in risk else "🚀"
|
| 403 |
+
status = "TRADE BLOCKED" if 'HIGH' in risk else "TRADE APPROVED"
|
| 404 |
+
|
| 405 |
+
html = f"""<div style="background-color: #1f2937; padding: 20px; border-radius: 12px; border: 1px solid #374151;">
|
| 406 |
+
<h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Report</h3>
|
| 407 |
+
<p><strong>Strategy:</strong> <span style="color:#d1d5db">{strat}</span></p>
|
| 408 |
+
<p><strong>Risk:</strong> <span style="color:{border_col}; font-weight:bold">{risk}</span></p>
|
| 409 |
+
<p><strong>Reason:</strong> <span style="color:#d1d5db">{just}</span></p>
|
| 410 |
+
<p><strong>Conf:</strong> <span style="color:#d1d5db">{conf}/10</span></p></div>
|
| 411 |
+
<div style="background-color:{bg_col}; color:white; padding:15px; margin-top:10px; border-radius:12px; text-align:center; font-weight:bold;">{icon} {status}</div>"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
else:
|
| 413 |
+
html = f"<div style='color:red'>{str(res)}</div>"
|
| 414 |
+
|
| 415 |
+
yield "Done", alloc_df, xai_plot, html
|
| 416 |
except Exception as e:
|
| 417 |
import traceback
|
| 418 |
traceback.print_exc()
|
| 419 |
+
yield f"Error: {str(e)}", None, go.Figure(), f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
|
| 422 |
# =========================================
|