Spaces:
Sleeping
Sleeping
File size: 4,172 Bytes
7451255 64862f3 317843a 7451255 317843a 64862f3 7451255 64862f3 df473db 64862f3 df473db b0a80d3 7451255 df473db 7451255 64862f3 df473db 317843a df473db 847d6b3 df473db 847d6b3 df473db 64862f3 7451255 df473db 7451255 64862f3 7451255 64862f3 7451255 df473db 64862f3 7451255 64862f3 7451255 df473db 64862f3 7451255 64862f3 b0a80d3 7451255 64862f3 7451255 df473db 64862f3 7451255 df473db b0a80d3 7451255 b0a80d3 df473db 64862f3 7451255 317843a df473db 7451255 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | # core/model_runner.py
import torch
import logging
from core.train_eval import train_and_evaluate
from core.models import (
LSTMModel,
GRUModel,
CNNModel,
MLPModel,
HybridCNNGRUModel,
TransformerModel,
BiLSTMModel,
)
logging.basicConfig(
level=logging.INFO,
filename="/tmp/app_log.txt",
filemode="a",
format="%(asctime)s - %(levelname)s - %(message)s",
)
def get_model(
df,
features,
target,
model_name="LSTM",
horizon=1,
# Hidden size aliases
hidden=None,
hidden_units=None,
# Layers aliases
layers=None,
n_layers=None,
# Learning rate aliases
lr=None,
learning_rate=None,
# Betas for optimizer
beta1=0.9,
beta2=0.999,
# Other hyperparams
epochs=50,
weight_decay=0.01,
dropout=0.2,
# Window aliases
window=None,
window_size=None,
test_split=0.2,
selector_method="RandomForest",
importance_threshold=0.0,
scheduler_type="None",
device=None,
verbose=True,
):
"""
Wrapper that accepts many common argument names used by the UI/calls,
normalizes them, and calls train_and_evaluate(...) with the canonical names.
"""
try:
# --- Normalize aliases & defaults ---
# hidden size: prefer explicit hidden_units, then hidden, else default 64
if hidden_units is not None:
hidden = hidden_units
if hidden is None:
hidden = 64
# layers: prefer explicit n_layers, then layers, else default 1
if n_layers is not None:
layers = n_layers
if layers is None:
layers = 1
# learning rate: prefer learning_rate then lr, else default 0.001
if learning_rate is not None:
lr = learning_rate
if lr is None:
lr = 0.001
# window size: prefer window_size then window, else default 30
if window_size is not None:
window = window_size
if window is None:
window = 30
# device: caller may pass it; otherwise detect automatically
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(
f"get_model called: model={model_name}, device={device}, hidden={hidden}, layers={layers}, lr={lr}, window={window}, epochs={epochs}"
)
# --- Select model class mapping (keys as used in UI) ---
model_classes = {
"LSTM": LSTMModel,
"GRU": GRUModel,
"CNN": CNNModel,
"MLP": MLPModel,
"Hybrid": HybridCNNGRUModel,
"HybridCNNGRU": HybridCNNGRUModel,
"Transformer": TransformerModel,
"BiLSTM": BiLSTMModel,
}
model_cls = model_classes.get(model_name, LSTMModel)
# --- Call the core training function with canonical param names ---
result = train_and_evaluate(
df=df,
features=features,
target=target,
model_cls=model_cls,
horizon=horizon,
hidden=hidden,
layers=layers,
epochs=epochs,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
dropout=dropout,
window=window,
test_split=test_split,
selector_method=selector_method,
importance_threshold=importance_threshold,
scheduler_type=scheduler_type,
device=device,
verbose=verbose,
)
# --- Normalize return ---
if not result:
logging.error(f"{model_name} returned empty result.")
return {"error": "Empty result from training"}
if isinstance(result, dict) and result.get("error"):
logging.error(f"{model_name} training error: {result['error']}")
return {"error": result["error"]}
logging.info(f"{model_name} training completed successfully")
return result
except Exception as e:
logging.error(f"Model runner error for {model_name}: {str(e)}", exc_info=True)
return {"error": str(e)}
|