gautamtata commited on
Commit
6d23eb9
·
1 Parent(s): 8d6b413

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -1
handler.py CHANGED
@@ -1,10 +1,12 @@
1
  from transformers import AutoConfig, Wav2Vec2Processor
 
 
2
  from torch import nn
3
  import torch
4
  import io
5
  import torchaudio
6
  import torch.nn.functional as F
7
- from typing import Dict, List, Any
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
  import requests
10
  import tempfile
@@ -16,6 +18,18 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
16
  )
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class Wav2Vec2ClassificationHead(nn.Module):
20
  """Head for wav2vec classification task."""
21
 
 
1
  from transformers import AutoConfig, Wav2Vec2Processor
2
+ from transformers.file_utils import ModelOutput
3
+ from dataclasses import dataclass
4
  from torch import nn
5
  import torch
6
  import io
7
  import torchaudio
8
  import torch.nn.functional as F
9
+ from typing import Dict, List, Any, Optional, Tuple
10
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
  import requests
12
  import tempfile
 
18
  )
19
 
20
 
21
+
22
+
23
+
24
+
25
+ @dataclass
26
+ class SpeechClassifierOutput(ModelOutput):
27
+ loss: Optional[torch.FloatTensor] = None
28
+ logits: torch.FloatTensor = None
29
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
30
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
31
+
32
+
33
  class Wav2Vec2ClassificationHead(nn.Module):
34
  """Head for wav2vec classification task."""
35