Dolixe commited on
Commit
71e2443
·
verified ·
1 Parent(s): 0b6f8af

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -9
app.py CHANGED
@@ -1,28 +1,58 @@
1
  import torch
2
  from chronos import ChronosPipeline
3
- import pandas as pd
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
- from typing import List
7
  import uvicorn
8
 
9
  app = FastAPI(title="Dolixe Kronos AI Service")
10
 
11
- # Load the model (Tiny version as requested)
12
  print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.")
13
- pipeline = ChronosPipeline.from_pretrained(
14
  "amazon/chronos-t5-tiny",
15
  device_map="cpu", # Use "cuda" if you have an NVIDIA GPU
16
  torch_dtype=torch.float32,
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class MarketData(BaseModel):
20
  prices: List[float]
21
  horizon: int = 12
 
22
 
23
  @app.get("/")
24
  def read_root():
25
- return {"status": "online", "model": "chronos-t5-tiny"}
26
 
27
  @app.post("/predict")
28
  async def predict(data: MarketData):
@@ -30,6 +60,16 @@ async def predict(data: MarketData):
30
  if not data.prices:
31
  raise HTTPException(status_code=400, detail="No price data provided")
32
 
 
 
 
 
 
 
 
 
 
 
33
  # Convert prices to a tensor
34
  context = torch.tensor(data.prices)
35
 
@@ -37,13 +77,17 @@ async def predict(data: MarketData):
37
  torch.manual_seed(42)
38
  prediction = pipeline.predict(context, data.horizon)
39
 
 
 
 
 
 
40
  # prediction[0] is the result for the first (and only) batch
41
  # We take the median (50th percentile) as our forecast
42
- # We also take 10th and 90th percentiles for confidence bands
43
- # Shape is (samples, horizon)
44
  forecast = prediction[0].median(dim=0).values.tolist()
45
- low_forecast = prediction[0].quantile(0.1, dim=0).tolist()
46
- high_forecast = prediction[0].quantile(0.9, dim=0).tolist()
47
 
48
  # Calculate Confidence Score (Monte Carlo Agreement)
49
  # 1. Determine if the median predicts market going UP or DOWN
@@ -75,6 +119,72 @@ async def predict(data: MarketData):
75
  print(f"Error during prediction: {str(e)}")
76
  raise HTTPException(status_code=500, detail=str(e))
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
  import os
80
  port = int(os.environ.get("PORT", 8000))
 
1
  import torch
2
  from chronos import ChronosPipeline
3
+
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
+ from typing import List, Optional, Dict, Any
7
  import uvicorn
8
 
9
  app = FastAPI(title="Dolixe Kronos AI Service")
10
 
11
+ # Load the models
12
  print("Loading Kronos (Chronos-T5-Tiny) model... this may take a minute on first run.")
13
+ pipeline_tiny = ChronosPipeline.from_pretrained(
14
  "amazon/chronos-t5-tiny",
15
  device_map="cpu", # Use "cuda" if you have an NVIDIA GPU
16
  torch_dtype=torch.float32,
17
  )
18
 
19
+ print("Loading Kronos (Chronos-T5-Base) model... this may take a bit longer.")
20
+ pipeline_base = ChronosPipeline.from_pretrained(
21
+ "amazon/chronos-t5-base",
22
+ device_map="cpu",
23
+ torch_dtype=torch.float32,
24
+ )
25
+
26
+ print("Loading Chronos Bolt Tiny model... this should be extremely fast.")
27
+ try:
28
+ pipeline_bolt_tiny = ChronosPipeline.from_pretrained(
29
+ "amazon/chronos-bolt-tiny",
30
+ device_map="cpu",
31
+ torch_dtype=torch.float32,
32
+ )
33
+ except Exception as e:
34
+ print(f"FAILED to load Bolt Tiny: {e}. Skipping...")
35
+ pipeline_bolt_tiny = None
36
+
37
+ print("Loading Chronos Bolt Base model... this may take a bit longer but will be faster than T5-Base.")
38
+ try:
39
+ pipeline_bolt_base = ChronosPipeline.from_pretrained(
40
+ "amazon/chronos-bolt-base",
41
+ device_map="cpu",
42
+ torch_dtype=torch.float32,
43
+ )
44
+ except Exception as e:
45
+ print(f"FAILED to load Bolt Base: {e}. Skipping...")
46
+ pipeline_bolt_base = None
47
+
48
  class MarketData(BaseModel):
49
  prices: List[float]
50
  horizon: int = 12
51
+ model: Optional[str] = "tiny"
52
 
53
  @app.get("/")
54
  def read_root():
55
+ return {"status": "online", "models": ["chronos-t5-tiny", "chronos-t5-base", "chronos-bolt-tiny", "chronos-bolt-base"]}
56
 
57
  @app.post("/predict")
58
  async def predict(data: MarketData):
 
60
  if not data.prices:
61
  raise HTTPException(status_code=400, detail="No price data provided")
62
 
63
+ # Select pipeline
64
+ if data.model == "base":
65
+ pipeline = pipeline_base
66
+ elif data.model == "bolt-tiny":
67
+ pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
68
+ elif data.model == "bolt-base":
69
+ pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
70
+ else:
71
+ pipeline = pipeline_tiny
72
+
73
  # Convert prices to a tensor
74
  context = torch.tensor(data.prices)
75
 
 
77
  torch.manual_seed(42)
78
  prediction = pipeline.predict(context, data.horizon)
79
 
80
+ # Set quantiles: Bolt models use tighter P25/P75 (50% zone), T5 uses P10/P90 (80% zone)
81
+ is_bolt = data.model.startswith("bolt")
82
+ q_low = 0.25 if is_bolt else 0.1
83
+ q_high = 0.75 if is_bolt else 0.9
84
+
85
  # prediction[0] is the result for the first (and only) batch
86
  # We take the median (50th percentile) as our forecast
87
+ # We also take the selected quantiles for confidence bands
 
88
  forecast = prediction[0].median(dim=0).values.tolist()
89
+ low_forecast = prediction[0].quantile(q_low, dim=0).tolist()
90
+ high_forecast = prediction[0].quantile(q_high, dim=0).tolist()
91
 
92
  # Calculate Confidence Score (Monte Carlo Agreement)
93
  # 1. Determine if the median predicts market going UP or DOWN
 
119
  print(f"Error during prediction: {str(e)}")
120
  raise HTTPException(status_code=500, detail=str(e))
121
 
122
+ class RollingBacktestData(BaseModel):
123
+ prices: List[float]
124
+ window_size: int = 50
125
+ model: Optional[str] = "tiny"
126
+
127
+ @app.post("/backtest-rolling")
128
+ async def backtest_rolling(data: RollingBacktestData):
129
+ try:
130
+ if len(data.prices) <= data.window_size:
131
+ raise HTTPException(status_code=400, detail="Not enough data for rolling backtest")
132
+
133
+ # Select pipeline
134
+ if data.model == "base":
135
+ pipeline = pipeline_base
136
+ elif data.model == "bolt-tiny":
137
+ pipeline = pipeline_bolt_tiny if pipeline_bolt_tiny else pipeline_tiny
138
+ elif data.model == "bolt-base":
139
+ pipeline = pipeline_bolt_base if pipeline_bolt_base else pipeline_base
140
+ else:
141
+ pipeline = pipeline_tiny
142
+
143
+ results = []
144
+ torch.manual_seed(42)
145
+
146
+ # We loop through the data starting from window_size index
147
+ # For each point, we take the preceding window_size prices as context
148
+ # and predict the NEXT price.
149
+ # To make it faster, we perform 1-step prediction.
150
+
151
+ for i in range(data.window_size, len(data.prices)):
152
+ context_prices = data.prices[i - data.window_size : i]
153
+ actual_next_price = data.prices[i]
154
+
155
+ context_tensor = torch.tensor(context_prices)
156
+ prediction = pipeline.predict(context_tensor, 1) # Only 1-step ahead
157
+
158
+ predicted_median = prediction[0].median(dim=0).values.item()
159
+
160
+ # Simple Directional logic
161
+ last_price = context_prices[-1]
162
+ predicted_dir = "UP" if predicted_median > last_price else "DOWN"
163
+ actual_dir = "UP" if actual_next_price > last_price else "DOWN"
164
+
165
+ results.append({
166
+ "index": i,
167
+ "predicted": predicted_median,
168
+ "actual": actual_next_price,
169
+ "predicted_dir": predicted_dir,
170
+ "actual_dir": actual_dir,
171
+ "correct": predicted_dir == actual_dir
172
+ })
173
+
174
+ # Calculate overall accuracy
175
+ correct_count = sum(1 for r in results if r["correct"])
176
+ hit_rate = (correct_count / len(results)) * 100 if results else 0
177
+
178
+ return {
179
+ "results": results,
180
+ "hit_rate": round(hit_rate, 2),
181
+ "total_samples": len(results),
182
+ "model_used": data.model
183
+ }
184
+ except Exception as e:
185
+ print(f"Error during rolling backtest: {str(e)}")
186
+ raise HTTPException(status_code=500, detail=str(e))
187
+
188
  if __name__ == "__main__":
189
  import os
190
  port = int(os.environ.get("PORT", 8000))