DanielKiani commited on
Commit
7b9d8a3
·
verified ·
1 Parent(s): dbb03ae

Update scripts/app.py

Browse files
Files changed (1) hide show
  1. scripts/app.py +56 -127
scripts/app.py CHANGED
@@ -139,93 +139,51 @@ def calculate_metrics_pro(portfolio_values, freq=252, rf=0.0):
139
  # XAI: Feature Importance Function
140
  # =========================================
141
  def calculate_feature_importance(model, obs):
142
- """
143
- Calculates feature importance using Integrated Gradients on the RL agent's policy network.
144
- """
145
- # Convert observation to torch tensor and enable gradient tracking
146
  obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
 
147
  obs_tensor.requires_grad_()
148
-
149
- # Get the policy network (actor)
150
  actor = model.policy.actor
151
-
152
- # Define a baseline (e.g., a zero observation)
153
  baseline = torch.zeros_like(obs_tensor)
154
-
155
- # Number of steps for integral approximation
156
  steps = 50
157
-
158
- # Generate scaled inputs along the path from baseline to input
159
  scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
160
-
161
  grads = []
162
  for scaled_input in scaled_inputs:
163
- # Forward pass to get action distribution parameters (mean)
164
  action_mean = actor(scaled_input)
165
-
166
- # We need a scalar output to calculate gradients against.
167
- # Here we sum, representing overall sensitivity of the action vector.
168
  target_output = action_mean.sum()
169
-
170
- # Calculate gradients of the target output with respect to the input features
171
  grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
172
  grads.append(grad)
173
 
174
- # Average the gradients using the trapezoidal rule approximation
175
- avg_grads = (grads[:-1] + grads[1:]) / 2.0
176
- avg_grads = torch.stack(avg_grads).mean(dim=0)
 
 
177
 
178
- # Calculate Integrated Gradients: (input - baseline) * average_gradients
179
  integrated_grads = (obs_tensor - baseline) * avg_grads
180
-
181
- # Detach, move to cpu, and convert to numpy array
182
  importance_scores = integrated_grads.detach().cpu().numpy().flatten()
183
-
184
- # Feature Names mapping
185
- num_assets = len(ASSETS)
186
- num_macro = len(MACRO_COLS)
187
-
188
- # Create feature names based on the observation structure
189
  feature_names = []
190
  for i in range(WINDOW_SIZE):
191
- for asset in ASSETS:
192
- feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
193
  for i in range(WINDOW_SIZE):
194
- for macro in MACRO_COLS:
195
- feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
196
 
197
- # Combine into a dictionary and sort by absolute importance
198
  feature_importance_dict = dict(zip(feature_names, importance_scores))
199
-
200
- # Aggregate importance by feature type (sum of absolute values across time steps)
201
  aggregated_importance = {}
202
  for base_feature in ASSETS + MACRO_COLS:
203
  total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
204
  aggregated_importance[base_feature] = total_imp
205
 
206
- # Sort and take top N for display
207
  top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
208
 
209
- # Create a Plotly bar chart
210
- fig = px.bar(
211
- x=list(top_features.values()),
212
- y=list(top_features.keys()),
213
- orientation='h',
214
- title="Top Influential Features (XAI)",
215
- labels={'x': 'Relative Importance Score', 'y': 'Feature'},
216
- color=list(top_features.values()),
217
- color_continuous_scale=px.colors.sequential.Viridis
218
- )
219
- fig.update_layout(
220
- template="plotly_dark",
221
- paper_bgcolor='rgba(0,0,0,0)',
222
- plot_bgcolor='rgba(0,0,0,0)',
223
- yaxis={'categoryorder':'total ascending'},
224
- coloraxis_showscale=False,
225
- margin=dict(l=10, r=10, t=40, b=10),
226
- height=300 # Keep it compact
227
- )
228
-
229
  return fig
230
 
231
  # =========================================
