Aziz30's picture
Update Interface_Graphique/interface_graphique/services/model_service.py
d687a96 verified
import numpy as np
import joblib
from tensorflow.keras.models import load_model
# =========================
# LOAD MODEL + PREPROCESS
# =========================
BASE_DIR = os.path.dirname(__file__)
try:
model = load_model(os.path.join(BASE_DIR, "../../models/global_return_lstm.keras"))
scaler = joblib.load(os.path.join(BASE_DIR, "../../models/return_scaler.save"))
encoder = joblib.load(os.path.join(BASE_DIR, "../../models/symbol_encoder.save"))
except Exception as e:
print("MODEL LOAD ERROR:", e)
model = None
SEQ_LEN = 60
SIGNAL_THRESHOLD = 0.001
# =========================
# PREPARE INPUT
# =========================
def prepare_input(symbol, returns):
# encode symbol
symbol_id = encoder.transform([symbol])[0]
# scale returns
returns = returns.reshape(-1, 1)
returns_scaled = scaler.transform(returns)
# reshape for LSTM
X_price = returns_scaled.reshape(1, SEQ_LEN, 1)
X_symbol = np.array([[symbol_id]])
return X_price, X_symbol
# =========================
# PREDICT SIGNAL
# =========================
def predict_signal(symbol, returns):
if model is None:
return "HOLD"
if len(returns) < SEQ_LEN:
return "HOLD"
returns = returns[-SEQ_LEN:]
X_price, X_symbol = prepare_input(symbol, returns)
pred_return = model.predict([X_price, X_symbol], verbose=0)[0][0]
if pred_return > SIGNAL_THRESHOLD:
return "BUY"
elif pred_return < -SIGNAL_THRESHOLD:
return "SELL"
else:
return "HOLD"