camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import torch
from torch import nn as nn
from nemo.collections.common.parts import transformer_weights_init
from nemo.core.classes import Exportable, NeuralModule
from nemo.core.neural_types import ChannelType, NeuralType
__all__ = ['Classifier']
class Classifier(NeuralModule, Exportable):
"""
A baseclass for modules to perform various classification tasks.
"""
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
"""
Returns definitions of module input ports.
We implement it here since all NLP classifiers have the same inputs
"""
return {"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}
def __init__(self, hidden_size: int, dropout: float = 0.0,) -> None:
"""
Initializes the Classifier base module.
Args:
hidden_size: the size of the hidden dimension
dropout: dropout to apply to the input hidden states
"""
super().__init__()
self._hidden_size = hidden_size
self.dropout = nn.Dropout(dropout)
def post_init(self, use_transformer_init: bool):
"""
Common post-processing to be called at the end of concrete Classifiers init methods
Args:
use_transformer_init : whether or not to apply transformer_weights_init
"""
if use_transformer_init:
self.apply(lambda module: transformer_weights_init(module, xavier=False))
def input_example(self, max_batch=1, max_dim=256):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
sample = next(self.parameters())
example = torch.randn(max_batch, max_dim, self._hidden_size).to(sample.device).to(sample.dtype)
return tuple([example])
def save_to(self, save_path: str):
"""
Saves the module to the specified path.
Args:
save_path: Path to where to save the module.
"""
pass
@classmethod
def restore_from(cls, restore_path: str):
"""
Restores the module from the specified path.
Args:
restore_path: Path to restore the module from.
"""
pass