@@ -411,83 +369,54 @@ def prepare_observation(data_window):
411
  return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
412
 
413
  def predict_and_analyze():
414
- """Main function for Forecast Tab."""
415
- status_msg = "Starting process..."
416
- loading_html = """<div style="color: #9ca3af;">🔄 Fetching data & running prediction...</div>"""
417
- # Update to yield an empty plot for the XAI chart initially
418
- yield status_msg, None, go.Figure(), loading_html
419
-
420
  try:
421
- data_window = get_latest_data_window(WINDOW_SIZE)
422
- # Get flattened obs for prediction and raw obs for XAI
423
  flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
424
-
425
- if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
426
  model = SAC.load(MODEL_PATH)
 
 
 
 
 
427
 
428
- # --- XAI: Calculate Feature Importance ---
429
- status_msg = "Calculating feature importance..."
430
- yield status_msg, None, go.Figure(), loading_html
431
- xai_plot = calculate_feature_importance(model, raw_obs)
432
-
433
- # --- Prediction ---
434
  action, _ = model.predict(flat_obs, deterministic=True)
435
- exp_action = np.exp(np.asarray(action).flatten())
436
- weights = exp_action / np.sum(exp_action)
437
- allocations_dict = {asset: weights[i] for i, asset in enumerate(ASSETS)}
438
- allocations_dict['Cash'] = weights[-1]
439
- alloc_df = pd.DataFrame(list(allocations_dict.items()), columns=['Asset', 'Proposed Allocation'])
440
- alloc_df['Proposed Allocation'] = alloc_df['Proposed Allocation'].apply(lambda x: f"{x:.2%}")
441
-
442
- status_msg = "Prediction done. Running AI Risk Analysis..."
443
- analysing_html = """<div style="color: #9ca3af;">🤖 Running Qwen-2.5-3B Risk Analysis...</div>"""
444
- # Yield XAI plot along with other outputs
445
- yield status_msg, alloc_df, xai_plot, analysing_html
446
-
447
- allocations_for_llm = {k: float(v) for k, v in allocations_dict.items()}
448
- analysis_result = analyze_agent_decision(df_window_for_analyst, allocations_for_llm)
449
- status_msg = "Analysis complete!"
450
-
451
- if isinstance(analysis_result, dict):
452
- strat = analysis_result.get('strategy_summary', 'N/A')
453
- risk = analysis_result.get('risk_level', 'N/A').upper()
454
- just = analysis_result.get('justification', 'N/A')
455
- conf = analysis_result.get('confidence_score', 'N/A')
456
- if 'HIGH' in risk:
457
- risk_css = "color: #ef4444; font-weight: bold;"
458
- status_bg = "#7f1d1d"
459
- status_border = "#ef4444"
460
- status_icon = ""
461
- status_text = "TRADE BLOCKED: High Risk Detected"
462
- else:
463
- risk_css = "color: #10b981; font-weight: bold;"
464
- status_bg = "#064e3b"
465
- status_border = "#10b981"
466
- status_icon = "🚀"
467
- status_text = "TRADE APPROVED"
468
-
469
- report_html = f"""
470
- <div style="background-color: #1f2937; padding: 20px; border-radius: 12px 12px 0 0; border: 1px solid #374151; border-bottom: none;">
471
- <h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Risk Analyst Report</h3>
472
- <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Strategy:</strong><br><span style="color: #d1d5db;">{strat}</span></div>
473
- <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Risk Level:</strong><span style="margin-left: 8px; {risk_css}">{risk}</span></div>
474
- <div style="margin-bottom: 15px;"><strong style="color: #9ca3af;">Justification:</strong><br><span style="color: #d1d5db;">{just}</span></div>
475
- <div><strong style="color: #9ca3af;">Confidence:</strong> <span style="color: #d1d5db;">{conf}/10</span></div>
476
- </div>
477
- <div style="background-color: {status_bg}; color: white; padding: 15px; border-radius: 0 0 12px 12px; border: 2px solid {status_border}; text-align: center; font-size: 1.2em; font-weight: bold; display: flex; align-items: center; justify-content: center;">
478
- <span style="margin-right: 10px; font-size: 1.4em;">{status_icon}</span>{status_text}
479
- </div>"""
480
  else:
