Update model_hf.py
Browse files- model_hf.py +6 -5
model_hf.py
CHANGED
|
@@ -6,9 +6,9 @@ import torch
|
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from torch import Tensor
|
| 9 |
-
import fairseq
|
| 10 |
from .config_ssl import SSLConfig
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
|
| 13 |
___author__ = "Hemlata Tak"
|
| 14 |
__email__ = "tak@eurecom.fr"
|
|
@@ -21,10 +21,11 @@ __email__ = "tak@eurecom.fr"
|
|
| 21 |
class SSLModel(nn.Module):
|
| 22 |
def __init__(self,device):
|
| 23 |
super(SSLModel, self).__init__()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
self.model = model[0]
|
| 29 |
self.model_device=device
|
| 30 |
self.out_dim = 1024
|
|
|
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from torch import Tensor
|
|
|
|
| 9 |
from .config_ssl import SSLConfig
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
+
from transformers import Wav2Vec2ForPreTraining
|
| 12 |
|
| 13 |
___author__ = "Hemlata Tak"
|
| 14 |
__email__ = "tak@eurecom.fr"
|
|
|
|
| 21 |
class SSLModel(nn.Module):
|
| 22 |
def __init__(self,device):
|
| 23 |
super(SSLModel, self).__init__()
|
| 24 |
+
# eliminate fairseq dependency
|
| 25 |
+
repo_id = "facebook/wav2vec2-xlsr-300m"
|
| 26 |
+
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xlsr-300m")
|
| 27 |
+
# cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
|
| 28 |
+
# model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
|
| 29 |
self.model = model[0]
|
| 30 |
self.model_device=device
|
| 31 |
self.out_dim = 1024
|