File size: 5,372 Bytes
c072ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bb5e85
c072ec7
9bb5e85
 
 
 
 
c072ec7
9bb5e85
 
 
 
c072ec7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# main.py

import os
import pandas as pd
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from supabase import create_client, Client
from datetime import datetime, timedelta
from typing import List, Dict

# Local imports
from config import CONFIG
from utils import fetch_yahoo, candles_to_dataframe, create_features_for_df
from catboost import CatBoostRegressor

# --- INITIALIZATION (Same as before) ---
load_dotenv()
app = FastAPI(title="Stock Prediction API", version="1.0.0")
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)
MODELS: Dict[str, CatBoostRegressor] = {}
TARGETS = ['target_1d', 'target_3d', 'target_1w', 'target_2w']
MODEL_FEATURE_ORDER: List[str] = []

# --- HELPER FUNCTION ---
def bound_prediction(value: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
    """Clips a value to be within the specified range [min, max]."""
    return max(min_val, min(value, max_val))

# --- STARTUP EVENT (Same as before) ---
@app.on_event("startup")
def load_models():
    """Load all CatBoost models from the /models directory into memory."""
    print("--- Loading models at startup ---")
    for target in TARGETS:
        model_path = os.path.join("models", f"catboost_regressor_{target}.cbm")
        if os.path.exists(model_path):
            model = CatBoostRegressor()
            model.load_model(model_path)
            MODELS[target] = model
            print(f"✅ Model loaded for {target}")
        else:
            print(f"🚨 WARNING: Model file not found at {model_path}")
    if MODELS:
        global MODEL_FEATURE_ORDER
        MODEL_FEATURE_ORDER = list(MODELS.values())[0].feature_names_
        print(f"Feature order set with {len(MODEL_FEATURE_ORDER)} features.")

# --- API ENDPOINTS ---

@app.get("/")
def read_root():
    return {"status": "ok", "message": f"Prediction API is live. {len(MODELS)} models loaded."}

@app.post("/run-prediction-batch")
async def run_prediction_batch():
    """

    Triggers a batch prediction job.

    Fetches data, predicts, bounds the predictions to [0, 1], and saves to Supabase.

    """
    if not MODELS:
        raise HTTPException(status_code=500, detail="Models are not loaded. Cannot run predictions.")

    print(f"\n--- Starting new prediction batch at {datetime.now().isoformat()} ---")
    
    tickers_to_predict = CONFIG["IDX_TICKERS"]
    all_predictions = []
    
    for ticker in tickers_to_predict:
        try:
            # 1. Fetch & Prepare Data (Same as before)
            fetch_days = CONFIG['HISTORY_BUFFER_DAYS']
            start_date = (datetime.now() - timedelta(days=fetch_days)).strftime('%Y-%m-%d')
            end_date = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
            candles = fetch_yahoo(ticker, CONFIG['PROCESS_TIMEFRAMES'][0], start_date, end_date)
            df_live = candles_to_dataframe(candles)

            if df_live.empty or len(df_live) < 250:
                print(f"Skipping {ticker}, not enough recent data.")
                continue

            latest_features_dict = create_features_for_df(df_live, CONFIG['PROCESS_TIMEFRAMES'][0])
            if not latest_features_dict:
                print(f"Skipping {ticker}, feature creation failed.")
                continue

            # 2. Predict with models
            features_for_pred = pd.DataFrame([latest_features_dict])[MODEL_FEATURE_ORDER]
            prediction_results = {}
            for target, model in MODELS.items():
                pred_score = model.predict(features_for_pred)[0]
                # ✨ CHANGE: Bound the prediction score to the [0, 1] range
                bounded_score = bound_prediction(pred_score)
                prediction_results[f'predicted_{target}'] = bounded_score

            # 3. Prepare data for Supabase
            db_row = {
                "prediction_time": datetime.now().isoformat(),
                "ticker": ticker,
                "predicted_target_1d": prediction_results.get('predicted_target_1d'),
                "predicted_target_3d": prediction_results.get('predicted_target_3d'),
                "predicted_target_1w": prediction_results.get('predicted_target_1w'),
                "predicted_target_2w": prediction_results.get('predicted_target_2w'),
            }
            
            # 4. Save to database
            # response = supabase.table('stock_predictions').upsert(db_row).execute()
            
            # if response.data:
            #     print(f"✅ Successfully predicted and upserted for {ticker}")
            #     all_predictions.append(db_row)
            # else:
            #      print(f"🚨 DB Error for {ticker}: {response.error.message if response.error else 'Unknown error'}")

            # REVISED LOGIC: Just append the results to the list to be returned
            print(f"✅ Successfully predicted for {ticker}")
            all_predictions.append(db_row)
            
        except Exception as e:
            print(f"🚨 An error occurred while processing {ticker}: {e}")
            continue

    return {
        "status": "success",
        "processed_count": len(all_predictions),
        "data": all_predictions
    }