Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from fastapi import FastAPI, Request
|
| 5 |
-
from chronos import ChronosPipeline
|
| 6 |
import uvicorn
|
| 7 |
|
| 8 |
app = FastAPI()
|
|
@@ -14,12 +14,12 @@ pipeline = None
|
|
| 14 |
def load_model():
|
| 15 |
global pipeline
|
| 16 |
print("Loading Chronos model...")
|
| 17 |
-
# Using 'tiny'
|
| 18 |
-
#
|
| 19 |
pipeline = ChronosPipeline.from_pretrained(
|
| 20 |
"amazon/chronos-t5-tiny",
|
| 21 |
-
device_map="
|
| 22 |
-
|
| 23 |
)
|
| 24 |
print("Model loaded successfully.")
|
| 25 |
|
|
@@ -29,37 +29,55 @@ def home():
|
|
| 29 |
|
| 30 |
@app.post("/predict")
|
| 31 |
async def get_forecast(request: Request):
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
|
| 63 |
if __name__ == "__main__":
|
| 64 |
-
#
|
| 65 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from fastapi import FastAPI, Request
|
| 5 |
+
from chronos import ChronosPipeline # This works ONLY if chronos-forecasting is installed
|
| 6 |
import uvicorn
|
| 7 |
|
| 8 |
app = FastAPI()
|
|
|
|
| 14 |
def load_model():
|
| 15 |
global pipeline
|
| 16 |
print("Loading Chronos model...")
|
| 17 |
+
# Using 'tiny' for stability on HF basic instances
|
| 18 |
+
# Ensure dtype is torch.bfloat16 for efficiency
|
| 19 |
pipeline = ChronosPipeline.from_pretrained(
|
| 20 |
"amazon/chronos-t5-tiny",
|
| 21 |
+
device_map="cpu", # Use "cpu" if no GPU, or "auto" if you have a T4
|
| 22 |
+
torch_dtype=torch.bfloat16
|
| 23 |
)
|
| 24 |
print("Model loaded successfully.")
|
| 25 |
|
|
|
|
| 29 |
|
| 30 |
@app.post("/predict")
|
| 31 |
async def get_forecast(request: Request):
|
| 32 |
+
global pipeline
|
| 33 |
+
try:
|
| 34 |
+
body = await request.json()
|
| 35 |
+
history = body.get("history", [])
|
| 36 |
+
user_rr = float(body.get("rr_ratio", 2.0))
|
| 37 |
+
|
| 38 |
+
if not history:
|
| 39 |
+
return {"error": "History array is empty"}
|
| 40 |
|
| 41 |
+
# 1. Convert to tensor and add batch dimension [1, seq_len]
|
| 42 |
+
# Chronos expects a batch dimension
|
| 43 |
+
context = torch.tensor(history, dtype=torch.float32).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
# 2. Prediction logic
|
| 46 |
+
# num_samples=20 gives enough for quantiles without killing RAM
|
| 47 |
+
forecast = pipeline.predict(context, prediction_length=12, num_samples=20)
|
| 48 |
+
|
| 49 |
+
# 3. Extract results (remove batch dimension for processing)
|
| 50 |
+
# forecast shape is [batch, samples, horizon] -> [samples, horizon]
|
| 51 |
+
forecast_np = forecast[0].numpy()
|
| 52 |
+
|
| 53 |
+
low_bound = np.percentile(forecast_np, 10, axis=0)
|
| 54 |
+
median_pred = np.percentile(forecast_np, 50, axis=0)
|
| 55 |
+
high_bound = np.percentile(forecast_np, 90, axis=0)
|
| 56 |
|
| 57 |
+
# 4. Trading Logic
|
| 58 |
+
entry_price = float(history[-1])
|
| 59 |
+
first_pred = float(median_pred[0])
|
| 60 |
+
|
| 61 |
+
# Stop Loss (SL) at the 10th percentile
|
| 62 |
+
sl = float(low_bound[0])
|
| 63 |
+
if sl >= entry_price:
|
| 64 |
+
sl = entry_price * 0.98 # Safety fallback
|
| 65 |
+
|
| 66 |
+
# Take Profit (TP)
|
| 67 |
+
risk = entry_price - sl
|
| 68 |
+
tp = entry_price + (risk * user_rr)
|
| 69 |
|
| 70 |
+
return {
|
| 71 |
+
"prediction": median_pred.tolist(),
|
| 72 |
+
"upper_bound": high_bound.tolist(),
|
| 73 |
+
"lower_bound": low_bound.tolist(),
|
| 74 |
+
"suggested_sl": round(sl, 4),
|
| 75 |
+
"suggested_tp": round(tp, 4),
|
| 76 |
+
"verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry"
|
| 77 |
+
}
|
| 78 |
+
except Exception as e:
|
| 79 |
+
return {"error": str(e)}
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
| 82 |
+
# Port 7860 is required for Hugging Face Spaces
|
| 83 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|