| import torch |
| from torch.nn.functional import mse_loss |
| from typing import Mapping |
|
|
|
|
| def assert_model_parameters_fp32(model: torch.nn.Module, model_name: str) -> None: |
| non_fp32: list[dict[str, str]] = [] |
| parameter_count = 0 |
| for name, parameter in model.named_parameters(): |
| parameter_count += 1 |
| if parameter.dtype != torch.float32: |
| non_fp32.append({"name": name, "dtype": str(parameter.dtype)}) |
|
|
| assert parameter_count > 0, f"{model_name} has no parameters." |
| assert len(non_fp32) == 0, ( |
| f"{model_name} parameters must all be torch.float32. " |
| f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}" |
| ) |
|
|
|
|
| def assert_state_dict_floating_tensors_fp32( |
| state_dict: Mapping[str, torch.Tensor], |
| state_dict_name: str, |
| ) -> None: |
| non_fp32: list[dict[str, str]] = [] |
| for tensor_name in sorted(state_dict.keys()): |
| tensor = state_dict[tensor_name] |
| assert torch.is_tensor(tensor), ( |
| f"{state_dict_name} state_dict entry must be a tensor. " |
| f"name={tensor_name} type={type(tensor)}" |
| ) |
| if tensor.is_floating_point() and tensor.dtype != torch.float32: |
| non_fp32.append({"name": tensor_name, "dtype": str(tensor.dtype)}) |
|
|
| assert len(non_fp32) == 0, ( |
| f"{state_dict_name} floating tensors must be torch.float32. " |
| f"non_fp32_count={len(non_fp32)} sample={non_fp32[:5]}" |
| ) |
|
|
|
|
| def assert_state_dict_equal( |
| reference_state_dict: Mapping[str, torch.Tensor], |
| candidate_state_dict: Mapping[str, torch.Tensor], |
| context: str, |
| max_report: int = 10, |
| ) -> None: |
| error_msgs = [] |
| for (ref_name, ref_tensor), (cand_name, cand_tensor) in zip(reference_state_dict.items(), candidate_state_dict.items()): |
| if ref_name != cand_name: |
| msg = f"Name mismatch: {ref_name} != {cand_name}" |
| print(msg) |
| error_msgs.append(msg) |
| else: |
| diff = mse_loss(ref_tensor, cand_tensor).item() |
| if diff > 0.0: |
| msg = f"{ref_name}: {diff}" |
| print(msg) |
| error_msgs.append(msg) |
| assert not error_msgs, ( |
| f"{context} state_dict parity failed:{' | '.join(error_msgs[:max_report])}" |
| ) |
|
|
|
|
| def assert_models_fp32_and_equal( |
| reference_model: torch.nn.Module, |
| candidate_model: torch.nn.Module, |
| context: str, |
| max_report: int = 5, |
| ) -> None: |
| assert_model_parameters_fp32(model=reference_model, model_name=f"{context} reference model") |
| assert_model_parameters_fp32(model=candidate_model, model_name=f"{context} candidate model") |
| assert_state_dict_equal( |
| reference_state_dict=reference_model.state_dict(), |
| candidate_state_dict=candidate_model.state_dict(), |
| context=context, |
| max_report=max_report, |
| ) |
|
|