File size: 3,897 Bytes
349ad65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# scripts/predict_tomorrow.py

import pandas as pd
import numpy as np
import os
import sys
from datetime import datetime, timedelta
from stable_baselines3 import SAC

# --- Imports ---
try:
    # Ensure we can find local scripts
    sys.path.append(os.getcwd())
except:
    pass

from fetch_market_data import fetch_market_data, ASSETS, FRED_IDS
from llm_analysis_rag import analyze_agent_decision

# --- Configuration ---
MODEL_PATH = "checkpoints/sac_portfolio_model.zip"
WINDOW_SIZE = 30
MACRO_COLS = list(FRED_IDS.values()) # ['Federal Funds Rate', 'CPI', 'VIX']

def get_latest_data_window(window_size=30):
    """
    Fetches live data and returns the last 'window_size' rows.
    """
    print("--- πŸ”„ Fetching Real-Time Data for Prediction ---")

    # Fetch a buffer to ensure we have enough data after cleaning
    lookback_days = window_size + 100
    end_date = datetime.now().strftime('%Y-%m-%d')
    start_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')

    # We don't strictly need to save to a file for prediction, so filename=None
    df = fetch_market_data(start_date, end_date, filename=None)

    if df is None or len(df) < window_size:
        print(f"❌ Not enough data fetched. Got {len(df) if df is not None else 0} rows, needed {window_size}.")
        return None

    # Return exactly the last N rows (Observation Window)
    return df.iloc[-window_size:].copy()

def prepare_observation(data_window):
    """
    Normalizes data: Window / First_Row_of_Window
    """
    # Extract specific columns to guarantee order
    price_data = data_window[ASSETS].values
    macro_data = data_window[MACRO_COLS].values

    # Normalize
    norm_prices = price_data / (price_data[0] + 1e-8)
    norm_macro = macro_data / (macro_data[0] + 1e-8)

    # Concatenate and flatten for MLP input
    obs = np.concatenate([norm_prices, norm_macro], axis=1)
    return obs.flatten().astype(np.float32)

def get_allocations(action):
    """Applies Softmax to convert raw action to weights"""
    action = np.asarray(action).flatten()
    exp_action = np.exp(action)
    return exp_action / np.sum(exp_action)

def main():
    print(f"πŸš€ Prediction Job: {datetime.now().strftime('%Y-%m-%d')}")

    # 1. Get Data
    data_window = get_latest_data_window(WINDOW_SIZE)
    if data_window is None: return

    # 2. Prepare Obs
    obs = prepare_observation(data_window)

    # 3. Load MLP Model
    if not os.path.exists(MODEL_PATH):
        print(f"❌ Model not found at {MODEL_PATH}")
        return

    print(f"Loading MLP SAC model...")
    model = SAC.load(MODEL_PATH)

    # 4. Predict
    action, _ = model.predict(obs, deterministic=True)
    weights = get_allocations(action)

    # 5. Format Allocations (THE FIX IS HERE)
    allocations = {}
    for i, asset in enumerate(ASSETS):
        allocations[asset] = float(weights[i]) # Explicit float() cast
    allocations['Cash'] = float(weights[-1])   # Explicit float() cast

    # 6. Output Results
    print("\n" + "="*40)
    print(f"πŸ€– SAC MLP MODEL RECOMMENDATION")
    print("="*40)
    for asset, weight in allocations.items():
        print(f"{asset:<10} : {weight:6.2%}")
    print("="*40)

    # 7. AI Risk Analyst
    print("\n🧠 Running AI Risk Analysis...")

    # Now this will work because all numbers are standard floats
    analysis = analyze_agent_decision(data_window, allocations)

    if isinstance(analysis, dict):
        print(f"\nStrategy:      {analysis.get('strategy_summary')}")
        print(f"Risk Level:    {analysis.get('risk_level')}")
        print(f"Justification: {analysis.get('justification')}")

        if analysis.get('risk_level') == 'High':
             print("\nβ›” BLOCKING TRADE: High Risk detected by AI Guardrail.")
        else:
             print("\nβœ… TRADE APPROVED.")
    else:
        print(analysis)

if __name__ == "__main__":
    main()