RayoDeCodigos commited on
Commit
3c199fc
·
verified ·
1 Parent(s): ea6ac62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -33
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  import numpy as np
4
  from fastapi import FastAPI, Request
5
- from chronos import ChronosPipeline
6
  import uvicorn
7
 
8
  app = FastAPI()
@@ -14,12 +14,12 @@ pipeline = None
14
  def load_model():
15
  global pipeline
16
  print("Loading Chronos model...")
17
- # Using 'tiny' first to ensure it fits in RAM.
18
- # Swap to 'amazon/chronos-t5-small' if you have confirmed 16GB+ RAM
19
  pipeline = ChronosPipeline.from_pretrained(
20
  "amazon/chronos-t5-tiny",
21
- device_map="auto",
22
- dtype=torch.bfloat16
23
  )
24
  print("Model loaded successfully.")
25
 
@@ -29,37 +29,55 @@ def home():
29
 
30
  @app.post("/predict")
31
  async def get_forecast(request: Request):
32
- body = await request.json()
33
- history = body.get("history", [])
34
- user_rr = float(body.get("rr_ratio", 2.0))
35
-
36
- if not history:
37
- return {"error": "History array is empty"}
 
 
38
 
39
- # Chronos prediction logic
40
- context = torch.tensor(history, dtype=torch.float32)
41
- forecast = pipeline.predict(context, prediction_length=12)
42
-
43
- forecast_np = forecast.numpy()[0]
44
- low_bound = np.percentile(forecast_np, 10, axis=0)
45
- median_pred = np.percentile(forecast_np, 50, axis=0)
46
- high_bound = np.percentile(forecast_np, 90, axis=0)
 
 
 
 
 
 
 
47
 
48
- entry_price = float(history[-1])
49
- sl = float(low_bound[0])
50
- if sl >= entry_price: sl = entry_price * 0.98
51
-
52
- tp = entry_price + ((entry_price - sl) * user_rr)
 
 
 
 
 
 
 
53
 
54
- return {
55
- "prediction": median_pred.tolist(),
56
- "upper_bound": high_bound.tolist(),
57
- "lower_bound": low_bound.tolist(),
58
- "suggested_sl": round(sl, 4),
59
- "suggested_tp": round(tp, 4),
60
- "verdict": "Entry Confirmed" if median_pred[0] > entry_price else "Wait"
61
- }
 
 
62
 
63
  if __name__ == "__main__":
64
- # CRITICAL: Hugging Face requires port 7860
65
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import torch
3
  import numpy as np
4
  from fastapi import FastAPI, Request
5
+ from chronos import ChronosPipeline # This works ONLY if chronos-forecasting is installed
6
  import uvicorn
7
 
8
  app = FastAPI()
 
14
  def load_model():
15
  global pipeline
16
  print("Loading Chronos model...")
17
+ # Using 'tiny' for stability on HF basic instances
18
+ # Ensure dtype is torch.bfloat16 for efficiency
19
  pipeline = ChronosPipeline.from_pretrained(
20
  "amazon/chronos-t5-tiny",
21
+ device_map="cpu", # Use "cpu" if no GPU, or "auto" if you have a T4
22
+ torch_dtype=torch.bfloat16
23
  )
24
  print("Model loaded successfully.")
25
 
 
29
 
30
  @app.post("/predict")
31
  async def get_forecast(request: Request):
32
+ global pipeline
33
+ try:
34
+ body = await request.json()
35
+ history = body.get("history", [])
36
+ user_rr = float(body.get("rr_ratio", 2.0))
37
+
38
+ if not history:
39
+ return {"error": "History array is empty"}
40
 
41
+ # 1. Convert to tensor and add batch dimension [1, seq_len]
42
+ # Chronos expects a batch dimension
43
+ context = torch.tensor(history, dtype=torch.float32).unsqueeze(0)
44
+
45
+ # 2. Prediction logic
46
+ # num_samples=20 gives enough for quantiles without killing RAM
47
+ forecast = pipeline.predict(context, prediction_length=12, num_samples=20)
48
+
49
+ # 3. Extract results (remove batch dimension for processing)
50
+ # forecast shape is [batch, samples, horizon] -> [samples, horizon]
51
+ forecast_np = forecast[0].numpy()
52
+
53
+ low_bound = np.percentile(forecast_np, 10, axis=0)
54
+ median_pred = np.percentile(forecast_np, 50, axis=0)
55
+ high_bound = np.percentile(forecast_np, 90, axis=0)
56
 
57
+ # 4. Trading Logic
58
+ entry_price = float(history[-1])
59
+ first_pred = float(median_pred[0])
60
+
61
+ # Stop Loss (SL) at the 10th percentile
62
+ sl = float(low_bound[0])
63
+ if sl >= entry_price:
64
+ sl = entry_price * 0.98 # Safety fallback
65
+
66
+ # Take Profit (TP)
67
+ risk = entry_price - sl
68
+ tp = entry_price + (risk * user_rr)
69
 
70
+ return {
71
+ "prediction": median_pred.tolist(),
72
+ "upper_bound": high_bound.tolist(),
73
+ "lower_bound": low_bound.tolist(),
74
+ "suggested_sl": round(sl, 4),
75
+ "suggested_tp": round(tp, 4),
76
+ "verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry"
77
+ }
78
+ except Exception as e:
79
+ return {"error": str(e)}
80
 
81
  if __name__ == "__main__":
82
+ # Port 7860 is required for Hugging Face Spaces
83
  uvicorn.run(app, host="0.0.0.0", port=7860)