from fastapi import FastAPI, HTTPException, Depends, status, Header from contextlib import asynccontextmanager from typing import Optional, Annotated from datetime import datetime, timedelta # SQLModel & Database Imports from sqlmodel import Session, select from database import create_db_and_tables, engine from models import User # Security Imports from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from passlib.context import CryptContext from jose import JWTError, jwt # Import your strategy logic from strategy import train_models_and_backtest # Import model manager for HMM-SVR models from model_manager import ( load_all_models, train_and_save_model, load_model, is_model_trained, get_model_info, get_cached_models ) # --- 1. LIFESPAN (Create Tables on Startup) --- @asynccontextmanager async def lifespan(app: FastAPI): create_db_and_tables() # Load all pre-trained HMM-SVR models from disk into memory print("\nšŸš€ Starting AlgoQuant API...") loaded_models = load_all_models() if loaded_models: print(f"āœ… Loaded {len(loaded_models)} HMM-SVR models: {list(loaded_models.keys())}") else: print("ā„¹ļø No pre-trained models found. Train models using /api/models/train/{symbol}") yield app = FastAPI(lifespan=lifespan) # --- CONFIGURATION --- import os SECRET_KEY = os.getenv("SECRET_KEY", "algoquant_super_secret_key") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 43200 # 30 days (30 * 24 * 60) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # --- CORS --- app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:3000", "http://127.0.0.1:3000", "https://algo-quant-pi.vercel.app" ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Health Check Endpoint --- @app.get("/") async def root(): return {"status": "healthy", "message": "AlgoQuant API is running"} @app.get("/health") async def health_check(): return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()} # --- Pydantic Models (For Request Body) --- class UserCreate(BaseModel): email: str password: str name: Optional[str] = None class UserLogin(BaseModel): email: str password: str class BacktestRequest(BaseModel): ticker: str start_date: str end_date: str strategy: str = "hmm_svr" # Strategy-specific parameters short_window: int = 12 long_window: int = 26 n_states: int = 3 class Token(BaseModel): access_token: str token_type: str class SimulatedTradingRequest(BaseModel): symbol: str trade_amount: float duration: int duration_unit: str = "minutes" # "minutes" or "days" # --- DATABASE DEPENDENCY --- def get_session(): with Session(engine) as session: yield session # --- AUTH HELPERS --- def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): return pwd_context.hash(password) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.now() + expires_delta else: expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) # --- THE GUARD (Protect Routes) --- async def get_current_user(authorization: str = Header(None)): if not authorization: raise HTTPException(status_code=401, detail="Missing Token") try: token = authorization.split(" ")[1] payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) email: str = payload.get("sub") if email is None: raise HTTPException(status_code=401, detail="Invalid Token") return email except (JWTError, IndexError): raise HTTPException(status_code=401, detail="Could not validate credentials") # --- ROUTES --- @app.post("/api/signup", response_model=Token) def signup(user_data: UserCreate, session: Session = Depends(get_session)): try: # 1. Check if user exists in DB statement = select(User).where(User.email == user_data.email) existing_user = session.exec(statement).first() if existing_user: raise HTTPException(status_code=400, detail="Email already registered") # 2. Hash Password & Create User Object hashed_pwd = get_password_hash(user_data.password) new_user = User( email=user_data.email, name=user_data.name, hashed_password=hashed_pwd ) # 3. Save to DB session.add(new_user) session.commit() session.refresh(new_user) # 4. Auto-login (Return Token immediately) access_token = create_access_token(data={"sub": new_user.email}) return {"access_token": access_token, "token_type": "bearer"} except HTTPException: raise except Exception as e: session.rollback() print(f"Signup error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error during signup") @app.post("/api/login", response_model=Token) def login(user_data: UserLogin, session: Session = Depends(get_session)): try: # 1. Validate input if not user_data.email or not user_data.password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email and password are required" ) # 2. Select User from DB statement = select(User).where(User.email == user_data.email) user = session.exec(statement).first() # 3. Verify if not user or not verify_password(user_data.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) # 4. Issue Token (30 days expiration) access_token = create_access_token(data={"sub": user.email}) return {"access_token": access_token, "token_type": "bearer"} except HTTPException: raise except Exception as e: print(f"Login error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error during login") @app.post("/api/backtest") def run_backtest( req: BacktestRequest, current_user: str = Depends(get_current_user) ): print(f"User {current_user} is running {req.strategy} backtest...") result = train_models_and_backtest( req.ticker, req.start_date, req.end_date, short_window=req.short_window, long_window=req.long_window, n_states=req.n_states ) return result @app.get("/api/backtest/strategies") def get_backtest_strategies(current_user: str = Depends(get_current_user)): """Get available backtest strategies""" return { "strategies": [] } # --- MODEL MANAGEMENT ROUTES (HMM-SVR) --- @app.post("/api/models/train/{symbol}") def train_model(symbol: str, current_user: str = Depends(get_current_user)): """ Train and save HMM-SVR model for a specific symbol. This trains on 4 years of historical data and saves the model to disk. The model will be automatically loaded on next startup. Example: POST /api/models/train/BTCUSDT """ print(f"[API] User {current_user} requested model training for {symbol}") # Validate symbol format symbol = symbol.upper() if not symbol.endswith('USDT'): raise HTTPException( status_code=400, detail="Symbol must end with USDT (e.g., BTCUSDT, ETHUSDT)" ) result = train_and_save_model(symbol, n_states=3) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return { "success": True, "message": f"Model trained and saved for {symbol}", "details": result } @app.get("/api/models/status/{symbol}") def get_model_status(symbol: str, current_user: str = Depends(get_current_user)): """ Check if a model exists and get its metadata for a symbol. """ symbol = symbol.upper() if not is_model_trained(symbol): return { "trained": False, "symbol": symbol, "message": f"No model found for {symbol}. Train it using POST /api/models/train/{symbol}" } info = get_model_info(symbol) return { "trained": True, "symbol": symbol, "info": info } @app.get("/api/models") def list_models(current_user: str = Depends(get_current_user)): """ List all available trained models and their status. """ cached = get_cached_models() # Also check for models on disk that aren't loaded yet import os from model_manager import MODEL_DIR disk_models = [] if os.path.exists(MODEL_DIR): for filename in os.listdir(MODEL_DIR): if filename.endswith('_hmm_svr.pkl'): symbol = filename.replace('_hmm_svr.pkl', '').upper() disk_models.append(symbol) return { "loaded_models": cached, "available_on_disk": disk_models, "total_count": len(set(list(cached.keys()) + disk_models)) } @app.post("/api/models/reload") def reload_models(current_user: str = Depends(get_current_user)): """ Reload all models from disk into memory. Useful if models were trained externally or after a restart. """ result = load_all_models() return { "success": True, "loaded_models": list(result.keys()), "count": sum(result.values()) } @app.get("/api/models/signal/{symbol}") def get_instant_signal(symbol: str, current_user: str = Depends(get_current_user)): """ Get instant trading signal for a symbol using trained HMM-SVR model. Auto-trains model if it doesn't exist. Returns current regime, recommended position size, and trading signal. Example: GET /api/models/signal/BTCUSDT """ from model_manager import is_model_trained, load_model, calculate_signal_and_position, train_and_save_model import yfinance as yf from datetime import datetime, timedelta import pandas as pd symbol = symbol.upper() base_symbol = symbol.replace('USDT', '') yahoo_symbol = f"{base_symbol}-USD" # Convert to Yahoo Finance format # Check if model exists, train if not (same as bot auto-training) if not is_model_trained(base_symbol) and not is_model_trained(symbol): print(f"[SignalAPI] No model found for {base_symbol}, training now...") try: # Train model with both Yahoo symbol and Binance symbol for fallback # Save model with base symbol name (BNB) not Yahoo format (BNB-USD) train_result = train_and_save_model( symbol=yahoo_symbol, n_states=3, binance_symbol=symbol, save_as=base_symbol ) if train_result and 'error' not in train_result: print(f"[SignalAPI] āœ… Model trained for {base_symbol} with {train_result.get('train_days', 0)} days") else: return { "success": False, "error": f"Failed to train model: {train_result.get('error', 'Unknown error')}", "action_required": "Insufficient data to train model" } except Exception as e: return { "success": False, "error": f"Model training failed: {str(e)}" } # Fetch recent price data (450 days for proper feature calculation) try: end_date = datetime.now() start_date = end_date - timedelta(days=450) df = yf.download(yahoo_symbol, start=start_date, end=end_date, progress=False, auto_adjust=True) if df.empty: return { "success": False, "error": f"Could not fetch price data for {yahoo_symbol}" } # Handle MultiIndex columns if isinstance(df.columns, pd.MultiIndex): if 'Close' in df.columns.get_level_values(0): df.columns = df.columns.get_level_values(0) else: df.columns = df.columns.get_level_values(1) # Get signal from model (use base_symbol for model lookup, yahoo_symbol for data) result = calculate_signal_and_position( symbol=base_symbol, recent_data=df, short_window=12, long_window=26 ) if result is None or 'error' in result: return { "success": False, "error": result.get('error', 'Unknown error') if result else "Failed to calculate signal" } # Determine human-readable signal ema_signal = result.get('ema_signal', 0) target_position = result.get('target_position', 0) position_multiplier = result.get('position_size_multiplier', 1.0) regime = result.get('regime', 1) regime_label = result.get('regime_label', 'Normal') # Generate action recommendation (5-level system: 0x, 0.5x, 1x, 2x, 3x) if target_position == 0: if regime_label == 'Crash': action = "STAY OUT" action_color = "red" action_description = "🚨 Crash Protocol: Safety override activated" else: action = "WAIT" action_color = "yellow" action_description = "Bearish trend - waiting for reversal" elif target_position == 3: action = "STRONG BUY (3x)" action_color = "green" action_description = "šŸš€ Max Leverage: Safe regime + very low risk!" elif target_position == 2: action = "BUY (2x)" action_color = "cyan" action_description = "šŸ“ˆ Medium Leverage: Favorable conditions" elif target_position == 0.5: action = "CAUTIOUS BUY (0.5x)" action_color = "orange" action_description = "āš ļø Defensive: High risk detected" else: action = "BUY (1x)" action_color = "blue" action_description = "āœ… Standard bullish position" return { "success": True, "symbol": symbol, "current_price": result.get('close_price', 0), "signal": { "action": action, "action_color": action_color, "action_description": action_description, "ema_trend": "Bullish" if ema_signal == 1 else "Bearish", "position_multiplier": position_multiplier, "target_position": target_position, "signal_stability": result.get('signal_stability', 0.5), # NEW "ema_gap_percent": result.get('ema_gap_percent', 0) # NEW: Trend strength }, "regime": { "state": regime, "label": regime_label, "description": "Low volatility" if regime == 0 else ("High volatility - danger" if regime_label == 'Crash' else "Normal volatility") }, "risk": { "ratio": result.get('risk_ratio', 1.0), "level": "Low" if result.get('risk_ratio', 1.0) < 0.5 else ("High" if result.get('risk_ratio', 1.0) > 1.5 else "Moderate"), "predicted_volatility": result.get('predicted_vol', 0) }, "technicals": { "ema_short": result.get('ema_short', 0), "ema_long": result.get('ema_long', 0) }, "reasoning": result.get('reasoning', ''), "timestamp": datetime.now().isoformat() } except Exception as e: return { "success": False, "error": f"Error calculating signal: {str(e)}" } # --- SIMULATED TRADING ROUTES --- @app.get("/api/simulated/trades") def get_simulated_trades( limit: int = 50, current_user: str = Depends(get_current_user) ): """Get recent simulated trades for the current user""" from simulated_endpoints import get_simulated_trades_endpoint return get_simulated_trades_endpoint(limit, current_user) @app.get("/api/simulated/sessions") def get_simulated_sessions(current_user: str = Depends(get_current_user)): """Get all simulated trading sessions for the current user""" from simulated_endpoints import get_simulated_sessions_endpoint return get_simulated_sessions_endpoint(current_user) @app.get("/api/simulated/portfolio") def get_simulated_portfolio(current_user: str = Depends(get_current_user)): """Get the internal simulated portfolio (database-driven wallet)""" from simulated_exchange import get_portfolio_summary from database import initialize_portfolio_if_empty # Initialize portfolio with 10k USDT if this is a new user initialize_portfolio_if_empty(user_email=current_user) portfolio = get_portfolio_summary(user_email=current_user) return portfolio @app.post("/api/simulated/start") def start_simulated_session(req: SimulatedTradingRequest, current_user: str = Depends(get_current_user)): """Start HMM-SVR trading bot session""" from simulated_trading import start_simulated_trading from database import initialize_portfolio_if_empty # Initialize portfolio with 10k USDT if this is a new user initialize_portfolio_if_empty(user_email=current_user) duration_minutes = req.duration if req.duration_unit == "days": duration_minutes = req.duration * 24 * 60 result = start_simulated_trading( user_email=current_user, symbol=req.symbol, trade_amount=req.trade_amount, duration_minutes=duration_minutes ) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return result @app.post("/api/simulated/stop/{session_id}") def stop_simulated_session(session_id: str, current_user: str = Depends(get_current_user)): """Stop trading bot session""" from simulated_trading import stop_simulated_trading result = stop_simulated_trading(session_id) if "error" in result: raise HTTPException(status_code=404, detail=result["error"]) return result @app.get("/api/simulated/session/{session_id}") def get_simulated_session(session_id: str, current_user: str = Depends(get_current_user)): """Get bot session status""" from simulated_trading import get_simulated_session_status status = get_simulated_session_status(session_id) if "error" in status: raise HTTPException(status_code=404, detail=status["error"]) return status # --- MANUAL TRADING ROUTES (Market Page) --- class ManualBuyRequest(BaseModel): symbol: str # e.g., 'BTC', 'ETH' usdt_amount: float # Amount in USDT to spend class ManualSellRequest(BaseModel): symbol: str # e.g., 'BTC', 'ETH' quantity: float # Amount of asset to sell class ManualSellPercentRequest(BaseModel): symbol: str # e.g., 'BTC', 'ETH' percentage: float # Percentage of holdings to sell (0-100) @app.post("/api/market/buy") def manual_buy(req: ManualBuyRequest, current_user: str = Depends(get_current_user)): """ Execute a manual buy order from the Market page. This is independent from automated trading bot strategies. Updates portfolio and creates trade log entry. """ from manual_trading import execute_manual_buy from database import initialize_portfolio_if_empty # Ensure user has portfolio initialized initialize_portfolio_if_empty(user_email=current_user) # Validate input if req.usdt_amount <= 0: raise HTTPException(status_code=400, detail="Amount must be positive") if req.usdt_amount < 1: raise HTTPException(status_code=400, detail="Minimum buy amount is 1 USDT") success, trade_info, error = execute_manual_buy( symbol=req.symbol, usdt_amount=req.usdt_amount, user_email=current_user ) if not success: raise HTTPException(status_code=400, detail=error) return { "success": True, "message": f"Successfully bought {trade_info['quantity']:.8f} {req.symbol}", "trade": trade_info } @app.post("/api/market/sell") def manual_sell(req: ManualSellRequest, current_user: str = Depends(get_current_user)): """ Execute a manual sell order from the Market page. This is independent from automated trading bot strategies. Updates portfolio and creates trade log entry. """ from manual_trading import execute_manual_sell from database import initialize_portfolio_if_empty # Ensure user has portfolio initialized initialize_portfolio_if_empty(user_email=current_user) # Validate input if req.quantity <= 0: raise HTTPException(status_code=400, detail="Quantity must be positive") success, trade_info, error = execute_manual_sell( symbol=req.symbol, quantity=req.quantity, user_email=current_user ) if not success: raise HTTPException(status_code=400, detail=error) return { "success": True, "message": f"Successfully sold {trade_info['quantity']:.8f} {req.symbol}", "trade": trade_info } @app.post("/api/market/sell-percent") def manual_sell_percent(req: ManualSellPercentRequest, current_user: str = Depends(get_current_user)): """ Sell a percentage of holdings for a specific asset. Useful for quick "Sell 25%", "Sell 50%", "Sell All" actions. """ from manual_trading import execute_manual_sell, get_user_balance from database import initialize_portfolio_if_empty # Ensure user has portfolio initialized initialize_portfolio_if_empty(user_email=current_user) # Validate percentage if req.percentage <= 0 or req.percentage > 100: raise HTTPException(status_code=400, detail="Percentage must be between 0 and 100") # Get current balance balance = get_user_balance(req.symbol.upper(), current_user) if balance <= 0: raise HTTPException(status_code=400, detail=f"No {req.symbol} holdings to sell") # Calculate quantity to sell quantity_to_sell = balance * (req.percentage / 100) success, trade_info, error = execute_manual_sell( symbol=req.symbol, quantity=quantity_to_sell, user_email=current_user ) if not success: raise HTTPException(status_code=400, detail=error) return { "success": True, "message": f"Successfully sold {req.percentage}% ({trade_info['quantity']:.8f}) {req.symbol}", "trade": trade_info } @app.get("/api/market/trades") def get_manual_trades(limit: int = 50, current_user: str = Depends(get_current_user)): """Get manual trade history for the current user""" from manual_trading import get_manual_trade_history trades = get_manual_trade_history(current_user, limit) return {"trades": trades} @app.get("/api/market/prices") def get_market_prices(current_user: str = Depends(get_current_user)): """ Get current prices for all supported assets. Useful for initial page load before WebSocket connects. """ from manual_trading import get_prices_for_assets prices = get_prices_for_assets() return {"prices": prices} @app.get("/api/market/assets") def get_supported_assets(current_user: str = Depends(get_current_user)): """Get list of supported assets for manual trading""" from manual_trading import SUPPORTED_ASSETS assets = [ {"symbol": "BTC", "name": "Bitcoin", "logo": "₿", "color": "#F7931A"}, {"symbol": "ETH", "name": "Ethereum", "logo": "Īž", "color": "#627EEA"}, {"symbol": "SOL", "name": "Solana", "logo": "ā—Ž", "color": "#14F195"}, {"symbol": "LINK", "name": "Chainlink", "logo": "⬔", "color": "#2A5ADA"}, {"symbol": "DOGE", "name": "Dogecoin", "logo": "Ɛ", "color": "#C2A633"}, {"symbol": "BNB", "name": "BNB", "logo": "⬔", "color": "#F3BA2F"}, ] return {"assets": [a for a in assets if a["symbol"] in SUPPORTED_ASSETS]} @app.get("/api/market/cost-basis/{symbol}") def get_cost_basis(symbol: str, current_user: str = Depends(get_current_user)): """ Get the average cost basis and investment info for a specific asset. Used to show estimated PnL before selling. """ from manual_trading import get_asset_cost_basis, get_current_price_from_binance, TRADING_FEE cost_info = get_asset_cost_basis(symbol.upper(), current_user) # Get current price to calculate unrealized PnL current_price = get_current_price_from_binance(symbol.upper(), "USDT") if current_price and cost_info['balance'] > 0: current_value = current_price * cost_info['balance'] fee_estimate = current_value * TRADING_FEE net_value = current_value - fee_estimate unrealized_pnl = net_value - cost_info['total_invested'] unrealized_pnl_percent = ((net_value / cost_info['total_invested']) - 1) * 100 if cost_info['total_invested'] > 0 else 0.0 else: current_value = 0.0 unrealized_pnl = 0.0 unrealized_pnl_percent = 0.0 return { "symbol": symbol.upper(), "balance": cost_info['balance'], "avg_cost_basis": cost_info['avg_cost_basis'], "total_invested": cost_info['total_invested'], "current_price": current_price, "current_value": current_value, "unrealized_pnl": unrealized_pnl, "unrealized_pnl_percent": unrealized_pnl_percent }