RayoDeCodigos commited on
Commit
d148f2e
·
verified ·
1 Parent(s): 72f6e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -45
app.py CHANGED
@@ -1,69 +1,65 @@
 
1
  import torch
2
  import numpy as np
3
  from fastapi import FastAPI, Request
4
  from chronos import ChronosPipeline
 
5
 
6
  app = FastAPI()
7
 
8
- # Pipeline initialization with the corrected 'dtype'
9
- pipeline = ChronosPipeline.from_pretrained(
10
- "amazon/chronos-t5-small",
11
- device_map="auto",
12
- dtype=torch.bfloat16,
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @app.post("/predict")
16
  async def get_forecast(request: Request):
17
- # 1. Receive data from your backend
18
  body = await request.json()
19
  history = body.get("history", [])
20
- user_rr = body.get("rr_ratio", 2.0) # Default to 2:1 if not provided
21
 
22
  if not history:
23
- return {"error": "No history data provided"}
24
 
25
- # 2. Convert history to tensor (No length limit - uses your 2-3 year data)
26
- context = torch.tensor(history)
27
-
28
- # 3. Predict (Defaulting to 12 steps ahead)
29
- # Chronos returns (num_samples, prediction_length)
30
  forecast = pipeline.predict(context, prediction_length=12)
31
 
32
- # 4. Extract Quantiles for Bounds
33
- # P10 = Lower Bound, P50 = Median (Prediction), P90 = Upper Bound
34
- low_bound = np.percentile(forecast.numpy(), 10, axis=0)
35
- median_pred = np.percentile(forecast.numpy(), 50, axis=0)
36
- high_bound = np.percentile(forecast.numpy(), 90, axis=0)
37
 
38
- # 5. Trading System Logic
39
- entry_price = history[-1]
40
- next_move = median_pred[0]
41
 
42
- # Stop Loss (SL) based on the model's P10 (statistical support)
43
- # We ensure SL is actually below entry; if not, we use a 2% buffer
44
- sl = min(low_bound[0], entry_price * 0.98)
45
-
46
- # Calculate Take Profit (TP) based on User RR Ratio
47
- risk_amount = entry_price - sl
48
- tp = entry_price + (risk_amount * user_rr)
49
-
50
- # 6. Verdict Engine
51
- if next_move > entry_price * 1.01:
52
- verdict = "Entry Confirmed: Upward momentum detected."
53
- elif next_move < entry_price * 0.99:
54
- verdict = "Wait: Potential pullback or downward trend."
55
- else:
56
- verdict = "Neutral: Sideways movement expected. Wait for breakout."
57
 
58
- # 7. Response to Backend
59
  return {
60
- "status": "success",
61
- "entry": float(entry_price),
62
  "prediction": median_pred.tolist(),
63
  "upper_bound": high_bound.tolist(),
64
  "lower_bound": low_bound.tolist(),
65
- "suggested_sl": round(float(sl), 2),
66
- "suggested_tp": round(float(tp), 2),
67
- "verdict": verdict,
68
- "rr_applied": f"{user_rr}:1"
69
- }
 
 
 
 
1
+ import os
2
  import torch
3
  import numpy as np
4
  from fastapi import FastAPI, Request
5
  from chronos import ChronosPipeline
6
+ import uvicorn
7
 
8
  app = FastAPI()
9
 
10
+ # Global variable for the model
11
+ pipeline = None
12
+
13
+ @app.on_event("startup")
14
+ def load_model():
15
+ global pipeline
16
+ print("Loading Chronos model...")
17
+ # Using 'tiny' first to ensure it fits in RAM.
18
+ # Swap to 'amazon/chronos-t5-small' if you have confirmed 16GB+ RAM
19
+ pipeline = ChronosPipeline.from_pretrained(
20
+ "amazon/chronos-t5-tiny",
21
+ device_map="auto",
22
+ dtype=torch.bfloat16
23
+ )
24
+ print("Model loaded successfully.")
25
+
26
+ @app.get("/")
27
+ def home():
28
+ return {"status": "Model is running", "info": "Send POST to /predict"}
29
 
30
  @app.post("/predict")
31
  async def get_forecast(request: Request):
 
32
  body = await request.json()
33
  history = body.get("history", [])
34
+ user_rr = float(body.get("rr_ratio", 2.0))
35
 
36
  if not history:
37
+ return {"error": "History array is empty"}
38
 
39
+ # Chronos prediction logic
40
+ context = torch.tensor(history, dtype=torch.float32)
 
 
 
41
  forecast = pipeline.predict(context, prediction_length=12)
42
 
43
+ forecast_np = forecast.numpy()[0]
44
+ low_bound = np.percentile(forecast_np, 10, axis=0)
45
+ median_pred = np.percentile(forecast_np, 50, axis=0)
46
+ high_bound = np.percentile(forecast_np, 90, axis=0)
 
47
 
48
+ entry_price = float(history[-1])
49
+ sl = float(low_bound[0])
50
+ if sl >= entry_price: sl = entry_price * 0.98
51
 
52
+ tp = entry_price + ((entry_price - sl) * user_rr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
54
  return {
 
 
55
  "prediction": median_pred.tolist(),
56
  "upper_bound": high_bound.tolist(),
57
  "lower_bound": low_bound.tolist(),
58
+ "suggested_sl": round(sl, 4),
59
+ "suggested_tp": round(tp, 4),
60
+ "verdict": "Entry Confirmed" if median_pred[0] > entry_price else "Wait"
61
+ }
62
+
63
+ if __name__ == "__main__":
64
+ # CRITICAL: Hugging Face requires port 7860
65
+ uvicorn.run(app, host="0.0.0.0", port=7860)