Spaces:
Sleeping
Sleeping
Update core/model_runner.py
Browse files- core/model_runner.py +61 -10
core/model_runner.py
CHANGED
|
@@ -26,39 +26,85 @@ def get_model(
|
|
| 26 |
target,
|
| 27 |
model_name="LSTM",
|
| 28 |
horizon=1,
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
epochs=50,
|
| 33 |
-
lr=0.001,
|
| 34 |
weight_decay=0.01,
|
| 35 |
dropout=0.2,
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
test_split=0.2,
|
| 38 |
selector_method="RandomForest",
|
| 39 |
importance_threshold=0.0,
|
| 40 |
scheduler_type="None",
|
|
|
|
| 41 |
verbose=True,
|
| 42 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
try:
|
| 44 |
-
#
|
|
|
|
| 45 |
if hidden_units is not None:
|
| 46 |
hidden = hidden_units
|
|
|
|
|
|
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
model_classes = {
|
| 49 |
"LSTM": LSTMModel,
|
| 50 |
"GRU": GRUModel,
|
| 51 |
"CNN": CNNModel,
|
| 52 |
"MLP": MLPModel,
|
| 53 |
"Hybrid": HybridCNNGRUModel,
|
|
|
|
| 54 |
"Transformer": TransformerModel,
|
| 55 |
"BiLSTM": BiLSTMModel,
|
| 56 |
}
|
| 57 |
|
| 58 |
model_cls = model_classes.get(model_name, LSTMModel)
|
| 59 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
-
logging.info(f"Running {model_name} on device={device}")
|
| 61 |
|
|
|
|
| 62 |
result = train_and_evaluate(
|
| 63 |
df=df,
|
| 64 |
features=features,
|
|
@@ -69,6 +115,8 @@ def get_model(
|
|
| 69 |
layers=layers,
|
| 70 |
epochs=epochs,
|
| 71 |
lr=lr,
|
|
|
|
|
|
|
| 72 |
weight_decay=weight_decay,
|
| 73 |
dropout=dropout,
|
| 74 |
window=window,
|
|
@@ -80,14 +128,17 @@ def get_model(
|
|
| 80 |
verbose=verbose,
|
| 81 |
)
|
| 82 |
|
|
|
|
| 83 |
if not result:
|
| 84 |
logging.error(f"{model_name} returned empty result.")
|
| 85 |
-
return {"error": "Empty result"}
|
| 86 |
if isinstance(result, dict) and result.get("error"):
|
| 87 |
logging.error(f"{model_name} training error: {result['error']}")
|
| 88 |
return {"error": result["error"]}
|
|
|
|
|
|
|
| 89 |
return result
|
| 90 |
|
| 91 |
except Exception as e:
|
| 92 |
-
logging.error(f"Model runner error for {model_name}: {str(e)}")
|
| 93 |
return {"error": str(e)}
|
|
|
|
| 26 |
target,
|
| 27 |
model_name="LSTM",
|
| 28 |
horizon=1,
|
| 29 |
+
# Hidden size aliases
|
| 30 |
+
hidden=None,
|
| 31 |
+
hidden_units=None,
|
| 32 |
+
# Layers aliases
|
| 33 |
+
layers=None,
|
| 34 |
+
n_layers=None,
|
| 35 |
+
# Learning rate aliases
|
| 36 |
+
lr=None,
|
| 37 |
+
learning_rate=None,
|
| 38 |
+
# Betas for optimizer
|
| 39 |
+
beta1=0.9,
|
| 40 |
+
beta2=0.999,
|
| 41 |
+
# Other hyperparams
|
| 42 |
epochs=50,
|
|
|
|
| 43 |
weight_decay=0.01,
|
| 44 |
dropout=0.2,
|
| 45 |
+
# Window aliases
|
| 46 |
+
window=None,
|
| 47 |
+
window_size=None,
|
| 48 |
test_split=0.2,
|
| 49 |
selector_method="RandomForest",
|
| 50 |
importance_threshold=0.0,
|
| 51 |
scheduler_type="None",
|
| 52 |
+
device=None,
|
| 53 |
verbose=True,
|
| 54 |
):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper that accepts many common argument names used by the UI/calls,
|
| 57 |
+
normalizes them, and calls train_and_evaluate(...) with the canonical names.
|
| 58 |
+
"""
|
| 59 |
try:
|
| 60 |
+
# --- Normalize aliases & defaults ---
|
| 61 |
+
# hidden size: prefer explicit hidden_units, then hidden, else default 64
|
| 62 |
if hidden_units is not None:
|
| 63 |
hidden = hidden_units
|
| 64 |
+
if hidden is None:
|
| 65 |
+
hidden = 64
|
| 66 |
|
| 67 |
+
# layers: prefer explicit n_layers, then layers, else default 1
|
| 68 |
+
if n_layers is not None:
|
| 69 |
+
layers = n_layers
|
| 70 |
+
if layers is None:
|
| 71 |
+
layers = 1
|
| 72 |
+
|
| 73 |
+
# learning rate: prefer learning_rate then lr, else default 0.001
|
| 74 |
+
if learning_rate is not None:
|
| 75 |
+
lr = learning_rate
|
| 76 |
+
if lr is None:
|
| 77 |
+
lr = 0.001
|
| 78 |
+
|
| 79 |
+
# window size: prefer window_size then window, else default 30
|
| 80 |
+
if window_size is not None:
|
| 81 |
+
window = window_size
|
| 82 |
+
if window is None:
|
| 83 |
+
window = 30
|
| 84 |
+
|
| 85 |
+
# device: caller may pass it; otherwise detect automatically
|
| 86 |
+
if device is None:
|
| 87 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 88 |
+
|
| 89 |
+
logging.info(
|
| 90 |
+
f"get_model called: model={model_name}, device={device}, hidden={hidden}, layers={layers}, lr={lr}, window={window}, epochs={epochs}"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# --- Select model class mapping (keys as used in UI) ---
|
| 94 |
model_classes = {
|
| 95 |
"LSTM": LSTMModel,
|
| 96 |
"GRU": GRUModel,
|
| 97 |
"CNN": CNNModel,
|
| 98 |
"MLP": MLPModel,
|
| 99 |
"Hybrid": HybridCNNGRUModel,
|
| 100 |
+
"HybridCNNGRU": HybridCNNGRUModel,
|
| 101 |
"Transformer": TransformerModel,
|
| 102 |
"BiLSTM": BiLSTMModel,
|
| 103 |
}
|
| 104 |
|
| 105 |
model_cls = model_classes.get(model_name, LSTMModel)
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# --- Call the core training function with canonical param names ---
|
| 108 |
result = train_and_evaluate(
|
| 109 |
df=df,
|
| 110 |
features=features,
|
|
|
|
| 115 |
layers=layers,
|
| 116 |
epochs=epochs,
|
| 117 |
lr=lr,
|
| 118 |
+
beta1=beta1,
|
| 119 |
+
beta2=beta2,
|
| 120 |
weight_decay=weight_decay,
|
| 121 |
dropout=dropout,
|
| 122 |
window=window,
|
|
|
|
| 128 |
verbose=verbose,
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# --- Normalize return ---
|
| 132 |
if not result:
|
| 133 |
logging.error(f"{model_name} returned empty result.")
|
| 134 |
+
return {"error": "Empty result from training"}
|
| 135 |
if isinstance(result, dict) and result.get("error"):
|
| 136 |
logging.error(f"{model_name} training error: {result['error']}")
|
| 137 |
return {"error": result["error"]}
|
| 138 |
+
|
| 139 |
+
logging.info(f"{model_name} training completed successfully")
|
| 140 |
return result
|
| 141 |
|
| 142 |
except Exception as e:
|
| 143 |
+
logging.error(f"Model runner error for {model_name}: {str(e)}", exc_info=True)
|
| 144 |
return {"error": str(e)}
|