Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional, Tuple, Union | |
| import f5_tts | |
| from f5_tts.model.backbones.dit_mask import DiT as DiT_ | |
| _GPU_FM_TORCH_COMPILE = True | |
| class GPUDiT(DiT_): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.fast_forward = torch.compile(self.fast_forward, dynamic=False, fullgraph=True) \ | |
| if _GPU_FM_TORCH_COMPILE else self.fast_forward | |
| # =================================================================== | |
| print("========================= DO FM PATCH ============================") | |
| # =================================================================== | |
| f5_tts.model.backbones.dit_mask.DiT = GPUDiT |