Spaces:
Running
Running
| import torch | |
| from chronos import ChronosPipeline | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any | |
| import uvicorn | |
| app = FastAPI(title="Dolixe Kronos AI Service") | |
| # Load the models | |
| print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.") | |
| pipeline_tiny = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-tiny", | |
| device_map="cpu", # Use "cuda" if you have an NVIDIA GPU | |
| torch_dtype=torch.float32, | |
| ) | |
| print("Loading Kronos (Chronos-T5-Base) model... this may take a bit longer.") | |
| pipeline_base = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-base", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| ) | |
| print("Loading Chronos Bolt Tiny model... this should be extremely fast.") | |
| try: | |
| pipeline_bolt_tiny = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-bolt-tiny", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| ) | |
| except Exception as e: | |
| print(f"FAILED to load Bolt Tiny: {e}. Skipping...") | |
| pipeline_bolt_tiny = None | |
| print("Loading Chronos Bolt Base model... this may take a bit longer but will be faster than T5-Base.") | |
| try: | |
| pipeline_bolt_base = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-bolt-base", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| ) | |
| except Exception as e: | |
| print(f"FAILED to load Bolt Base: {e}. Skipping...") | |
| pipeline_bolt_base = None | |
| class MarketData(BaseModel): | |
| prices: List[float] | |
| horizon: int = 12 | |
| model: Optional[str] = "tiny" | |
| def read_root(): | |
| return {"status": "online", "models": ["chronos-t5-tiny", "chronos-t5-base", "chronos-bolt-tiny", "chronos-bolt-base"]} | |
| async def predict(data: MarketData): | |
| try: | |
| if not data.prices: | |
| raise HTTPException(status_code=400, detail="No price data provided") | |
| # Select pipeline | |
| if data.model == "base": | |
| pipeline = pipeline_base | |
| elif data.model == "bolt-tiny": | |
| pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny | |
| elif data.model == "bolt-base": | |
| pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base | |
| else: | |
| pipeline = pipeline_tiny | |
| # Convert prices to a tensor | |
| context = torch.tensor(data.prices) | |
| # Run prediction with a fixed seed to ensure consistency for the same input | |
| torch.manual_seed(42) | |
| prediction = pipeline.predict(context, data.horizon) | |
| # Set quantiles: Bolt models use tighter P25/P75 (50% zone), T5 uses P10/P90 (80% zone) | |
| is_bolt = data.model.startswith("bolt") | |
| q_low = 0.25 if is_bolt else 0.1 | |
| q_high = 0.75 if is_bolt else 0.9 | |
| # prediction[0] is the result for the first (and only) batch | |
| # We take the median (50th percentile) as our forecast | |
| # We also take the selected quantiles for confidence bands | |
| forecast = prediction[0].median(dim=0).values.tolist() | |
| low_forecast = prediction[0].quantile(q_low, dim=0).tolist() | |
| high_forecast = prediction[0].quantile(q_high, dim=0).tolist() | |
| # Calculate Confidence Score (Monte Carlo Agreement) | |
| # 1. Determine if the median predicts market going UP or DOWN | |
| last_known_price = data.prices[-1] | |
| median_direction_up = forecast[-1] > last_known_price | |
| # 2. See how many of the (default 20) samples agree with this direction | |
| samples = prediction[0] # shape (num_samples, horizon) | |
| final_values = samples[:, -1] | |
| if median_direction_up: | |
| agreeing_samples = (final_values > last_known_price).sum().item() | |
| else: | |
| agreeing_samples = (final_values < last_known_price).sum().item() | |
| confidence_score = round((agreeing_samples / samples.shape[0]) * 100) | |
| # Ensure it presents at least a 50% baseline (since 50/50 is neutral range) | |
| confidence_score = max(50, confidence_score) | |
| return { | |
| "forecast": forecast, | |
| "low": low_forecast, | |
| "high": high_forecast, | |
| "confidence": confidence_score, | |
| "horizon": data.horizon, | |
| "input_size": len(data.prices) | |
| } | |
| except Exception as e: | |
| print(f"Error during prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class RollingBacktestData(BaseModel): | |
| prices: List[float] | |
| window_size: int = 50 | |
| model: Optional[str] = "tiny" | |
| async def backtest_rolling(data: RollingBacktestData): | |
| try: | |
| if len(data.prices) <= data.window_size: | |
| raise HTTPException(status_code=400, detail="Not enough data for rolling backtest") | |
| # Select pipeline | |
| if data.model == "base": | |
| pipeline = pipeline_base | |
| elif data.model == "bolt-tiny": | |
| pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny | |
| elif data.model == "bolt-base": | |
| pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base | |
| else: | |
| pipeline = pipeline_tiny | |
| results = [] | |
| torch.manual_seed(42) | |
| # We loop through the data starting from window_size index | |
| # For each point, we take the preceding window_size prices as context | |
| # and predict the NEXT price. | |
| # To make it faster, we perform 1-step prediction. | |
| for i in range(data.window_size, len(data.prices)): | |
| context_prices = data.prices[i - data.window_size : i] | |
| actual_next_price = data.prices[i] | |
| context_tensor = torch.tensor(context_prices) | |
| prediction = pipeline.predict(context_tensor, 1) # Only 1-step ahead | |
| predicted_median = prediction[0].median(dim=0).values.item() | |
| # Simple Directional logic | |
| last_price = context_prices[-1] | |
| predicted_dir = "UP" if predicted_median > last_price else "DOWN" | |
| actual_dir = "UP" if actual_next_price > last_price else "DOWN" | |
| results.append({ | |
| "index": i, | |
| "predicted": predicted_median, | |
| "actual": actual_next_price, | |
| "predicted_dir": predicted_dir, | |
| "actual_dir": actual_dir, | |
| "correct": predicted_dir == actual_dir | |
| }) | |
| # Calculate overall accuracy | |
| correct_count = sum(1 for r in results if r["correct"]) | |
| hit_rate = (correct_count / len(results)) * 100 if results else 0 | |
| return { | |
| "results": results, | |
| "hit_rate": round(hit_rate, 2), | |
| "total_samples": len(results), | |
| "model_used": data.model | |
| } | |
| except Exception as e: | |
| print(f"Error during rolling backtest: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |