yuthrb commited on
Commit
b970ee2
·
verified ·
1 Parent(s): bd0768b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -33
handler.py CHANGED
@@ -1,43 +1,29 @@
1
- from typing import Dict, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
- import scipy
4
- import io
5
 
6
  class EndpointHandler:
7
- def __init__(self, path=""):
8
- # Explicitly load processor with local files
9
- self.processor = AutoProcessor.from_pretrained(
10
- path,
11
- local_files_only=True,
12
- trust_remote_code=True
13
- )
14
- self.model = MusicgenForConditionalGeneration.from_pretrained(
15
- path,
16
- local_files_only=True,
17
- trust_remote_code=True
18
- )
19
 
20
- def __call__(self, data: Dict[str, Any]) -> bytes:
21
- text = data.get("inputs", "")
22
- duration = data.get("parameters", {}).get("duration", 5)
23
-
24
  inputs = self.processor(
25
- text=[text],
26
- return_tensors="pt",
27
  padding=True,
28
- truncation=True
 
29
  )
30
-
 
 
31
  audio_values = self.model.generate(
32
  **inputs,
33
- max_new_tokens=int(duration * 50)
 
 
34
  )
35
-
36
- sampling_rate = self.model.config.audio_encoder.sampling_rate
37
- with io.BytesIO() as wav_io:
38
- scipy.io.wavfile.write(
39
- wav_io,
40
- rate=sampling_rate,
41
- data=audio_values[0, 0].numpy()
42
- )
43
- return wav_io.getvalue()
 
1
+ from typing import Dict, List, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
+ import torch
 
4
 
5
  class EndpointHandler:
6
+ def __init__(self, model_path):
7
+ self.processor = AutoProcessor.from_pretrained(model_path)
8
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_path)
9
+ if torch.cuda.is_available():
10
+ self.model = self.model.to("cuda")
 
 
 
 
 
 
 
11
 
12
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
13
  inputs = self.processor(
14
+ text=data["text"],
15
+ audio=data.get("audio", None),
16
  padding=True,
17
+ sampling_rate=data.get("sampling_rate", None),
18
+ return_tensors="pt",
19
  )
20
+ if torch.cuda.is_available():
21
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
22
+
23
  audio_values = self.model.generate(
24
  **inputs,
25
+ do_sample=data.get("do_sample", True),
26
+ guidance_scale=data.get("guidance_scale", 3),
27
+ max_new_tokens=data.get("max_new_tokens", 256),
28
  )
29
+ return {"audio_values": audio_values.cpu().numpy()}