forwarder1121 commited on
Commit
dd7709d
ยท
verified ยท
1 Parent(s): 1dce8fe

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from models import StudentForAudioClassification
3
+ import torch
4
+ import torchaudio
5
+
6
+ # 1. W2V extractor: ์‹ค์ œ ์‚ฌ์šฉํ•œ W2V๋กœ ๊ต์ฒด! (์•„๋ž˜๋Š” ์˜ˆ์‹œ)
7
+ bundle = torchaudio.pipelines.WAV2VEC2_BASE
8
+ w2v_model = bundle.get_model()
9
+ w2v_model.eval()
10
+
11
+ def preprocess(audio_bytes):
12
+ # 1. ์˜ค๋””์˜ค ํŒŒ์ผ์„ waveform์œผ๋กœ ๋กœ๋“œ (16kHz ๋ณ€ํ™˜)
13
+ import io
14
+ waveform, orig_sr = torchaudio.load(io.BytesIO(audio_bytes))
15
+ waveform = waveform.mean(dim=0, keepdim=True) # mono ๋ณ€ํ™˜
16
+ if orig_sr != 16000:
17
+ resampler = torchaudio.transforms.Resample(orig_sr, 16000)
18
+ waveform = resampler(waveform)
19
+
20
+ # 2. W2V embedding ์ถ”์ถœ (์—ฌ๊ธฐ์„  mean pooling)
21
+ with torch.no_grad():
22
+ features = w2v_model(waveform)[0] # (1, T, 512)
23
+ x_w2v = features.mean(dim=1) # (1, 512)
24
+ return x_w2v
25
+
26
+ def inference(model, inputs):
27
+ with torch.no_grad():
28
+ outputs = model(inputs)
29
+ probs = torch.softmax(outputs.logits, dim=-1)
30
+ return {
31
+ "probabilities": probs.squeeze(0).tolist(), # [not_stressed_prob, stressed_prob]
32
+ "label": int(probs.argmax(dim=-1)[0])
33
+ }
34
+
35
+ def init():
36
+ model = StudentForAudioClassification.from_pretrained(".", trust_remote_code=True)
37
+ model.eval()
38
+ return model