File size: 2,819 Bytes
d148f2e
b90f986
 
72f6e5f
3c199fc
d148f2e
b90f986
3260c40
b90f986
d148f2e
 
 
 
 
 
 
3c199fc
 
d148f2e
 
3c199fc
 
d148f2e
 
 
 
 
 
b90f986
3260c40
72f6e5f
3c199fc
 
 
 
 
 
 
 
72f6e5f
3c199fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3260c40
3c199fc
 
 
 
 
 
 
 
 
 
 
 
3260c40
3c199fc
 
 
 
 
 
 
 
 
 
d148f2e
 
3c199fc
d148f2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
import numpy as np
from fastapi import FastAPI, Request
from chronos import ChronosPipeline  # This works ONLY if chronos-forecasting is installed
import uvicorn

app = FastAPI()

# Global variable for the model
pipeline = None

@app.on_event("startup")
def load_model():
    global pipeline
    print("Loading Chronos model...")
    # Using 'tiny' for stability on HF basic instances
    # Ensure dtype is torch.bfloat16 for efficiency
    pipeline = ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-tiny", 
        device_map="cpu",  # Use "cpu" if no GPU, or "auto" if you have a T4
        torch_dtype=torch.bfloat16
    )
    print("Model loaded successfully.")

@app.get("/")
def home():
    return {"status": "Model is running", "info": "Send POST to /predict"}

@app.post("/predict")
async def get_forecast(request: Request):
    global pipeline
    try:
        body = await request.json()
        history = body.get("history", [])
        user_rr = float(body.get("rr_ratio", 2.0))
        
        if not history:
            return {"error": "History array is empty"}

        # 1. Convert to tensor and add batch dimension [1, seq_len]
        # Chronos expects a batch dimension
        context = torch.tensor(history, dtype=torch.float32).unsqueeze(0)
        
        # 2. Prediction logic
        # num_samples=20 gives enough for quantiles without killing RAM
        forecast = pipeline.predict(context, prediction_length=12, num_samples=20)
        
        # 3. Extract results (remove batch dimension for processing)
        # forecast shape is [batch, samples, horizon] -> [samples, horizon]
        forecast_np = forecast[0].numpy()
        
        low_bound = np.percentile(forecast_np, 10, axis=0)
        median_pred = np.percentile(forecast_np, 50, axis=0)
        high_bound = np.percentile(forecast_np, 90, axis=0)

        # 4. Trading Logic
        entry_price = float(history[-1])
        first_pred = float(median_pred[0])
        
        # Stop Loss (SL) at the 10th percentile
        sl = float(low_bound[0])
        if sl >= entry_price: 
            sl = entry_price * 0.98 # Safety fallback
        
        # Take Profit (TP)
        risk = entry_price - sl
        tp = entry_price + (risk * user_rr)

        return {
            "prediction": median_pred.tolist(),
            "upper_bound": high_bound.tolist(),
            "lower_bound": low_bound.tolist(),
            "suggested_sl": round(sl, 4),
            "suggested_tp": round(tp, 4),
            "verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry"
        }
    except Exception as e:
        return {"error": str(e)}

if __name__ == "__main__":
    # Port 7860 is required for Hugging Face Spaces
    uvicorn.run(app, host="0.0.0.0", port=7860)