Aziz30 commited on
Commit
9523912
·
verified ·
1 Parent(s): 5b57aa4

Update Interface_Graphique/interface_graphique/services/model_service.py

Browse files
Interface_Graphique/interface_graphique/services/model_service.py CHANGED
@@ -1,12 +1,54 @@
1
  import numpy as np
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def predict_signal(symbol, returns):
4
 
5
- avg = np.mean(returns)
 
 
 
 
 
 
 
6
 
7
- if avg > 0:
8
  return "BUY"
9
- elif avg < 0:
10
  return "SELL"
11
  else:
12
  return "HOLD"
 
1
  import numpy as np
2
+ import joblib
3
+ from tensorflow.keras.models import load_model
4
 
5
+ # =========================
6
+ # LOAD MODEL + PREPROCESS
7
+ # =========================
8
+ model = load_model("models/global_return_lstm.keras")
9
+ scaler = joblib.load("models/return_scaler.save")
10
+ encoder = joblib.load("models/symbol_encoder.save")
11
+
12
+ SEQ_LEN = 60
13
+ SIGNAL_THRESHOLD = 0.001
14
+
15
+
16
+ # =========================
17
+ # PREPARE INPUT
18
+ # =========================
19
+ def prepare_input(symbol, returns):
20
+
21
+ # encode symbol
22
+ symbol_id = encoder.transform([symbol])[0]
23
+
24
+ # scale returns
25
+ returns = returns.reshape(-1, 1)
26
+ returns_scaled = scaler.transform(returns)
27
+
28
+ # reshape for LSTM
29
+ X_price = returns_scaled.reshape(1, SEQ_LEN, 1)
30
+ X_symbol = np.array([[symbol_id]])
31
+
32
+ return X_price, X_symbol
33
+
34
+
35
+ # =========================
36
+ # PREDICT SIGNAL
37
+ # =========================
38
  def predict_signal(symbol, returns):
39
 
40
+ if len(returns) < SEQ_LEN:
41
+ return "HOLD"
42
+
43
+ returns = returns[-SEQ_LEN:]
44
+
45
+ X_price, X_symbol = prepare_input(symbol, returns)
46
+
47
+ pred_return = model.predict([X_price, X_symbol], verbose=0)[0][0]
48
 
49
+ if pred_return > SIGNAL_THRESHOLD:
50
  return "BUY"
51
+ elif pred_return < -SIGNAL_THRESHOLD:
52
  return "SELL"
53
  else:
54
  return "HOLD"