Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request | |
| from fastapi.security import APIKeyHeader | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any | |
| import logging | |
| import threading | |
| import asyncio | |
| import numpy as np | |
| import redis | |
| import json | |
| import os | |
| import hashlib | |
| from core_engine import run_engine | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| from opentelemetry import trace | |
| from opentelemetry.sdk.trace import TracerProvider | |
| from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter | |
| from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor | |
| # Initialize OpenTelemetry Tracer | |
| provider = TracerProvider() | |
| processor = BatchSpanProcessor(ConsoleSpanExporter()) | |
| provider.add_span_processor(processor) | |
| trace.set_tracer_provider(provider) | |
| tracer = trace.get_tracer(__name__) | |
| app = FastAPI(title="Portfolio Engine API", version="1.0.0") | |
| # Instrument FastAPI for automatic endpoint tracing | |
| FastAPIInstrumentor.instrument_app(app) | |
| API_KEY = os.getenv("API_KEY") | |
| if API_KEY is None: | |
| raise RuntimeError( | |
| "FATAL: API_KEY environment variable must be set. " | |
| "Refusing to start with default credentials." | |
| ) | |
| api_key_header = APIKeyHeader(name="X-API-Key") | |
| def verify_api_key(api_key: str = Depends(api_key_header)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Could not validate credentials") | |
| return api_key | |
| redis_client = redis.Redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"), decode_responses=True) | |
| def rate_limit(request: Request, limit: int = 10, window: int = 60): | |
| ip = request.client.host if request.client else "127.0.0.1" | |
| key = f"rate_limit:{ip}:{request.url.path}" | |
| try: | |
| current = redis_client.get(key) | |
| if current and int(current) >= limit: | |
| raise HTTPException(status_code=429, detail="Too Many Requests") | |
| pipe = redis_client.pipeline() | |
| pipe.incr(key) | |
| pipe.expire(key, window) | |
| pipe.execute() | |
| except redis.RedisError as e: | |
| logging.warning(f"Redis rate limiter failed, bypassing: {e}") | |
| # Global state to hold the latest portfolio for the WebSocket dashboard | |
| GLOBAL_STATE = { | |
| "capital": 0.0, | |
| "weights": {}, | |
| "prices": {}, | |
| "shares": {}, | |
| "pnl": 0.0 | |
| } | |
| import asyncio | |
| GLOBAL_STATE_LOCK = asyncio.Lock() | |
| from pydantic import BaseModel, Field | |
| class PortfolioRequest(BaseModel): | |
| tickers: List[str] = Field(["SPY", "TLT", "GLD"], min_length=1, description="List of asset tickers") | |
| capital: float = Field(100000.0, gt=0, description="Total capital to allocate") | |
| risk: int = Field(5, ge=1, le=10, description="Risk tolerance level (1-10)") | |
| model: int = Field(6, ge=1, le=7, description="1=CAPM, 2=BL, 3=Bayes, 4=FF, 5=ML, 6=E2E, 7=World Model") | |
| engine: int = Field(1, ge=1, le=2, description="Allocation engine (1=Convex, 2=HRP)") | |
| currency: str = Field("$", max_length=5) | |
| days: int = Field(252, ge=1, le=365) | |
| bsts: bool = False | |
| monthly: bool = False | |
| tax: bool = False | |
| excel: bool = False | |
| no_dynamic_risk: bool = False | |
| with_futures: bool = False | |
| overlay_mode: str = Field("beta_hedge", description="Futures overlay mode") | |
| futures_target_beta: float = Field(0.0, ge=-2.0, le=2.0) | |
| futures_universe: List[str] = ["MES", "ES"] | |
| futures_safety_multiplier: float = Field(3.0, ge=1.0, le=10.0) | |
| futures_margin_headroom: float = Field(0.05, ge=0.0, le=0.5) | |
| current_weights: Dict[str, float] = {} | |
| class OptimizationResponse(BaseModel): | |
| status: str | |
| message: str | |
| def get_risk_factor(risk_level: int) -> float: | |
| risk_map = { | |
| 1: 0.1, 2: 0.5, 3: 1.0, 4: 2.0, 5: 3.0, | |
| 6: 5.0, 7: 7.5, 8: 10.0, 9: 15.0, 10: 25.0 | |
| } | |
| return risk_map.get(risk_level, 3.0) | |
| async def run_optimization(req: PortfolioRequest, request: Request, api_key: str = Depends(verify_api_key)): | |
| """Triggers the heavy optimization pipeline natively in Python via cvxpy/ML stack.""" | |
| rate_limit(request, limit=5, window=60) | |
| try: | |
| req_hash = hashlib.sha256(json.dumps(req.model_dump(), sort_keys=True).encode()).hexdigest() | |
| cache_key = f"opt_{req_hash}" | |
| try: | |
| cached_state_json = redis_client.get(cache_key) | |
| if cached_state_json: | |
| logging.info("Returning cached optimization result") | |
| cached_state = json.loads(cached_state_json) | |
| async with GLOBAL_STATE_LOCK: | |
| GLOBAL_STATE.update(cached_state) | |
| return {"status": "success", "message": "Optimization completed successfully (cached)."} | |
| except redis.RedisError as e: | |
| logging.warning(f"Redis cache check failed: {e}") | |
| overrides = { | |
| "tickers": req.tickers, | |
| "capital": req.capital, | |
| "risk_input": req.risk, | |
| "risk_factor": get_risk_factor(req.risk), | |
| "model": req.model, | |
| "allocation_engine": req.engine, | |
| "current_weights_raw": req.current_weights, | |
| "headless": True, | |
| "cfg_overrides": { | |
| "currency_symbol": req.currency, | |
| "trading_days_per_year": req.days, | |
| "bsts_enabled": req.bsts, | |
| "tax_enabled": req.tax, | |
| "dynamic_risk": not req.no_dynamic_risk, | |
| "export_excel": req.excel, | |
| "with_futures": req.with_futures, | |
| "overlay_mode": req.overlay_mode, | |
| "futures_universe": req.futures_universe, | |
| "futures_target_beta": req.futures_target_beta, | |
| "futures_safety_multiplier": req.futures_safety_multiplier, | |
| "futures_margin_headroom": req.futures_margin_headroom, | |
| } | |
| } | |
| if req.monthly: | |
| overrides["cfg_overrides"]["return_frequency"] = "monthly" | |
| import functools | |
| loop = asyncio.get_event_loop() | |
| with tracer.start_as_current_span("run_engine_pipeline_async_task"): | |
| task = loop.run_in_executor(None, functools.partial(run_engine, overrides=overrides)) | |
| try: | |
| opt_res = await task | |
| except asyncio.CancelledError: | |
| logging.info("Optimization task cancelled by client.") | |
| raise | |
| # Populate global state for live streaming | |
| weights = opt_res.get("target_weights", {}) | |
| prices = opt_res.get("prices", {}) | |
| capital = req.capital | |
| shares = {} | |
| for t, w in weights.items(): | |
| if t == 'CASH' or t not in prices: | |
| continue | |
| shares[t] = (capital * w) / prices[t] | |
| state_update = { | |
| "capital": capital, | |
| "weights": weights, | |
| "prices": prices.copy(), | |
| "shares": shares, | |
| "pnl": 0.0 | |
| } | |
| async with GLOBAL_STATE_LOCK: | |
| GLOBAL_STATE.update(state_update) | |
| try: | |
| redis_client.setex(cache_key, 3600, json.dumps(state_update)) | |
| except redis.RedisError as e: | |
| logging.warning(f"Failed to cache result in Redis: {e}") | |
| # Write to Audit Log | |
| try: | |
| from database import get_pg_engine, AuditLog | |
| from sqlalchemy.orm import sessionmaker | |
| engine = get_pg_engine() | |
| Session = sessionmaker(bind=engine) | |
| with Session() as session: | |
| log_entry = AuditLog( | |
| user_id=api_key, | |
| endpoint=request.url.path, | |
| request_hash=req_hash, | |
| request_body=req.model_dump(), | |
| response_weights=weights, | |
| ip_address=request.client.host if request.client else "unknown" | |
| ) | |
| session.add(log_entry) | |
| session.commit() | |
| except Exception as e: | |
| logging.error(f"Failed to write audit log: {e}") | |
| return {"status": "success", "message": "Optimization completed successfully."} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| api_key = websocket.headers.get("X-API-Key") or websocket.query_params.get("api_key") | |
| if api_key != API_KEY: | |
| await websocket.close(code=1008) | |
| return | |
| await websocket.accept() | |
| rng = np.random.default_rng() | |
| try: | |
| while True: | |
| if not GLOBAL_STATE["shares"]: | |
| await asyncio.sleep(1) | |
| continue | |
| async with GLOBAL_STATE_LOCK: | |
| tickers_list = list(GLOBAL_STATE["shares"].keys()) | |
| if tickers_list: | |
| try: | |
| # Fetch real live data | |
| import yfinance as yf | |
| tickers_str = " ".join(tickers_list) | |
| data = yf.download(tickers_str, period="1d", interval="1m", progress=False) | |
| if not data.empty and 'Close' in data: | |
| close_data = data['Close'] | |
| current_value = 0.0 | |
| new_prices = {} | |
| async with GLOBAL_STATE_LOCK: | |
| for t, share_qty in GLOBAL_STATE["shares"].items(): | |
| try: | |
| # Handle MultiIndex for multiple tickers vs SingleIndex for one ticker | |
| if len(tickers_list) > 1: | |
| if t in close_data.columns: | |
| price = float(close_data[t].iloc[-1]) | |
| else: | |
| price = GLOBAL_STATE["prices"].get(t, 100.0) | |
| else: | |
| price = float(close_data.iloc[-1]) | |
| if not pd.isna(price): | |
| GLOBAL_STATE["prices"][t] = price | |
| new_prices[t] = round(price, 2) | |
| current_value += share_qty * price | |
| except Exception as e: | |
| logging.error(f"Error extracting price for {t}: {e}") | |
| cash = GLOBAL_STATE["capital"] * GLOBAL_STATE["weights"].get("CASH", 0.0) | |
| current_value += cash | |
| GLOBAL_STATE["pnl"] = current_value - GLOBAL_STATE["capital"] | |
| payload = { | |
| "type": "live_update", | |
| "capital": round(current_value, 2), | |
| "pnl": round(GLOBAL_STATE["pnl"], 2), | |
| "prices": new_prices | |
| } | |
| await websocket.send_json(payload) | |
| except Exception as e: | |
| logging.error(f"Error fetching live data: {e}") | |
| await asyncio.sleep(10) | |
| except WebSocketDisconnect: | |
| logging.info("WebSocket disconnected") | |
| def health_check(): | |
| return {"status": "healthy"} | |
| async def ping(): | |
| """Endpoint for UptimeRobot to ping Render, which in turn pings HF to keep both awake.""" | |
| hf_url = os.getenv("HF_BACKEND_URL", "https://engineportf-portfolio-opt.hf.space").rstrip('/') | |
| import requests | |
| try: | |
| requests.get(f"{hf_url}/", timeout=10) | |
| except: | |
| pass | |
| return {"status": "awake"} | |
| class ChatRequest(BaseModel): | |
| message: str | |
| portfolio_context: dict | |
| async def chat_with_portfolio(req: ChatRequest): | |
| try: | |
| from huggingface_hub import InferenceClient | |
| has_hf_hub = True | |
| except ImportError: | |
| has_hf_hub = False | |
| if not has_hf_hub: | |
| raise HTTPException(status_code=500, detail="huggingface_hub is not installed on the server.") | |
| try: | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| if not hf_token: | |
| return {"status": "error", "detail": "AI is disabled. Please add 'HF_TOKEN' to your Hugging Face Space Secrets to enable the AI."} | |
| system_prompt = ( | |
| "You are an elite quantitative analyst AI. " | |
| "You are explaining the user's mathematical portfolio allocation. " | |
| "Never give explicit financial advice (e.g. 'You must buy this stock'). " | |
| "Only explain WHY the math chose these weights based on the user's inputs and market metrics. " | |
| f"Here is the user's current mathematically optimized portfolio context: {req.portfolio_context}" | |
| ) | |
| prompt = f"<s>[INST] {system_prompt}\n\nContext:\n{req.portfolio_context}\n\nUser: {req.message} [/INST]" | |
| try: | |
| from huggingface_hub import InferenceClient | |
| client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=hf_token) | |
| response = client.text_generation(prompt, max_new_tokens=500, temperature=0.3, return_full_text=False) | |
| return {"status": "success", "response": response.strip()} | |
| except Exception as client_err: | |
| logging.warning(f"InferenceClient failed: {client_err}. Falling back to requests.") | |
| import requests | |
| api_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3" | |
| headers = {"Authorization": f"Bearer {hf_token}"} | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": {"max_new_tokens": 500, "temperature": 0.3, "return_full_text": False} | |
| } | |
| try: | |
| res = requests.post(api_url, headers=headers, json=payload, timeout=60) | |
| if res.ok: | |
| data = res.json() | |
| if isinstance(data, list) and len(data) > 0: | |
| response_text = data[0].get("generated_text", "AI response empty.") | |
| return {"status": "success", "response": response_text.strip()} | |
| elif isinstance(data, dict) and "error" in data: | |
| return {"status": "error", "detail": f"Hugging Face AI Error: {data['error']}"} | |
| else: | |
| return {"status": "success", "response": str(data)} | |
| else: | |
| return {"status": "error", "detail": f"Hugging Face API Error: {res.status_code} - {res.text}"} | |
| except Exception as req_err: | |
| return {"status": "error", "detail": f"AI temporarily unavailable due to server networking issues (DNS): {req_err}"} | |
| except Exception as e: | |
| logging.error(f"AI Chat error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |