| import torch | |
| from transformers import PreTrainedModel | |
| from ..utils import torch_gc | |
| class CPUTextEncoderWrapper(PreTrainedModel): | |
| def __init__(self, text_encoder, torch_dtype): | |
| super().__init__(text_encoder.config) | |
| self.config = text_encoder.config | |
| self._device = text_encoder.device | |
| # cpu not support float16 | |
| self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) | |
| self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) | |
| self.torch_dtype = torch_dtype | |
| del text_encoder | |
| torch_gc() | |
| def __call__(self, x, **kwargs): | |
| input_device = x.device | |
| original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs) | |
| for k, v in original_output.items(): | |
| if isinstance(v, tuple): | |
| original_output[k] = [ | |
| v[i].to(input_device).to(self.torch_dtype) for i in range(len(v)) | |
| ] | |
| else: | |
| original_output[k] = v.to(input_device).to(self.torch_dtype) | |
| return original_output | |
| def dtype(self): | |
| return self.torch_dtype | |
| def device(self) -> torch.device: | |
| """ | |
| `torch.device`: The device on which the module is (assuming that all the module parameters are on the same | |
| device). | |
| """ | |
| return self._device |