File size: 1,093 Bytes
7277bfe
 
 
334cdc0
7277bfe
 
334cdc0
 
 
 
7277bfe
 
334cdc0
7277bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pandas as pd
from core.train_eval import train_and_evaluate
from core.models import LSTMModel, GRUModel, CNNModel, TransformerModel, MLPModel, BiLSTMModel, HybridModel

def get_model(df, model_name, horizon, hidden_units, n_layers, epochs, learning_rate, beta1, beta2, weight_decay, dropout, window_size, test_split):
    model_map = {
        "LSTM": LSTMModel,
        "GRU": GRUModel,
        "CNN": CNNModel,
        "Transformer": TransformerModel,
        "MLP": MLPModel,
        "BiLSTM": BiLSTMModel,
        "Hybrid": HybridModel
    }
    
    if model_name not in model_map:
        raise ValueError(f"Model {model_name} not supported.")
    
    model_cls = model_map[model_name]
    result = train_and_evaluate(
        df=df,
        model_cls=model_cls,
        horizon=horizon,
        hidden=hidden_units,
        layers=n_layers,
        epochs=epochs,
        lr=learning_rate,
        beta1=beta1,
        beta2=beta2,
        weight_decay=weight_decay,
        dropout=dropout,
        window=window_size,
        test_split=test_split
    )
    
    return result