File size: 11,071 Bytes
1f5470c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | import numpy as np
from keras.src import backend
from keras.src import callbacks as callbacks_module
from keras.src import tree
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.numpy.core import is_tensor
from keras.src.trainers import trainer as base_trainer
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.trainers.epoch_iterator import EpochIterator
from keras.src.utils import traceback_utils
class NumpyTrainer(base_trainer.Trainer):
def __init__(self):
super().__init__()
self.test_function = None
self.predict_function = None
def test_step(self, data):
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg:
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self._compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
)
self._loss_tracker.update_state(
loss, sample_weight=tree.flatten(x)[0].shape[0]
)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def predict_step(self, data):
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg:
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred
def make_test_function(self, force=False):
if self.test_function is not None and not force:
return self.test_function
def one_test_step(data):
data = data[0]
return self.test_step(data)
def multi_test_steps(data):
for single_step_data in data:
logs = one_test_step([single_step_data])
return logs
if self.steps_per_execution > 1:
test_step = multi_test_steps
else:
test_step = one_test_step
self.test_function = test_step
def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
return self.predict_function
def one_predict_step(data):
data = data[0]
return self.predict_step(data)
def multi_predict_steps(data):
outputs = one_predict_step(data[:1])
for single_step_data in data[1:]:
step_outputs = one_predict_step([single_step_data])
outputs = tree.map_structure(
lambda t1, t2: np.concatenate([t1, t2]),
outputs,
step_outputs,
)
return outputs
if self.steps_per_execution > 1:
predict_step = multi_predict_steps
else:
predict_step = one_predict_step
self.predict_function = predict_step
def _symbolic_build(self, data_batch):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
compile_loss_unbuilt = (
self._compile_loss is not None and not self._compile_loss.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
# Create symbolic tensors matching an input batch.
def to_symbolic_input(v):
if is_tensor(v):
return KerasTensor(v.shape, standardize_dtype(v.dtype))
return v
data_batch = tree.map_structure(to_symbolic_input, data_batch)
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
except:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling "
"fit/evaluate/predict. "
"A model is 'built' when its variables have "
"been created and its `self.built` attribute "
"is True. Usually, calling the model on a batch "
"of data is the right way to build it."
)
if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec(
self.compute_metrics,
x,
y,
y_pred,
sample_weight=sample_weight,
)
if compile_loss_unbuilt:
# Build `CompileLoss` state with `backend.compute_output_spec`.
backend.compute_output_spec(
self._compute_loss,
x,
y,
y_pred,
sample_weight=sample_weight,
)
self._post_build()
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
raise NotImplementedError("fit not implemented for NumPy backend.")
@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tree.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tree.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
self.make_predict_function()
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator:
callbacks.on_predict_batch_begin(step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
if self.stop_predicting:
break
callbacks.on_predict_end()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
@traceback_utils.filter_traceback
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)
if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, data in epoch_iterator:
data_batch = data[0]
self._symbolic_build(data_batch)
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
self.reset_metrics()
for step, data in epoch_iterator:
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
raise NotImplementedError(
"train_on_batch not implemented for NumPy backend."
)
def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
self._assert_compile_called("test_on_batch")
data = (x, y, sample_weight)
# Maybe build model
self._symbolic_build(data)
self.make_test_function()
logs = self.test_function([data])
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def predict_on_batch(self, x):
self.make_predict_function()
batch_outputs = self.predict_function([(x,)])
batch_outputs = tree.map_structure(
backend.convert_to_numpy, batch_outputs
)
return batch_outputs
|