dolixecharting / app.py
Dolixe's picture
Upload app.py
71e2443 verified
import torch
from chronos import ChronosPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uvicorn
app = FastAPI(title="Dolixe Kronos AI Service")
# Load the models
print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.")
pipeline_tiny = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-tiny",
device_map="cpu", # Use "cuda" if you have an NVIDIA GPU
torch_dtype=torch.float32,
)
print("Loading Kronos (Chronos-T5-Base) model... this may take a bit longer.")
pipeline_base = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map="cpu",
torch_dtype=torch.float32,
)
print("Loading Chronos Bolt Tiny model... this should be extremely fast.")
try:
pipeline_bolt_tiny = ChronosPipeline.from_pretrained(
"amazon/chronos-bolt-tiny",
device_map="cpu",
torch_dtype=torch.float32,
)
except Exception as e:
print(f"FAILED to load Bolt Tiny: {e}. Skipping...")
pipeline_bolt_tiny = None
print("Loading Chronos Bolt Base model... this may take a bit longer but will be faster than T5-Base.")
try:
pipeline_bolt_base = ChronosPipeline.from_pretrained(
"amazon/chronos-bolt-base",
device_map="cpu",
torch_dtype=torch.float32,
)
except Exception as e:
print(f"FAILED to load Bolt Base: {e}. Skipping...")
pipeline_bolt_base = None
class MarketData(BaseModel):
prices: List[float]
horizon: int = 12
model: Optional[str] = "tiny"
@app.get("/")
def read_root():
return {"status": "online", "models": ["chronos-t5-tiny", "chronos-t5-base", "chronos-bolt-tiny", "chronos-bolt-base"]}
@app.post("/predict")
async def predict(data: MarketData):
try:
if not data.prices:
raise HTTPException(status_code=400, detail="No price data provided")
# Select pipeline
if data.model == "base":
pipeline = pipeline_base
elif data.model == "bolt-tiny":
pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
elif data.model == "bolt-base":
pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
else:
pipeline = pipeline_tiny
# Convert prices to a tensor
context = torch.tensor(data.prices)
# Run prediction with a fixed seed to ensure consistency for the same input
torch.manual_seed(42)
prediction = pipeline.predict(context, data.horizon)
# Set quantiles: Bolt models use tighter P25/P75 (50% zone), T5 uses P10/P90 (80% zone)
is_bolt = data.model.startswith("bolt")
q_low = 0.25 if is_bolt else 0.1
q_high = 0.75 if is_bolt else 0.9
# prediction[0] is the result for the first (and only) batch
# We take the median (50th percentile) as our forecast
# We also take the selected quantiles for confidence bands
forecast = prediction[0].median(dim=0).values.tolist()
low_forecast = prediction[0].quantile(q_low, dim=0).tolist()
high_forecast = prediction[0].quantile(q_high, dim=0).tolist()
# Calculate Confidence Score (Monte Carlo Agreement)
# 1. Determine if the median predicts market going UP or DOWN
last_known_price = data.prices[-1]
median_direction_up = forecast[-1] > last_known_price
# 2. See how many of the (default 20) samples agree with this direction
samples = prediction[0] # shape (num_samples, horizon)
final_values = samples[:, -1]
if median_direction_up:
agreeing_samples = (final_values > last_known_price).sum().item()
else:
agreeing_samples = (final_values < last_known_price).sum().item()
confidence_score = round((agreeing_samples / samples.shape[0]) * 100)
# Ensure it presents at least a 50% baseline (since 50/50 is neutral range)
confidence_score = max(50, confidence_score)
return {
"forecast": forecast,
"low": low_forecast,
"high": high_forecast,
"confidence": confidence_score,
"horizon": data.horizon,
"input_size": len(data.prices)
}
except Exception as e:
print(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
class RollingBacktestData(BaseModel):
prices: List[float]
window_size: int = 50
model: Optional[str] = "tiny"
@app.post("/backtest-rolling")
async def backtest_rolling(data: RollingBacktestData):
try:
if len(data.prices) <= data.window_size:
raise HTTPException(status_code=400, detail="Not enough data for rolling backtest")
# Select pipeline
if data.model == "base":
pipeline = pipeline_base
elif data.model == "bolt-tiny":
pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
elif data.model == "bolt-base":
pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
else:
pipeline = pipeline_tiny
results = []
torch.manual_seed(42)
# We loop through the data starting from window_size index
# For each point, we take the preceding window_size prices as context
# and predict the NEXT price.
# To make it faster, we perform 1-step prediction.
for i in range(data.window_size, len(data.prices)):
context_prices = data.prices[i - data.window_size : i]
actual_next_price = data.prices[i]
context_tensor = torch.tensor(context_prices)
prediction = pipeline.predict(context_tensor, 1) # Only 1-step ahead
predicted_median = prediction[0].median(dim=0).values.item()
# Simple Directional logic
last_price = context_prices[-1]
predicted_dir = "UP" if predicted_median > last_price else "DOWN"
actual_dir = "UP" if actual_next_price > last_price else "DOWN"
results.append({
"index": i,
"predicted": predicted_median,
"actual": actual_next_price,
"predicted_dir": predicted_dir,
"actual_dir": actual_dir,
"correct": predicted_dir == actual_dir
})
# Calculate overall accuracy
correct_count = sum(1 for r in results if r["correct"])
hit_rate = (correct_count / len(results)) * 100 if results else 0
return {
"results": results,
"hit_rate": round(hit_rate, 2),
"total_samples": len(results),
"model_used": data.model
}
except Exception as e:
print(f"Error during rolling backtest: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)