Spaces:
Running
Running
File size: 7,409 Bytes
686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 71e2443 686c5ac 0b6f8af 686c5ac 71e2443 686c5ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import torch
from chronos import ChronosPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uvicorn
app = FastAPI(title="Dolixe Kronos AI Service")
# Load the models
print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.")
pipeline_tiny = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-tiny",
device_map="cpu", # Use "cuda" if you have an NVIDIA GPU
torch_dtype=torch.float32,
)
print("Loading Kronos (Chronos-T5-Base) model... this may take a bit longer.")
pipeline_base = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map="cpu",
torch_dtype=torch.float32,
)
print("Loading Chronos Bolt Tiny model... this should be extremely fast.")
try:
pipeline_bolt_tiny = ChronosPipeline.from_pretrained(
"amazon/chronos-bolt-tiny",
device_map="cpu",
torch_dtype=torch.float32,
)
except Exception as e:
print(f"FAILED to load Bolt Tiny: {e}. Skipping...")
pipeline_bolt_tiny = None
print("Loading Chronos Bolt Base model... this may take a bit longer but will be faster than T5-Base.")
try:
pipeline_bolt_base = ChronosPipeline.from_pretrained(
"amazon/chronos-bolt-base",
device_map="cpu",
torch_dtype=torch.float32,
)
except Exception as e:
print(f"FAILED to load Bolt Base: {e}. Skipping...")
pipeline_bolt_base = None
class MarketData(BaseModel):
prices: List[float]
horizon: int = 12
model: Optional[str] = "tiny"
@app.get("/")
def read_root():
return {"status": "online", "models": ["chronos-t5-tiny", "chronos-t5-base", "chronos-bolt-tiny", "chronos-bolt-base"]}
@app.post("/predict")
async def predict(data: MarketData):
try:
if not data.prices:
raise HTTPException(status_code=400, detail="No price data provided")
# Select pipeline
if data.model == "base":
pipeline = pipeline_base
elif data.model == "bolt-tiny":
pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
elif data.model == "bolt-base":
pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
else:
pipeline = pipeline_tiny
# Convert prices to a tensor
context = torch.tensor(data.prices)
# Run prediction with a fixed seed to ensure consistency for the same input
torch.manual_seed(42)
prediction = pipeline.predict(context, data.horizon)
# Set quantiles: Bolt models use tighter P25/P75 (50% zone), T5 uses P10/P90 (80% zone)
is_bolt = data.model.startswith("bolt")
q_low = 0.25 if is_bolt else 0.1
q_high = 0.75 if is_bolt else 0.9
# prediction[0] is the result for the first (and only) batch
# We take the median (50th percentile) as our forecast
# We also take the selected quantiles for confidence bands
forecast = prediction[0].median(dim=0).values.tolist()
low_forecast = prediction[0].quantile(q_low, dim=0).tolist()
high_forecast = prediction[0].quantile(q_high, dim=0).tolist()
# Calculate Confidence Score (Monte Carlo Agreement)
# 1. Determine if the median predicts market going UP or DOWN
last_known_price = data.prices[-1]
median_direction_up = forecast[-1] > last_known_price
# 2. See how many of the (default 20) samples agree with this direction
samples = prediction[0] # shape (num_samples, horizon)
final_values = samples[:, -1]
if median_direction_up:
agreeing_samples = (final_values > last_known_price).sum().item()
else:
agreeing_samples = (final_values < last_known_price).sum().item()
confidence_score = round((agreeing_samples / samples.shape[0]) * 100)
# Ensure it presents at least a 50% baseline (since 50/50 is neutral range)
confidence_score = max(50, confidence_score)
return {
"forecast": forecast,
"low": low_forecast,
"high": high_forecast,
"confidence": confidence_score,
"horizon": data.horizon,
"input_size": len(data.prices)
}
except Exception as e:
print(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
class RollingBacktestData(BaseModel):
prices: List[float]
window_size: int = 50
model: Optional[str] = "tiny"
@app.post("/backtest-rolling")
async def backtest_rolling(data: RollingBacktestData):
try:
if len(data.prices) <= data.window_size:
raise HTTPException(status_code=400, detail="Not enough data for rolling backtest")
# Select pipeline
if data.model == "base":
pipeline = pipeline_base
elif data.model == "bolt-tiny":
pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
elif data.model == "bolt-base":
pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
else:
pipeline = pipeline_tiny
results = []
torch.manual_seed(42)
# We loop through the data starting from window_size index
# For each point, we take the preceding window_size prices as context
# and predict the NEXT price.
# To make it faster, we perform 1-step prediction.
for i in range(data.window_size, len(data.prices)):
context_prices = data.prices[i - data.window_size : i]
actual_next_price = data.prices[i]
context_tensor = torch.tensor(context_prices)
prediction = pipeline.predict(context_tensor, 1) # Only 1-step ahead
predicted_median = prediction[0].median(dim=0).values.item()
# Simple Directional logic
last_price = context_prices[-1]
predicted_dir = "UP" if predicted_median > last_price else "DOWN"
actual_dir = "UP" if actual_next_price > last_price else "DOWN"
results.append({
"index": i,
"predicted": predicted_median,
"actual": actual_next_price,
"predicted_dir": predicted_dir,
"actual_dir": actual_dir,
"correct": predicted_dir == actual_dir
})
# Calculate overall accuracy
correct_count = sum(1 for r in results if r["correct"])
hit_rate = (correct_count / len(results)) * 100 if results else 0
return {
"results": results,
"hit_rate": round(hit_rate, 2),
"total_samples": len(results),
"model_used": data.model
}
except Exception as e:
print(f"Error during rolling backtest: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
|