Spaces:
Sleeping
Sleeping
| # Copyright (C) 2021-2024, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| import logging | |
| from typing import Any, List, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from doctr.utils.data import download_from_url | |
| __all__ = [ | |
| "load_pretrained_params", | |
| "conv_sequence_pt", | |
| "set_device_and_dtype", | |
| "export_model_to_onnx", | |
| "_copy_tensor", | |
| "_bf16_to_float32", | |
| ] | |
| def _copy_tensor(x: torch.Tensor) -> torch.Tensor: | |
| return x.clone().detach() | |
| def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor: | |
| # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype | |
| return x.float() if x.dtype == torch.bfloat16 else x | |
| def load_pretrained_params( | |
| model: nn.Module, | |
| url: Optional[str] = None, | |
| hash_prefix: Optional[str] = None, | |
| ignore_keys: Optional[List[str]] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| """Load a set of parameters onto a model | |
| >>> from doctr.models import load_pretrained_params | |
| >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") | |
| Args: | |
| ---- | |
| model: the PyTorch model to be loaded | |
| url: URL of the zipped set of parameters | |
| hash_prefix: first characters of SHA256 expected hash | |
| ignore_keys: list of weights to be ignored from the state_dict | |
| **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` | |
| """ | |
| if url is None: | |
| logging.warning("Invalid model URL, using default initialization.") | |
| else: | |
| archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) | |
| # Read state_dict | |
| state_dict = torch.load(archive_path, map_location="cpu") | |
| # Remove weights from the state_dict | |
| if ignore_keys is not None and len(ignore_keys) > 0: | |
| for key in ignore_keys: | |
| state_dict.pop(key) | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0: | |
| raise ValueError("unable to load state_dict, due to non-matching keys.") | |
| else: | |
| # Load weights | |
| model.load_state_dict(state_dict) | |
| def conv_sequence_pt( | |
| in_channels: int, | |
| out_channels: int, | |
| relu: bool = False, | |
| bn: bool = False, | |
| **kwargs: Any, | |
| ) -> List[nn.Module]: | |
| """Builds a convolutional-based layer sequence | |
| >>> from torch.nn import Sequential | |
| >>> from doctr.models import conv_sequence | |
| >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) | |
| Args: | |
| ---- | |
| in_channels: number of input channels | |
| out_channels: number of output channels | |
| relu: whether ReLU should be used | |
| bn: should a batch normalization layer be added | |
| **kwargs: additional arguments to be passed to the convolutional layer | |
| Returns: | |
| ------- | |
| list of layers | |
| """ | |
| # No bias before Batch norm | |
| kwargs["bias"] = kwargs.get("bias", not bn) | |
| # Add activation directly to the conv if there is no BN | |
| conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)] | |
| if bn: | |
| conv_seq.append(nn.BatchNorm2d(out_channels)) | |
| if relu: | |
| conv_seq.append(nn.ReLU(inplace=True)) | |
| return conv_seq | |
| def set_device_and_dtype( | |
| model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype | |
| ) -> Tuple[Any, List[torch.Tensor]]: | |
| """Set the device and dtype of a model and its batches | |
| >>> import torch | |
| >>> from torch import nn | |
| >>> from doctr.models.utils import set_device_and_dtype | |
| >>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) | |
| >>> batches = [torch.rand(8) for _ in range(2)] | |
| >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16) | |
| Args: | |
| ---- | |
| model: the model to be set | |
| batches: the batches to be set | |
| device: the device to be used | |
| dtype: the dtype to be used | |
| Returns: | |
| ------- | |
| the model and batches set | |
| """ | |
| return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches] | |
| def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str: | |
| """Export model to ONNX format. | |
| >>> import torch | |
| >>> from doctr.models.classification import resnet18 | |
| >>> from doctr.models.utils import export_model_to_onnx | |
| >>> model = resnet18(pretrained=True) | |
| >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) | |
| Args: | |
| ---- | |
| model: the PyTorch model to be exported | |
| model_name: the name for the exported model | |
| dummy_input: the dummy input to the model | |
| kwargs: additional arguments to be passed to torch.onnx.export | |
| Returns: | |
| ------- | |
| the path to the exported model | |
| """ | |
| torch.onnx.export( | |
| model, | |
| dummy_input, | |
| f"{model_name}.onnx", | |
| input_names=["input"], | |
| output_names=["logits"], | |
| dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, | |
| export_params=True, | |
| verbose=False, | |
| **kwargs, | |
| ) | |
| logging.info(f"Model exported to {model_name}.onnx") | |
| return f"{model_name}.onnx" | |