| |
| |
| |
| |
|
|
| import logging |
| from typing import List, Optional |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class FairseqDropout(nn.Module): |
| def __init__(self, p, module_name=None): |
| super().__init__() |
| self.p = p |
| self.module_name = module_name |
| self.apply_during_inference = False |
|
|
| def forward(self, x, inplace: bool = False): |
| if self.training or self.apply_during_inference: |
| return F.dropout(x, p=self.p, training=True, inplace=inplace) |
| else: |
| return x |
|
|
| def make_generation_fast_( |
| self, |
| name: str, |
| retain_dropout: bool = False, |
| retain_dropout_modules: Optional[List[str]] = None, |
| **kwargs |
| ): |
| if retain_dropout: |
| if retain_dropout_modules is not None and self.module_name is None: |
| logger.warning( |
| "Cannot enable dropout during inference for module {} " |
| "because module_name was not set".format(name) |
| ) |
| elif ( |
| retain_dropout_modules is None |
| or self.module_name in retain_dropout_modules |
| ): |
| logger.info( |
| "Enabling dropout during inference for module: {}".format(name) |
| ) |
| self.apply_during_inference = True |
| else: |
| logger.info("Disabling dropout for module: {}".format(name)) |
|
|