481
- report_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Analysis Failed to Parse</h3><p>{str(analysis_result)}</p></div>"""
482
- # Final yield with all outputs including XAI plot
483
- yield status_msg, alloc_df, xai_plot, report_html
484
  except Exception as e:
485
  import traceback
486
  traceback.print_exc()
487
- status_msg = f"Error: {str(e)}"
488
- error_html = f"""<div style="padding: 20px; background-color: #7f1d1d; color: #fca5a5; border-radius: 12px;"><h3>❌ Process Error</h3><p>{str(e)}</p></div>"""
489
- # Final yield in case of error
490
- yield status_msg, None, go.Figure(), error_html
491
 
492
 
493
  # =========================================
 
139
  # XAI: Feature Importance Function
140
  # =========================================
141
  def calculate_feature_importance(model, obs):
 
 
 
 
142
  obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=model.device)
143
+ if obs_tensor.dim() == 1: obs_tensor = obs_tensor.unsqueeze(0)
144
  obs_tensor.requires_grad_()
145
+
 
146
  actor = model.policy.actor
 
 
147
  baseline = torch.zeros_like(obs_tensor)
 
 
148
  steps = 50
 
 
149
  scaled_inputs = [baseline + (float(i) / steps) * (obs_tensor - baseline) for i in range(steps + 1)]
150
+
151
  grads = []
152
  for scaled_input in scaled_inputs:
 
153
  action_mean = actor(scaled_input)
 
 
 
154
  target_output = action_mean.sum()
 
 
155
  grad = torch.autograd.grad(outputs=target_output, inputs=scaled_input)[0]
156
  grads.append(grad)
157
 
158
+ # --- Stack gradients first, then perform arithmetic ---
159
+ stacked_grads = torch.stack(grads)
160
+ avg_grads = (stacked_grads[:-1] + stacked_grads[1:]) / 2.0
161
+ avg_grads = avg_grads.mean(dim=0)
162
+ # -----------------------------------------------------------
163
 
 
164
  integrated_grads = (obs_tensor - baseline) * avg_grads
 
 
165
  importance_scores = integrated_grads.detach().cpu().numpy().flatten()
166
+
 
 
 
 
 
167
  feature_names = []
168
  for i in range(WINDOW_SIZE):
169
+ for asset in ASSETS: feature_names.append(f"{asset}_t-{WINDOW_SIZE-1-i}")
 
170
  for i in range(WINDOW_SIZE):
171
+ for macro in MACRO_COLS: feature_names.append(f"{macro}_t-{WINDOW_SIZE-1-i}")
 
172
 
 
173
  feature_importance_dict = dict(zip(feature_names, importance_scores))
 
 
174
  aggregated_importance = {}
175
  for base_feature in ASSETS + MACRO_COLS:
176
  total_imp = sum(abs(val) for key, val in feature_importance_dict.items() if key.startswith(base_feature))
177
  aggregated_importance[base_feature] = total_imp
178
 
 
179
  top_features = dict(sorted(aggregated_importance.items(), key=lambda item: item[1], reverse=True)[:8])
180
 
181
+ fig = px.bar(x=list(top_features.values()), y=list(top_features.keys()), orientation='h',
182
+ title="Top Influential Features (XAI)", labels={'x': 'Importance', 'y': 'Feature'},
183
+ color=list(top_features.values()), color_continuous_scale=px.colors.sequential.Viridis)
184
+ fig.update_layout(template="plotly_dark", paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
185
+ yaxis={'categoryorder':'total ascending'}, coloraxis_showscale=False, margin=dict(l=10, r=10, t=40, b=10), height=300,
186
+ hoverlabel=dict(bgcolor="white", font_size=14, font_family="Roboto", font_color="black"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  return fig
188
 
189
  # =========================================
 
369
  return obs.flatten().astype(np.float32), obs.astype(np.float32), data_window
370
 
371
  def predict_and_analyze():
372
+ yield "Starting...", None, go.Figure(), "Loading..."
 
 
 
 
 
373
  try:
374
+ data_window = get_prediction_data(WINDOW_SIZE)
 
375
  flat_obs, raw_obs, df_window_for_analyst = prepare_observation(data_window)
376
+
377
+ if not os.path.exists(MODEL_PATH): raise FileNotFoundError("Model not found.")
378
  model = SAC.load(MODEL_PATH)
379
+
380
+ # --- Pass the FLATTENED observation to XAI function ---
381
+ # The XAI function logic expects an input that matches the model's input layer.
382
+ yield "XAI Calc...", None, go.Figure(), "Calculating XAI..."
383
+ xai_plot = calculate_feature_importance(model, flat_obs)
384
 
 
 
 
 
 
 
385
  action, _ = model.predict(flat_obs, deterministic=True)
386
+ exp_act = np.exp(np.asarray(action).flatten())
387
+ weights = exp_act / np.sum(exp_act)
388
+
389
+ allocs = {ASSETS[i]: weights[i] for i in range(len(ASSETS))}
390
+ allocs['Cash'] = weights[-1]
391
+ alloc_df = pd.DataFrame(list(allocs.items()), columns=['Asset', 'Alloc'])
392
+ alloc_df['Alloc'] = alloc_df['Alloc'].apply(lambda x: f"{x:.2%}")
393
+
394
+ yield "AI Analysis...", alloc_df, xai_plot, "Running AI..."
395
+ llm_allocs = {k: float(v) for k, v in allocs.items()}
396
+ res = analyze_agent_decision(df_window_for_analyst, llm_allocs)
397
+
398
+ if isinstance(res, dict):
399
+ strat, risk, just, conf = res.get('strategy_summary','N/A'), res.get('risk_level','N/A').upper(), res.get('justification','N/A'), res.get('confidence_score','N/A')
400
+ border_col = "#ef4444" if 'HIGH' in risk else "#10b981"
401
+ bg_col = "#7f1d1d" if 'HIGH' in risk else "#064e3b"
402
+ icon = "⛔" if 'HIGH' in risk else "🚀"
403
+ status = "TRADE BLOCKED" if 'HIGH' in risk else "TRADE APPROVED"
404
+
405
+ html = f"""<div style="background-color: #1f2937; padding: 20px; border-radius: 12px; border: 1px solid #374151;">
406
+ <h3 style="margin-top: 0; color: #e5e7eb;">🤖 AI Report</h3>
407
+ <p><strong>Strategy:</strong> <span style="color:#d1d5db">{strat}</span></p>
408
+ <p><strong>Risk:</strong> <span style="color:{border_col}; font-weight:bold">{risk}</span></p>
409
+ <p><strong>Reason:</strong> <span style="color:#d1d5db">{just}</span></p>
410
+ <p><strong>Conf:</strong> <span style="color:#d1d5db">{conf}/10</span></p></div>
411
+ <div style="background-color:{bg_col}; color:white; padding:15px; margin-top:10px; border-radius:12px; text-align:center; font-weight:bold;">{icon} {status}</div>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  else:
413
+ html = f"<div style='color:red'>{str(res)}</div>"
414
+
415
+ yield "Done", alloc_df, xai_plot, html
416
  except Exception as e:
417
  import traceback
418
  traceback.print_exc()
419
+ yield f"Error: {str(e)}", None, go.Figure(), f"Error: {str(e)}"
 
 
 
420
 
421
 
422
  # =========================================