Update modeling_hunyuan.py
Browse files- modeling_hunyuan.py +1 -0
modeling_hunyuan.py
CHANGED
|
@@ -8,6 +8,7 @@ import warnings
|
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 11 |
from torch import Tensor
|
| 12 |
import torch.nn.functional as F
|
| 13 |
import torch.utils.checkpoint
|
|
|
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
| 11 |
+
torch.set_default_dtype(torch.float32)
|
| 12 |
from torch import Tensor
|
| 13 |
import torch.nn.functional as F
|
| 14 |
import torch.utils.checkpoint
|