Commit
·
4483827
1
Parent(s):
91801e8
Update custom.py
Browse files
custom.py
CHANGED
|
@@ -11,7 +11,7 @@ import logging
|
|
| 11 |
from torch import Tensor
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
-
from speechbrain.lobes.models.
|
| 15 |
try:
|
| 16 |
from transformers import GPT2LMHeadModel
|
| 17 |
from transformers import GPT2Tokenizer
|
|
@@ -23,7 +23,7 @@ except ImportError:
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
-
class HuggingFaceGPT_expanded(
|
| 27 |
"""This lobe enables the integration of HuggingFace pretrained GPT model.
|
| 28 |
Source paper whisper:
|
| 29 |
https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf
|
|
|
|
| 11 |
from torch import Tensor
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
+
from speechbrain.lobes.models.huggingface_transformers.gpt import GPT
|
| 15 |
try:
|
| 16 |
from transformers import GPT2LMHeadModel
|
| 17 |
from transformers import GPT2Tokenizer
|
|
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
+
class HuggingFaceGPT_expanded(GPT):
|
| 27 |
"""This lobe enables the integration of HuggingFace pretrained GPT model.
|
| 28 |
Source paper whisper:
|
| 29 |
https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf
|