# main.py import os import pandas as pd from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from supabase import create_client, Client from datetime import datetime, timedelta from typing import List, Dict # Local imports from config import CONFIG from utils import fetch_yahoo, candles_to_dataframe, create_features_for_df from catboost import CatBoostRegressor # --- INITIALIZATION (Same as before) --- load_dotenv() app = FastAPI(title="Stock Prediction API", version="1.0.0") url: str = os.environ.get("SUPABASE_URL") key: str = os.environ.get("SUPABASE_KEY") supabase: Client = create_client(url, key) MODELS: Dict[str, CatBoostRegressor] = {} TARGETS = ['target_1d', 'target_3d', 'target_1w', 'target_2w'] MODEL_FEATURE_ORDER: List[str] = [] # --- HELPER FUNCTION --- def bound_prediction(value: float, min_val: float = 0.0, max_val: float = 1.0) -> float: """Clips a value to be within the specified range [min, max].""" return max(min_val, min(value, max_val)) # --- STARTUP EVENT (Same as before) --- @app.on_event("startup") def load_models(): """Load all CatBoost models from the /models directory into memory.""" print("--- Loading models at startup ---") for target in TARGETS: model_path = os.path.join("models", f"catboost_regressor_{target}.cbm") if os.path.exists(model_path): model = CatBoostRegressor() model.load_model(model_path) MODELS[target] = model print(f"✅ Model loaded for {target}") else: print(f"🚨 WARNING: Model file not found at {model_path}") if MODELS: global MODEL_FEATURE_ORDER MODEL_FEATURE_ORDER = list(MODELS.values())[0].feature_names_ print(f"Feature order set with {len(MODEL_FEATURE_ORDER)} features.") # --- API ENDPOINTS --- @app.get("/") def read_root(): return {"status": "ok", "message": f"Prediction API is live. {len(MODELS)} models loaded."} @app.post("/run-prediction-batch") async def run_prediction_batch(): """ Triggers a batch prediction job. Fetches data, predicts, bounds the predictions to [0, 1], and saves to Supabase. """ if not MODELS: raise HTTPException(status_code=500, detail="Models are not loaded. Cannot run predictions.") print(f"\n--- Starting new prediction batch at {datetime.now().isoformat()} ---") tickers_to_predict = CONFIG["IDX_TICKERS"] all_predictions = [] for ticker in tickers_to_predict: try: # 1. Fetch & Prepare Data (Same as before) fetch_days = CONFIG['HISTORY_BUFFER_DAYS'] start_date = (datetime.now() - timedelta(days=fetch_days)).strftime('%Y-%m-%d') end_date = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d') candles = fetch_yahoo(ticker, CONFIG['PROCESS_TIMEFRAMES'][0], start_date, end_date) df_live = candles_to_dataframe(candles) if df_live.empty or len(df_live) < 250: print(f"Skipping {ticker}, not enough recent data.") continue latest_features_dict = create_features_for_df(df_live, CONFIG['PROCESS_TIMEFRAMES'][0]) if not latest_features_dict: print(f"Skipping {ticker}, feature creation failed.") continue # 2. Predict with models features_for_pred = pd.DataFrame([latest_features_dict])[MODEL_FEATURE_ORDER] prediction_results = {} for target, model in MODELS.items(): pred_score = model.predict(features_for_pred)[0] # ✨ CHANGE: Bound the prediction score to the [0, 1] range bounded_score = bound_prediction(pred_score) prediction_results[f'predicted_{target}'] = bounded_score # 3. Prepare data for Supabase db_row = { "prediction_time": datetime.now().isoformat(), "ticker": ticker, "predicted_target_1d": prediction_results.get('predicted_target_1d'), "predicted_target_3d": prediction_results.get('predicted_target_3d'), "predicted_target_1w": prediction_results.get('predicted_target_1w'), "predicted_target_2w": prediction_results.get('predicted_target_2w'), } # 4. Save to database # response = supabase.table('stock_predictions').upsert(db_row).execute() # if response.data: # print(f"✅ Successfully predicted and upserted for {ticker}") # all_predictions.append(db_row) # else: # print(f"🚨 DB Error for {ticker}: {response.error.message if response.error else 'Unknown error'}") # REVISED LOGIC: Just append the results to the list to be returned print(f"✅ Successfully predicted for {ticker}") all_predictions.append(db_row) except Exception as e: print(f"🚨 An error occurred while processing {ticker}: {e}") continue return { "status": "success", "processed_count": len(all_predictions), "data": all_predictions }