Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request | |
| from fastapi.security import APIKeyHeader | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any | |
| import logging | |
| import threading | |
| import asyncio | |
| import numpy as np | |
| import redis | |
| import json | |
| import os | |
| import hashlib | |
| from core_engine import run_engine | |
| from opentelemetry import trace | |
| from opentelemetry.sdk.trace import TracerProvider | |
| from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter | |
| from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor | |
| # Initialize OpenTelemetry Tracer | |
| provider = TracerProvider() | |
| processor = BatchSpanProcessor(ConsoleSpanExporter()) | |
| provider.add_span_processor(processor) | |
| trace.set_tracer_provider(provider) | |
| tracer = trace.get_tracer(__name__) | |
| app = FastAPI(title="Portfolio Engine API", version="1.0.0") | |
| # Instrument FastAPI for automatic endpoint tracing | |
| FastAPIInstrumentor.instrument_app(app) | |
| API_KEY = os.getenv("API_KEY") | |
| if API_KEY is None: | |
| raise RuntimeError( | |
| "FATAL: API_KEY environment variable must be set. " | |
| "Refusing to start with default credentials." | |
| ) | |
| api_key_header = APIKeyHeader(name="X-API-Key") | |
| def verify_api_key(api_key: str = Depends(api_key_header)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Could not validate credentials") | |
| return api_key | |
| redis_client = redis.Redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"), decode_responses=True) | |
| def rate_limit(request: Request, limit: int = 10, window: int = 60): | |
| ip = request.client.host if request.client else "127.0.0.1" | |
| key = f"rate_limit:{ip}:{request.url.path}" | |
| try: | |
| current = redis_client.get(key) | |
| if current and int(current) >= limit: | |
| raise HTTPException(status_code=429, detail="Too Many Requests") | |
| pipe = redis_client.pipeline() | |
| pipe.incr(key) | |
| pipe.expire(key, window) | |
| pipe.execute() | |
| except redis.RedisError as e: | |
| logging.warning(f"Redis rate limiter failed, bypassing: {e}") | |
| # Global state to hold the latest portfolio for the WebSocket dashboard | |
| GLOBAL_STATE = { | |
| "capital": 0.0, | |
| "weights": {}, | |
| "prices": {}, | |
| "shares": {}, | |
| "pnl": 0.0 | |
| } | |
| import asyncio | |
| GLOBAL_STATE_LOCK = asyncio.Lock() | |
| from pydantic import BaseModel, Field | |
| class PortfolioRequest(BaseModel): | |
| tickers: List[str] = Field(["SPY", "TLT", "GLD"], min_length=1, description="List of asset tickers") | |
| capital: float = Field(100000.0, gt=0, description="Total capital to allocate") | |
| risk: int = Field(5, ge=1, le=10, description="Risk tolerance level (1-10)") | |
| model: int = Field(6, ge=1, le=7, description="1=CAPM, 2=BL, 3=Bayes, 4=FF, 5=ML, 6=E2E, 7=World Model") | |
| engine: int = Field(1, ge=1, le=2, description="Allocation engine (1=Convex, 2=HRP)") | |
| currency: str = Field("$", max_length=5) | |
| days: int = Field(252, ge=1, le=365) | |
| bsts: bool = False | |
| monthly: bool = False | |
| tax: bool = False | |
| excel: bool = False | |
| no_dynamic_risk: bool = False | |
| with_futures: bool = False | |
| overlay_mode: str = Field("beta_hedge", description="Futures overlay mode") | |
| futures_target_beta: float = Field(0.0, ge=-2.0, le=2.0) | |
| futures_universe: List[str] = ["MES", "ES"] | |
| futures_safety_multiplier: float = Field(3.0, ge=1.0, le=10.0) | |
| futures_margin_headroom: float = Field(0.05, ge=0.0, le=0.5) | |
| current_weights: Dict[str, float] = {} | |
| class OptimizationResponse(BaseModel): | |
| status: str | |
| message: str | |
| def get_risk_factor(risk_level: int) -> float: | |
| risk_map = { | |
| 1: 0.1, 2: 0.5, 3: 1.0, 4: 2.0, 5: 3.0, | |
| 6: 5.0, 7: 7.5, 8: 10.0, 9: 15.0, 10: 25.0 | |
| } | |
| return risk_map.get(risk_level, 3.0) | |
| async def run_optimization(req: PortfolioRequest, request: Request, api_key: str = Depends(verify_api_key)): | |
| """Triggers the heavy optimization pipeline natively in Python via cvxpy/ML stack.""" | |
| rate_limit(request, limit=5, window=60) | |
| try: | |
| req_hash = hashlib.sha256(json.dumps(req.model_dump(), sort_keys=True).encode()).hexdigest() | |
| cache_key = f"opt_{req_hash}" | |
| try: | |
| cached_state_json = redis_client.get(cache_key) | |
| if cached_state_json: | |
| logging.info("Returning cached optimization result") | |
| cached_state = json.loads(cached_state_json) | |
| async with GLOBAL_STATE_LOCK: | |
| GLOBAL_STATE.update(cached_state) | |
| return {"status": "success", "message": "Optimization completed successfully (cached)."} | |
| except redis.RedisError as e: | |
| logging.warning(f"Redis cache check failed: {e}") | |
| overrides = { | |
| "tickers": req.tickers, | |
| "capital": req.capital, | |
| "risk_input": req.risk, | |
| "risk_factor": get_risk_factor(req.risk), | |
| "model": req.model, | |
| "allocation_engine": req.engine, | |
| "current_weights_raw": req.current_weights, | |
| "headless": True, | |
| "cfg_overrides": { | |
| "currency_symbol": req.currency, | |
| "trading_days_per_year": req.days, | |
| "bsts_enabled": req.bsts, | |
| "tax_enabled": req.tax, | |
| "dynamic_risk": not req.no_dynamic_risk, | |
| "export_excel": req.excel, | |
| "with_futures": req.with_futures, | |
| "overlay_mode": req.overlay_mode, | |
| "futures_universe": req.futures_universe, | |
| "futures_target_beta": req.futures_target_beta, | |
| "futures_safety_multiplier": req.futures_safety_multiplier, | |
| "futures_margin_headroom": req.futures_margin_headroom, | |
| } | |
| } | |
| if req.monthly: | |
| overrides["cfg_overrides"]["return_frequency"] = "monthly" | |
| import functools | |
| loop = asyncio.get_event_loop() | |
| with tracer.start_as_current_span("run_engine_pipeline_async_task"): | |
| task = loop.run_in_executor(None, functools.partial(run_engine, overrides=overrides)) | |
| try: | |
| opt_res = await task | |
| except asyncio.CancelledError: | |
| logging.info("Optimization task cancelled by client.") | |
| raise | |
| # Populate global state for live streaming | |
| weights = opt_res.get("target_weights", {}) | |
| prices = opt_res.get("prices", {}) | |
| capital = req.capital | |
| shares = {} | |
| for t, w in weights.items(): | |
| if t == 'CASH' or t not in prices: | |
| continue | |
| shares[t] = (capital * w) / prices[t] | |
| state_update = { | |
| "capital": capital, | |
| "weights": weights, | |
| "prices": prices.copy(), | |
| "shares": shares, | |
| "pnl": 0.0 | |
| } | |
| async with GLOBAL_STATE_LOCK: | |
| GLOBAL_STATE.update(state_update) | |
| try: | |
| redis_client.setex(cache_key, 3600, json.dumps(state_update)) | |
| except redis.RedisError as e: | |
| logging.warning(f"Failed to cache result in Redis: {e}") | |
| # Write to Audit Log | |
| try: | |
| from database import get_pg_engine, AuditLog | |
| from sqlalchemy.orm import sessionmaker | |
| engine = get_pg_engine() | |
| Session = sessionmaker(bind=engine) | |
| with Session() as session: | |
| log_entry = AuditLog( | |
| user_id=api_key, | |
| endpoint=request.url.path, | |
| request_hash=req_hash, | |
| request_body=req.model_dump(), | |
| response_weights=weights, | |
| ip_address=request.client.host if request.client else "unknown" | |
| ) | |
| session.add(log_entry) | |
| session.commit() | |
| except Exception as e: | |
| logging.error(f"Failed to write audit log: {e}") | |
| return {"status": "success", "message": "Optimization completed successfully."} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| api_key = websocket.headers.get("X-API-Key") or websocket.query_params.get("api_key") | |
| if api_key != API_KEY: | |
| await websocket.close(code=1008) | |
| return | |
| await websocket.accept() | |
| rng = np.random.default_rng() | |
| try: | |
| while True: | |
| # Skip if not initialized | |
| if not GLOBAL_STATE["shares"]: | |
| await asyncio.sleep(1) | |
| continue | |
| # Simulate continuously 24/7 for dashboard testing purposes | |
| # Simulate a live tick via Geometric Brownian Motion | |
| # Mild volatility parameter for 5-second ticks | |
| dt = 5 / (252 * 23400) # 5 seconds in years (assuming 6.5h trading day) | |
| vol = 0.15 # 15% annualized vol approx | |
| current_value = 0.0 | |
| new_prices = {} | |
| async with GLOBAL_STATE_LOCK: | |
| for t, share_qty in GLOBAL_STATE["shares"].items(): | |
| price = GLOBAL_STATE["prices"].get(t, 100.0) | |
| # Apply small random shock | |
| shock = rng.normal(0, vol * np.sqrt(dt)) | |
| new_price = price * (1 + shock) | |
| GLOBAL_STATE["prices"][t] = new_price | |
| new_prices[t] = round(new_price, 2) | |
| current_value += share_qty * new_price | |
| # Add cash value | |
| cash = GLOBAL_STATE["capital"] * GLOBAL_STATE["weights"].get("CASH", 0.0) | |
| current_value += cash | |
| GLOBAL_STATE["pnl"] = current_value - GLOBAL_STATE["capital"] | |
| payload = { | |
| "type": "live_update", | |
| "capital": round(current_value, 2), | |
| "pnl": round(GLOBAL_STATE["pnl"], 2), | |
| "prices": new_prices | |
| } | |
| await websocket.send_json(payload) | |
| await asyncio.sleep(5) | |
| except WebSocketDisconnect: | |
| logging.info("WebSocket disconnected") | |
| def health_check(): | |
| return {"status": "healthy"} | |