|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""For log_train_summary.""" |
|
|
|
|
|
from typing import Any, Callable, Dict, Tuple, Sequence, Optional, Mapping, Union |
|
|
|
|
|
from clu import metric_writers |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
|
|
|
from scenic.train_lib import train_utils |
|
|
|
|
|
|
|
|
|
|
|
PyTree = Union[Mapping[str, Mapping], Any] |
|
|
PRNGKey = jnp.ndarray |
|
|
|
|
|
|
|
|
def log_train_summary(step: int, |
|
|
*, |
|
|
writer: metric_writers.MetricWriter, |
|
|
train_metrics: Sequence[Dict[str, Tuple[float, int]]], |
|
|
train_images: Any = None, |
|
|
extra_training_logs: Optional[Sequence[Dict[str, |
|
|
Any]]] = None, |
|
|
metrics_normalizer_fn: Optional[ |
|
|
Callable[[Dict[str, Tuple[float, int]], str], |
|
|
Dict[str, float]]] = None, |
|
|
prefix: str = 'train', |
|
|
step_idx: Optional[int] = None, |
|
|
key_separator: str = '_') -> Dict[str, float]: |
|
|
"""Computes and logs train metrics.""" |
|
|
if step_idx is None: |
|
|
step_idx = step |
|
|
|
|
|
def fmt(i, p): |
|
|
return f'%.{p}d' % i |
|
|
|
|
|
if train_images is not None: |
|
|
train_images = train_utils.stack_forest( |
|
|
train_images) |
|
|
train_images = jax.tree_util.tree_map(lambda x: jnp.concatenate(x)[:4], |
|
|
train_images) |
|
|
new_train_images = {} |
|
|
for key, value in train_images.items(): |
|
|
for (batch_idx, image) in enumerate(value): |
|
|
new_train_images[ |
|
|
f'{key}/bi{fmt(batch_idx,p=2)}/s{fmt(step_idx,p=8)}'] = image[0, |
|
|
...] |
|
|
|
|
|
writer.write_images(step, new_train_images) |
|
|
|
|
|
|
|
|
|
|
|
train_metrics = train_utils.stack_forest(train_metrics) |
|
|
|
|
|
train_metrics_summary = jax.tree_util.tree_map(lambda x: x.sum(), |
|
|
train_metrics) |
|
|
|
|
|
metrics_normalizer_fn = metrics_normalizer_fn or train_utils.normalize_metrics_summary |
|
|
train_metrics_summary = metrics_normalizer_fn(train_metrics_summary, 'train') |
|
|
|
|
|
|
|
|
|
|
|
extra_training_logs = extra_training_logs or {} |
|
|
train_logs = train_utils.stack_forest(extra_training_logs) |
|
|
|
|
|
|
|
|
writer.write_scalars( |
|
|
step, { |
|
|
key_separator.join((prefix, key)): val |
|
|
for key, val in train_metrics_summary.items() |
|
|
}) |
|
|
|
|
|
writer.write_scalars(step, |
|
|
{key: val.mean() for key, val in train_logs.items()}) |
|
|
|
|
|
writer.flush() |
|
|
return train_metrics_summary |
|
|
|