Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,63 +1,69 @@
|
|
| 1 |
-
from fastapi import FastAPI, Request
|
| 2 |
-
from chronos import ChronosPipeline
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = FastAPI()
|
| 7 |
|
| 8 |
-
#
|
| 9 |
pipeline = ChronosPipeline.from_pretrained(
|
| 10 |
-
"amazon/chronos-t5-small",
|
| 11 |
device_map="auto",
|
| 12 |
-
|
| 13 |
)
|
| 14 |
|
| 15 |
@app.post("/predict")
|
| 16 |
-
async def
|
| 17 |
-
data
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
# returns
|
| 27 |
-
forecast = pipeline.predict(context, prediction_length)
|
| 28 |
|
| 29 |
-
# Extract Quantiles
|
|
|
|
| 30 |
low_bound = np.percentile(forecast.numpy(), 10, axis=0)
|
| 31 |
median_pred = np.percentile(forecast.numpy(), 50, axis=0)
|
| 32 |
high_bound = np.percentile(forecast.numpy(), 90, axis=0)
|
| 33 |
|
| 34 |
-
# Trading
|
| 35 |
-
entry_price =
|
| 36 |
-
|
| 37 |
-
p10_support = low_bound[0]
|
| 38 |
|
| 39 |
-
# Stop Loss (SL)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
-
# Take Profit (TP) based on
|
| 44 |
-
|
| 45 |
-
tp = entry_price + (
|
| 46 |
|
| 47 |
-
# Verdict
|
| 48 |
-
if
|
| 49 |
-
verdict = "
|
| 50 |
-
elif
|
| 51 |
-
verdict = "Wait
|
| 52 |
else:
|
| 53 |
-
verdict = "Neutral
|
| 54 |
|
|
|
|
| 55 |
return {
|
|
|
|
|
|
|
| 56 |
"prediction": median_pred.tolist(),
|
| 57 |
"upper_bound": high_bound.tolist(),
|
| 58 |
"lower_bound": low_bound.tolist(),
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
"verdict": verdict,
|
| 62 |
-
"
|
| 63 |
}
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from chronos import ChronosPipeline
|
| 5 |
|
| 6 |
app = FastAPI()
|
| 7 |
|
| 8 |
+
# Pipeline initialization with the corrected 'dtype'
|
| 9 |
pipeline = ChronosPipeline.from_pretrained(
|
| 10 |
+
"amazon/chronos-t5-small",
|
| 11 |
device_map="auto",
|
| 12 |
+
dtype=torch.bfloat16,
|
| 13 |
)
|
| 14 |
|
| 15 |
@app.post("/predict")
|
| 16 |
+
async def get_forecast(request: Request):
|
| 17 |
+
# 1. Receive data from your backend
|
| 18 |
+
body = await request.json()
|
| 19 |
+
history = body.get("history", [])
|
| 20 |
+
user_rr = body.get("rr_ratio", 2.0) # Default to 2:1 if not provided
|
| 21 |
|
| 22 |
+
if not history:
|
| 23 |
+
return {"error": "No history data provided"}
|
| 24 |
+
|
| 25 |
+
# 2. Convert history to tensor (No length limit - uses your 2-3 year data)
|
| 26 |
+
context = torch.tensor(history)
|
| 27 |
|
| 28 |
+
# 3. Predict (Defaulting to 12 steps ahead)
|
| 29 |
+
# Chronos returns (num_samples, prediction_length)
|
| 30 |
+
forecast = pipeline.predict(context, prediction_length=12)
|
| 31 |
|
| 32 |
+
# 4. Extract Quantiles for Bounds
|
| 33 |
+
# P10 = Lower Bound, P50 = Median (Prediction), P90 = Upper Bound
|
| 34 |
low_bound = np.percentile(forecast.numpy(), 10, axis=0)
|
| 35 |
median_pred = np.percentile(forecast.numpy(), 50, axis=0)
|
| 36 |
high_bound = np.percentile(forecast.numpy(), 90, axis=0)
|
| 37 |
|
| 38 |
+
# 5. Trading System Logic
|
| 39 |
+
entry_price = history[-1]
|
| 40 |
+
next_move = median_pred[0]
|
|
|
|
| 41 |
|
| 42 |
+
# Stop Loss (SL) based on the model's P10 (statistical support)
|
| 43 |
+
# We ensure SL is actually below entry; if not, we use a 2% buffer
|
| 44 |
+
sl = min(low_bound[0], entry_price * 0.98)
|
| 45 |
|
| 46 |
+
# Calculate Take Profit (TP) based on User RR Ratio
|
| 47 |
+
risk_amount = entry_price - sl
|
| 48 |
+
tp = entry_price + (risk_amount * user_rr)
|
| 49 |
|
| 50 |
+
# 6. Verdict Engine
|
| 51 |
+
if next_move > entry_price * 1.01:
|
| 52 |
+
verdict = "Entry Confirmed: Upward momentum detected."
|
| 53 |
+
elif next_move < entry_price * 0.99:
|
| 54 |
+
verdict = "Wait: Potential pullback or downward trend."
|
| 55 |
else:
|
| 56 |
+
verdict = "Neutral: Sideways movement expected. Wait for breakout."
|
| 57 |
|
| 58 |
+
# 7. Response to Backend
|
| 59 |
return {
|
| 60 |
+
"status": "success",
|
| 61 |
+
"entry": float(entry_price),
|
| 62 |
"prediction": median_pred.tolist(),
|
| 63 |
"upper_bound": high_bound.tolist(),
|
| 64 |
"lower_bound": low_bound.tolist(),
|
| 65 |
+
"suggested_sl": round(float(sl), 2),
|
| 66 |
+
"suggested_tp": round(float(tp), 2),
|
| 67 |
"verdict": verdict,
|
| 68 |
+
"rr_applied": f"{user_rr}:1"
|
| 69 |
}
|