|
|
|
|
|
|
|
|
import os
|
|
|
import pandas as pd
|
|
|
from dotenv import load_dotenv
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from supabase import create_client, Client
|
|
|
from datetime import datetime, timedelta
|
|
|
from typing import List, Dict
|
|
|
|
|
|
|
|
|
from config import CONFIG
|
|
|
from utils import fetch_yahoo, candles_to_dataframe, create_features_for_df
|
|
|
from catboost import CatBoostRegressor
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
app = FastAPI(title="Stock Prediction API", version="1.0.0")
|
|
|
url: str = os.environ.get("SUPABASE_URL")
|
|
|
key: str = os.environ.get("SUPABASE_KEY")
|
|
|
supabase: Client = create_client(url, key)
|
|
|
MODELS: Dict[str, CatBoostRegressor] = {}
|
|
|
TARGETS = ['target_1d', 'target_3d', 'target_1w', 'target_2w']
|
|
|
MODEL_FEATURE_ORDER: List[str] = []
|
|
|
|
|
|
|
|
|
def bound_prediction(value: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
|
|
|
"""Clips a value to be within the specified range [min, max]."""
|
|
|
return max(min_val, min(value, max_val))
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
def load_models():
|
|
|
"""Load all CatBoost models from the /models directory into memory."""
|
|
|
print("--- Loading models at startup ---")
|
|
|
for target in TARGETS:
|
|
|
model_path = os.path.join("models", f"catboost_regressor_{target}.cbm")
|
|
|
if os.path.exists(model_path):
|
|
|
model = CatBoostRegressor()
|
|
|
model.load_model(model_path)
|
|
|
MODELS[target] = model
|
|
|
print(f"✅ Model loaded for {target}")
|
|
|
else:
|
|
|
print(f"🚨 WARNING: Model file not found at {model_path}")
|
|
|
if MODELS:
|
|
|
global MODEL_FEATURE_ORDER
|
|
|
MODEL_FEATURE_ORDER = list(MODELS.values())[0].feature_names_
|
|
|
print(f"Feature order set with {len(MODEL_FEATURE_ORDER)} features.")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
def read_root():
|
|
|
return {"status": "ok", "message": f"Prediction API is live. {len(MODELS)} models loaded."}
|
|
|
|
|
|
@app.post("/run-prediction-batch")
|
|
|
async def run_prediction_batch():
|
|
|
"""
|
|
|
Triggers a batch prediction job.
|
|
|
Fetches data, predicts, bounds the predictions to [0, 1], and saves to Supabase.
|
|
|
"""
|
|
|
if not MODELS:
|
|
|
raise HTTPException(status_code=500, detail="Models are not loaded. Cannot run predictions.")
|
|
|
|
|
|
print(f"\n--- Starting new prediction batch at {datetime.now().isoformat()} ---")
|
|
|
|
|
|
tickers_to_predict = CONFIG["IDX_TICKERS"]
|
|
|
all_predictions = []
|
|
|
|
|
|
for ticker in tickers_to_predict:
|
|
|
try:
|
|
|
|
|
|
fetch_days = CONFIG['HISTORY_BUFFER_DAYS']
|
|
|
start_date = (datetime.now() - timedelta(days=fetch_days)).strftime('%Y-%m-%d')
|
|
|
end_date = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
|
candles = fetch_yahoo(ticker, CONFIG['PROCESS_TIMEFRAMES'][0], start_date, end_date)
|
|
|
df_live = candles_to_dataframe(candles)
|
|
|
|
|
|
if df_live.empty or len(df_live) < 250:
|
|
|
print(f"Skipping {ticker}, not enough recent data.")
|
|
|
continue
|
|
|
|
|
|
latest_features_dict = create_features_for_df(df_live, CONFIG['PROCESS_TIMEFRAMES'][0])
|
|
|
if not latest_features_dict:
|
|
|
print(f"Skipping {ticker}, feature creation failed.")
|
|
|
continue
|
|
|
|
|
|
|
|
|
features_for_pred = pd.DataFrame([latest_features_dict])[MODEL_FEATURE_ORDER]
|
|
|
prediction_results = {}
|
|
|
for target, model in MODELS.items():
|
|
|
pred_score = model.predict(features_for_pred)[0]
|
|
|
|
|
|
bounded_score = bound_prediction(pred_score)
|
|
|
prediction_results[f'predicted_{target}'] = bounded_score
|
|
|
|
|
|
|
|
|
db_row = {
|
|
|
"prediction_time": datetime.now().isoformat(),
|
|
|
"ticker": ticker,
|
|
|
"predicted_target_1d": prediction_results.get('predicted_target_1d'),
|
|
|
"predicted_target_3d": prediction_results.get('predicted_target_3d'),
|
|
|
"predicted_target_1w": prediction_results.get('predicted_target_1w'),
|
|
|
"predicted_target_2w": prediction_results.get('predicted_target_2w'),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✅ Successfully predicted for {ticker}")
|
|
|
all_predictions.append(db_row)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"🚨 An error occurred while processing {ticker}: {e}")
|
|
|
continue
|
|
|
|
|
|
return {
|
|
|
"status": "success",
|
|
|
"processed_count": len(all_predictions),
|
|
|
"data": all_predictions
|
|
|
} |