JustinDuc commited on
Commit
531f25a
·
verified ·
1 Parent(s): e09c24a

Update saute_model.py

Browse files
Files changed (1) hide show
  1. saute_model.py +2 -2
saute_model.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, BertModel, BertTokenizerFast
4
  from transformers.modeling_outputs import MaskedLMOutput
5
- # from config import SAUTEConfig
6
 
7
  activation_to_class = {
8
  "gelu" : nn.GELU,
@@ -108,7 +108,7 @@ class EDUSpeakerAwareMLM(nn.Module):
108
 
109
 
110
  class UtteranceEmbedings(PreTrainedModel):
111
- # config_class = SAUTEConfig
112
 
113
  def __init__(self, config):
114
  super().__init__(config)
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, BertModel, BertTokenizerFast
4
  from transformers.modeling_outputs import MaskedLMOutput
5
+ from .saute_config import SAUTEConfig
6
 
7
  activation_to_class = {
8
  "gelu" : nn.GELU,
 
108
 
109
 
110
  class UtteranceEmbedings(PreTrainedModel):
111
+ config_class = SAUTEConfig
112
 
113
  def __init__(self, config):
114
  super().__init__(config)