StockPredict / core /model_runner.py
aromidvar1355's picture
Update core/model_runner.py
7277bfe verified
import pandas as pd
from core.train_eval import train_and_evaluate
from core.models import LSTMModel, GRUModel, CNNModel, TransformerModel, MLPModel, BiLSTMModel, HybridModel
def get_model(df, model_name, horizon, hidden_units, n_layers, epochs, learning_rate, beta1, beta2, weight_decay, dropout, window_size, test_split):
model_map = {
"LSTM": LSTMModel,
"GRU": GRUModel,
"CNN": CNNModel,
"Transformer": TransformerModel,
"MLP": MLPModel,
"BiLSTM": BiLSTMModel,
"Hybrid": HybridModel
}
if model_name not in model_map:
raise ValueError(f"Model {model_name} not supported.")
model_cls = model_map[model_name]
result = train_and_evaluate(
df=df,
model_cls=model_cls,
horizon=horizon,
hidden=hidden_units,
layers=n_layers,
epochs=epochs,
lr=learning_rate,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
dropout=dropout,
window=window_size,
test_split=test_split
)
return result