|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Mapping, Sequence |
|
|
|
|
|
import torch |
|
|
from lightning.pytorch.plugins import HalfPrecision |
|
|
from typing_extensions import override |
|
|
|
|
|
_HAS_HYDRA = True |
|
|
|
|
|
try: |
|
|
import hydra |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
except ModuleNotFoundError: |
|
|
DictConfig = Mapping |
|
|
OmegaConf = None |
|
|
_HAS_HYDRA = False |
|
|
|
|
|
|
|
|
def resolve_trainer_cfg(trainer_cfg: DictConfig) -> DictConfig: |
|
|
""" |
|
|
Resolves and processes a trainer configuration. |
|
|
|
|
|
This function handles specific trainer configuration details: |
|
|
- For half precision setups, replaces precision settings with custom plugins |
|
|
- Instantiates strategy objects from mapping configurations |
|
|
- Instantiates custom callbacks from sequences |
|
|
|
|
|
Args: |
|
|
trainer_cfg: A DictConfig containing trainer configuration parameters |
|
|
|
|
|
Returns: |
|
|
A processed DictConfig with resolved configuration values |
|
|
""" |
|
|
trainer_cfg = OmegaConf.to_container(trainer_cfg, resolve=True) |
|
|
if not _HAS_HYDRA: |
|
|
return trainer_cfg |
|
|
|
|
|
|
|
|
precision = trainer_cfg.get("precision") |
|
|
if precision in ("fp16-true", "bf16-true"): |
|
|
trainer_cfg.pop("precision", None) |
|
|
trainer_cfg["plugins"] = [HalfPrecisionForAudio(precision)] |
|
|
|
|
|
|
|
|
if (strategy := trainer_cfg.get("strategy", None)) is not None and isinstance(strategy, Mapping): |
|
|
trainer_cfg["strategy"] = hydra.utils.instantiate(strategy) |
|
|
|
|
|
|
|
|
if (cbs := trainer_cfg.get("callbacks", None)) is not None and isinstance(cbs, Sequence): |
|
|
resolved = [] |
|
|
for cb in cbs: |
|
|
resolved.append(hydra.utils.instantiate(cb)) |
|
|
trainer_cfg["callbacks"] = resolved |
|
|
|
|
|
return trainer_cfg |
|
|
|
|
|
|
|
|
class HalfPrecisionForAudio(HalfPrecision): |
|
|
""" |
|
|
Adjusted Pytorch Lightning plugin for training with half precision. |
|
|
It avoids downcasting audio to bfloat16 when the mini-batch is a dict |
|
|
with 'audio' string in the keys corresponding to audio tensors. |
|
|
""" |
|
|
|
|
|
@override |
|
|
def convert_input(self, data: Any) -> Any: |
|
|
""" |
|
|
Converts input data to the appropriate precision format, preserving audio tensor precision. |
|
|
|
|
|
This method overrides the parent class implementation to avoid downcasting tensors |
|
|
with 'audio' in their dictionary keys. It processes input data recursively when |
|
|
encountering nested dictionaries. |
|
|
|
|
|
Args: |
|
|
data: The input data to convert (can be tensor, dict, or other types) |
|
|
|
|
|
Returns: |
|
|
The converted data with appropriate precision for each element |
|
|
""" |
|
|
if not isinstance(data, dict): |
|
|
return super().convert_input(data) |
|
|
|
|
|
def _convert(v): |
|
|
if isinstance(v, dict): |
|
|
ans = {} |
|
|
for k, v in v.items(): |
|
|
if "audio" not in k or not torch.is_tensor(v): |
|
|
v = _convert(v) |
|
|
ans[k] = v |
|
|
return ans |
|
|
if isinstance(v, torch.Tensor) and torch.is_floating_point(v): |
|
|
return v.to(self._desired_input_dtype) |
|
|
return v |
|
|
|
|
|
return _convert(data) |
|
|
|