Spaces:
Sleeping
Sleeping
File size: 2,819 Bytes
d148f2e b90f986 72f6e5f 3c199fc d148f2e b90f986 3260c40 b90f986 d148f2e 3c199fc d148f2e 3c199fc d148f2e b90f986 3260c40 72f6e5f 3c199fc 72f6e5f 3c199fc 3260c40 3c199fc 3260c40 3c199fc d148f2e 3c199fc d148f2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import os
import torch
import numpy as np
from fastapi import FastAPI, Request
from chronos import ChronosPipeline # This works ONLY if chronos-forecasting is installed
import uvicorn
app = FastAPI()
# Global variable for the model
pipeline = None
@app.on_event("startup")
def load_model():
global pipeline
print("Loading Chronos model...")
# Using 'tiny' for stability on HF basic instances
# Ensure dtype is torch.bfloat16 for efficiency
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-tiny",
device_map="cpu", # Use "cpu" if no GPU, or "auto" if you have a T4
torch_dtype=torch.bfloat16
)
print("Model loaded successfully.")
@app.get("/")
def home():
return {"status": "Model is running", "info": "Send POST to /predict"}
@app.post("/predict")
async def get_forecast(request: Request):
global pipeline
try:
body = await request.json()
history = body.get("history", [])
user_rr = float(body.get("rr_ratio", 2.0))
if not history:
return {"error": "History array is empty"}
# 1. Convert to tensor and add batch dimension [1, seq_len]
# Chronos expects a batch dimension
context = torch.tensor(history, dtype=torch.float32).unsqueeze(0)
# 2. Prediction logic
# num_samples=20 gives enough for quantiles without killing RAM
forecast = pipeline.predict(context, prediction_length=12, num_samples=20)
# 3. Extract results (remove batch dimension for processing)
# forecast shape is [batch, samples, horizon] -> [samples, horizon]
forecast_np = forecast[0].numpy()
low_bound = np.percentile(forecast_np, 10, axis=0)
median_pred = np.percentile(forecast_np, 50, axis=0)
high_bound = np.percentile(forecast_np, 90, axis=0)
# 4. Trading Logic
entry_price = float(history[-1])
first_pred = float(median_pred[0])
# Stop Loss (SL) at the 10th percentile
sl = float(low_bound[0])
if sl >= entry_price:
sl = entry_price * 0.98 # Safety fallback
# Take Profit (TP)
risk = entry_price - sl
tp = entry_price + (risk * user_rr)
return {
"prediction": median_pred.tolist(),
"upper_bound": high_bound.tolist(),
"lower_bound": low_bound.tolist(),
"suggested_sl": round(sl, 4),
"suggested_tp": round(tp, 4),
"verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry"
}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
# Port 7860 is required for Hugging Face Spaces
uvicorn.run(app, host="0.0.0.0", port=7860) |