arfox-ai / main.py
exorcist123's picture
fix dont upsert to db
9bb5e85
# main.py
import os
import pandas as pd
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from supabase import create_client, Client
from datetime import datetime, timedelta
from typing import List, Dict
# Local imports
from config import CONFIG
from utils import fetch_yahoo, candles_to_dataframe, create_features_for_df
from catboost import CatBoostRegressor
# --- INITIALIZATION (Same as before) ---
load_dotenv()
app = FastAPI(title="Stock Prediction API", version="1.0.0")
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)
MODELS: Dict[str, CatBoostRegressor] = {}
TARGETS = ['target_1d', 'target_3d', 'target_1w', 'target_2w']
MODEL_FEATURE_ORDER: List[str] = []
# --- HELPER FUNCTION ---
def bound_prediction(value: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
"""Clips a value to be within the specified range [min, max]."""
return max(min_val, min(value, max_val))
# --- STARTUP EVENT (Same as before) ---
@app.on_event("startup")
def load_models():
"""Load all CatBoost models from the /models directory into memory."""
print("--- Loading models at startup ---")
for target in TARGETS:
model_path = os.path.join("models", f"catboost_regressor_{target}.cbm")
if os.path.exists(model_path):
model = CatBoostRegressor()
model.load_model(model_path)
MODELS[target] = model
print(f"✅ Model loaded for {target}")
else:
print(f"🚨 WARNING: Model file not found at {model_path}")
if MODELS:
global MODEL_FEATURE_ORDER
MODEL_FEATURE_ORDER = list(MODELS.values())[0].feature_names_
print(f"Feature order set with {len(MODEL_FEATURE_ORDER)} features.")
# --- API ENDPOINTS ---
@app.get("/")
def read_root():
return {"status": "ok", "message": f"Prediction API is live. {len(MODELS)} models loaded."}
@app.post("/run-prediction-batch")
async def run_prediction_batch():
"""
Triggers a batch prediction job.
Fetches data, predicts, bounds the predictions to [0, 1], and saves to Supabase.
"""
if not MODELS:
raise HTTPException(status_code=500, detail="Models are not loaded. Cannot run predictions.")
print(f"\n--- Starting new prediction batch at {datetime.now().isoformat()} ---")
tickers_to_predict = CONFIG["IDX_TICKERS"]
all_predictions = []
for ticker in tickers_to_predict:
try:
# 1. Fetch & Prepare Data (Same as before)
fetch_days = CONFIG['HISTORY_BUFFER_DAYS']
start_date = (datetime.now() - timedelta(days=fetch_days)).strftime('%Y-%m-%d')
end_date = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
candles = fetch_yahoo(ticker, CONFIG['PROCESS_TIMEFRAMES'][0], start_date, end_date)
df_live = candles_to_dataframe(candles)
if df_live.empty or len(df_live) < 250:
print(f"Skipping {ticker}, not enough recent data.")
continue
latest_features_dict = create_features_for_df(df_live, CONFIG['PROCESS_TIMEFRAMES'][0])
if not latest_features_dict:
print(f"Skipping {ticker}, feature creation failed.")
continue
# 2. Predict with models
features_for_pred = pd.DataFrame([latest_features_dict])[MODEL_FEATURE_ORDER]
prediction_results = {}
for target, model in MODELS.items():
pred_score = model.predict(features_for_pred)[0]
# ✨ CHANGE: Bound the prediction score to the [0, 1] range
bounded_score = bound_prediction(pred_score)
prediction_results[f'predicted_{target}'] = bounded_score
# 3. Prepare data for Supabase
db_row = {
"prediction_time": datetime.now().isoformat(),
"ticker": ticker,
"predicted_target_1d": prediction_results.get('predicted_target_1d'),
"predicted_target_3d": prediction_results.get('predicted_target_3d'),
"predicted_target_1w": prediction_results.get('predicted_target_1w'),
"predicted_target_2w": prediction_results.get('predicted_target_2w'),
}
# 4. Save to database
# response = supabase.table('stock_predictions').upsert(db_row).execute()
# if response.data:
# print(f"✅ Successfully predicted and upserted for {ticker}")
# all_predictions.append(db_row)
# else:
# print(f"🚨 DB Error for {ticker}: {response.error.message if response.error else 'Unknown error'}")
# REVISED LOGIC: Just append the results to the list to be returned
print(f"✅ Successfully predicted for {ticker}")
all_predictions.append(db_row)
except Exception as e:
print(f"🚨 An error occurred while processing {ticker}: {e}")
continue
return {
"status": "success",
"processed_count": len(all_predictions),
"data": all_predictions
}