| |
| |
| |
| |
| |
| from typing import Dict, List, Protocol, Union |
|
|
| import torch.nn as nn |
|
|
| from torchtitan.config_manager import JobConfig |
| from torchtitan.distributed import ParallelDims |
| from torchtitan.tools.logging import logger |
|
|
|
|
| class ModelConverter(Protocol): |
| """General model converter interface. |
| |
| A model converter is applying a modification to PyTorch model. |
| Typical use cases are: |
| - Quantization: using QAT, FP8, ... specialized linear layers; |
| - Fused optimized layers (e.g. flash-attention, norms, ...) |
| """ |
|
|
| def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
| ... |
|
|
| def convert(self, model: nn.Module): |
| """Inplace convertion of the model.""" |
| ... |
|
|
| def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): |
| """Post-optimizer (optional) hook (e.g. compute weights statistics).""" |
| ... |
|
|
|
|
| _registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} |
| """Registry of model converter classes. |
| """ |
|
|
|
|
| def register_model_converter(converter_cls: type[ModelConverter], name: str): |
| """Register a model converter class. |
| |
| A registered model converter can be applied on any model |
| using the `model.converters` config parameter. |
| """ |
| assert ( |
| name not in _registry_model_converter_cls |
| ), f"A model converter '{name}' is already registered." |
| _registry_model_converter_cls[name] = converter_cls |
|
|
|
|
| class ModelConvertersContainer(ModelConverter): |
| """Model converters sequential container. |
| |
| The class build the sequence of model converters defined in `model.converters` |
| job config, and apply them to the model sequentially. |
| """ |
|
|
| def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
| converter_classes = [ |
| _registry_model_converter_cls[name] for name in job_config.model.converters |
| ] |
| self.converters = [ |
| mh_cls(job_config, parallel_dims) for mh_cls in converter_classes |
| ] |
| self.print_after_conversion = job_config.model.print_after_conversion |
|
|
| def convert(self, model: nn.Module): |
| for mh in self.converters: |
| mh.convert(model) |
| if self.print_after_conversion: |
| logger.info(f"Model definion after conversion:\n\n{model}\n\n") |
|
|
| def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): |
| for mh in self.converters: |
| mh.post_optimizer_hook(model) |
|
|
|
|
| def build_model_converters( |
| job_config: JobConfig, parallel_dims: ParallelDims |
| ) -> ModelConvertersContainer: |
| """Build the collection of model converters to apply to the model.""" |
| return ModelConvertersContainer(job_config, parallel_dims) |
|
|