import torch from chronos import ChronosPipeline from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict, Any import uvicorn app = FastAPI(title="Dolixe Kronos AI Service") # Load the models print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.") pipeline_tiny = ChronosPipeline.from_pretrained( "amazon/chronos-t5-tiny", device_map="cpu", # Use "cuda" if you have an NVIDIA GPU torch_dtype=torch.float32, ) print("Loading Kronos (Chronos-T5-Base) model... this may take a bit longer.") pipeline_base = ChronosPipeline.from_pretrained( "amazon/chronos-t5-base", device_map="cpu", torch_dtype=torch.float32, ) print("Loading Chronos Bolt Tiny model... this should be extremely fast.") try: pipeline_bolt_tiny = ChronosPipeline.from_pretrained( "amazon/chronos-bolt-tiny", device_map="cpu", torch_dtype=torch.float32, ) except Exception as e: print(f"FAILED to load Bolt Tiny: {e}. Skipping...") pipeline_bolt_tiny = None print("Loading Chronos Bolt Base model... this may take a bit longer but will be faster than T5-Base.") try: pipeline_bolt_base = ChronosPipeline.from_pretrained( "amazon/chronos-bolt-base", device_map="cpu", torch_dtype=torch.float32, ) except Exception as e: print(f"FAILED to load Bolt Base: {e}. Skipping...") pipeline_bolt_base = None class MarketData(BaseModel): prices: List[float] horizon: int = 12 model: Optional[str] = "tiny" @app.get("/") def read_root(): return {"status": "online", "models": ["chronos-t5-tiny", "chronos-t5-base", "chronos-bolt-tiny", "chronos-bolt-base"]} @app.post("/predict") async def predict(data: MarketData): try: if not data.prices: raise HTTPException(status_code=400, detail="No price data provided") # Select pipeline if data.model == "base": pipeline = pipeline_base elif data.model == "bolt-tiny": pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny elif data.model == "bolt-base": pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base else: pipeline = pipeline_tiny # Convert prices to a tensor context = torch.tensor(data.prices) # Run prediction with a fixed seed to ensure consistency for the same input torch.manual_seed(42) prediction = pipeline.predict(context, data.horizon) # Set quantiles: Bolt models use tighter P25/P75 (50% zone), T5 uses P10/P90 (80% zone) is_bolt = data.model.startswith("bolt") q_low = 0.25 if is_bolt else 0.1 q_high = 0.75 if is_bolt else 0.9 # prediction[0] is the result for the first (and only) batch # We take the median (50th percentile) as our forecast # We also take the selected quantiles for confidence bands forecast = prediction[0].median(dim=0).values.tolist() low_forecast = prediction[0].quantile(q_low, dim=0).tolist() high_forecast = prediction[0].quantile(q_high, dim=0).tolist() # Calculate Confidence Score (Monte Carlo Agreement) # 1. Determine if the median predicts market going UP or DOWN last_known_price = data.prices[-1] median_direction_up = forecast[-1] > last_known_price # 2. See how many of the (default 20) samples agree with this direction samples = prediction[0] # shape (num_samples, horizon) final_values = samples[:, -1] if median_direction_up: agreeing_samples = (final_values > last_known_price).sum().item() else: agreeing_samples = (final_values < last_known_price).sum().item() confidence_score = round((agreeing_samples / samples.shape[0]) * 100) # Ensure it presents at least a 50% baseline (since 50/50 is neutral range) confidence_score = max(50, confidence_score) return { "forecast": forecast, "low": low_forecast, "high": high_forecast, "confidence": confidence_score, "horizon": data.horizon, "input_size": len(data.prices) } except Exception as e: print(f"Error during prediction: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) class RollingBacktestData(BaseModel): prices: List[float] window_size: int = 50 model: Optional[str] = "tiny" @app.post("/backtest-rolling") async def backtest_rolling(data: RollingBacktestData): try: if len(data.prices) <= data.window_size: raise HTTPException(status_code=400, detail="Not enough data for rolling backtest") # Select pipeline if data.model == "base": pipeline = pipeline_base elif data.model == "bolt-tiny": pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny elif data.model == "bolt-base": pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base else: pipeline = pipeline_tiny results = [] torch.manual_seed(42) # We loop through the data starting from window_size index # For each point, we take the preceding window_size prices as context # and predict the NEXT price. # To make it faster, we perform 1-step prediction. for i in range(data.window_size, len(data.prices)): context_prices = data.prices[i - data.window_size : i] actual_next_price = data.prices[i] context_tensor = torch.tensor(context_prices) prediction = pipeline.predict(context_tensor, 1) # Only 1-step ahead predicted_median = prediction[0].median(dim=0).values.item() # Simple Directional logic last_price = context_prices[-1] predicted_dir = "UP" if predicted_median > last_price else "DOWN" actual_dir = "UP" if actual_next_price > last_price else "DOWN" results.append({ "index": i, "predicted": predicted_median, "actual": actual_next_price, "predicted_dir": predicted_dir, "actual_dir": actual_dir, "correct": predicted_dir == actual_dir }) # Calculate overall accuracy correct_count = sum(1 for r in results if r["correct"]) hit_rate = (correct_count / len(results)) * 100 if results else 0 return { "results": results, "hit_rate": round(hit_rate, 2), "total_samples": len(results), "model_used": data.model } except Exception as e: print(f"Error during rolling backtest: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import os port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)