Create handler.py

#32
by yuthrb - opened
Files changed (1) hide show
  1. handler.py +43 -0
handler.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
+ import scipy
4
+ import io
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # Explicitly load processor with local files
9
+ self.processor = AutoProcessor.from_pretrained(
10
+ path,
11
+ local_files_only=True,
12
+ trust_remote_code=True
13
+ )
14
+ self.model = MusicgenForConditionalGeneration.from_pretrained(
15
+ path,
16
+ local_files_only=True,
17
+ trust_remote_code=True
18
+ )
19
+
20
+ def __call__(self, data: Dict[str, Any]) -> bytes:
21
+ text = data.get("inputs", "")
22
+ duration = data.get("parameters", {}).get("duration", 5)
23
+
24
+ inputs = self.processor(
25
+ text=[text],
26
+ return_tensors="pt",
27
+ padding=True,
28
+ truncation=True
29
+ )
30
+
31
+ audio_values = self.model.generate(
32
+ **inputs,
33
+ max_new_tokens=int(duration * 50)
34
+ )
35
+
36
+ sampling_rate = self.model.config.audio_encoder.sampling_rate
37
+ with io.BytesIO() as wav_io:
38
+ scipy.io.wavfile.write(
39
+ wav_io,
40
+ rate=sampling_rate,
41
+ data=audio_values[0, 0].numpy()
42
+ )
43
+ return wav_io.getvalue()