othsueh commited on
Commit
2b7a995
·
verified ·
1 Parent(s): aa9901a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -4
handler.py CHANGED
@@ -10,6 +10,7 @@ class EndpointHandler():
10
  def __init__(self, model_dir: str, **kwargs: Any) -> None:
11
  # Load config and model with trust_remote_code
12
  device = 'cuda'
 
13
  self.model = UpstreamFinetune.from_pretrained(
14
  model_dir,
15
  device=device,
@@ -25,14 +26,29 @@ class EndpointHandler():
25
  waveform, sr = torchaudio.load(io.BytesIO(audio))
26
  if sr != sampling_rate:
27
  waveform = torchaudio.functional.resample(waveform, sr, sampling_rate)
 
28
  # Forward pass
29
  with torch.no_grad():
30
  cat_logits, reg_outputs = self.model(
31
  waveform,
32
  sampling_rate
33
  )
34
- # Postprocess to Python types
35
- return [
36
- { "label": "arousal", "score" : reg_outputs[0]},
37
- { "label": "valence", "score": reg_outputs[1]}
 
 
 
 
 
 
 
 
 
 
 
 
38
  ]
 
 
 
10
  def __init__(self, model_dir: str, **kwargs: Any) -> None:
11
  # Load config and model with trust_remote_code
12
  device = 'cuda'
13
+ self.emotions = ['angry', 'sad', 'disgust', 'contempt', 'fear', 'neutral', 'surprise', 'happy']
14
  self.model = UpstreamFinetune.from_pretrained(
15
  model_dir,
16
  device=device,
 
26
  waveform, sr = torchaudio.load(io.BytesIO(audio))
27
  if sr != sampling_rate:
28
  waveform = torchaudio.functional.resample(waveform, sr, sampling_rate)
29
+
30
  # Forward pass
31
  with torch.no_grad():
32
  cat_logits, reg_outputs = self.model(
33
  waveform,
34
  sampling_rate
35
  )
36
+
37
+ # Convert logits to probabilities using softmax
38
+ emotion_probs = torch.nn.functional.softmax(cat_logits, dim=1)
39
+
40
+ # Create emotion predictions
41
+ emotion_predictions = []
42
+ for i, emotion in enumerate(self.emotions):
43
+ emotion_predictions.append({
44
+ "label": emotion,
45
+ "score": float(emotion_probs[0, i]) # Convert tensor to float
46
+ })
47
+
48
+ # Add arousal and valence predictions
49
+ result = emotion_predictions + [
50
+ {"label": "arousal", "score": float(reg_outputs[0, 0])},
51
+ {"label": "valence", "score": float(reg_outputs[0, 1])}
52
  ]
53
+
54
+ return result