farazmoradi98 commited on
Commit
4b34089
·
verified ·
1 Parent(s): 4466951

Add custom handler for TTS inference

Browse files
Files changed (1) hide show
  1. handler.py +155 -0
handler.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Handler for Sesame CSM-1B TTS model deployment on Hugging Face Inference Endpoints
4
+ """
5
+
6
+ import os
7
+ import base64
8
+ import io
9
+ import torch
10
+ import numpy as np
11
+ from typing import Dict, Any, List
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import scipy.io.wavfile as wavfile
14
+
15
+ # Global variables for model and tokenizer
16
+ model = None
17
+ tokenizer = None
18
+
19
+ def init():
20
+ """
21
+ Initialize the model and tokenizer
22
+ This is called once when the endpoint starts
23
+ """
24
+ global model, tokenizer
25
+
26
+ print("Initializing CSM-1B model...")
27
+
28
+ # Set device
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ print(f"Using device: {device}")
31
+
32
+ try:
33
+ # Load tokenizer
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ "farazmoradi98/csm-1b", # Use your forked model
36
+ trust_remote_code=True
37
+ )
38
+
39
+ # Load model
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ "farazmoradi98/csm-1b", # Use your forked model
42
+ trust_remote_code=True,
43
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
+ device_map="auto"
45
+ )
46
+
47
+ print("✅ Model and tokenizer loaded successfully!")
48
+
49
+ except Exception as e:
50
+ print(f"❌ Error loading model: {e}")
51
+ raise
52
+
53
+ def generate_speech(text: str, speaker: int = 0) -> bytes:
54
+ """
55
+ Generate speech from text using CSM-1B model
56
+
57
+ Args:
58
+ text (str): Input text to convert to speech
59
+ speaker (int): Speaker ID (0-3 for CSM-1B)
60
+
61
+ Returns:
62
+ bytes: WAV audio data
63
+ """
64
+ global model, tokenizer
65
+
66
+ try:
67
+ # Tokenize input text
68
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
69
+
70
+ # Generate speech
71
+ with torch.no_grad():
72
+ output = model.generate(
73
+ **inputs,
74
+ speaker=speaker,
75
+ max_new_tokens=1024, # Adjust as needed
76
+ do_sample=True,
77
+ temperature=0.8,
78
+ top_p=0.9,
79
+ repetition_penalty=1.1
80
+ )
81
+
82
+ # Decode audio from model output
83
+ # CSM-1B outputs audio tokens that need to be converted to waveform
84
+ audio_tokens = output[0][inputs.input_ids.shape[1]:]
85
+ audio_array = model.decode_audio(audio_tokens)
86
+
87
+ # Convert to 16-bit PCM WAV
88
+ audio_array = (audio_array * 32767).astype(np.int16)
89
+
90
+ # Save to WAV buffer
91
+ wav_buffer = io.BytesIO()
92
+ wavfile.write(wav_buffer, 24000, audio_array) # CSM-1B uses 24kHz
93
+ wav_buffer.seek(0)
94
+
95
+ return wav_buffer.getvalue()
96
+
97
+ except Exception as e:
98
+ print(f"❌ Error generating speech: {e}")
99
+ raise
100
+
101
+ def handler(request: Dict[str, Any]) -> Dict[str, Any]:
102
+ """
103
+ Main handler function for Hugging Face Inference API
104
+
105
+ Args:
106
+ request (dict): Request containing input data
107
+
108
+ Returns:
109
+ dict: Response with base64 encoded audio
110
+ """
111
+ try:
112
+ # Extract inputs from request
113
+ inputs = request.get("inputs", {})
114
+
115
+ # Handle different input formats
116
+ if isinstance(inputs, str):
117
+ text = inputs
118
+ speaker = 0
119
+ elif isinstance(inputs, dict):
120
+ text = inputs.get("text", "")
121
+ speaker = inputs.get("speaker", 0)
122
+ else:
123
+ return {
124
+ "error": "Invalid input format. Expected string or dict with 'text' field."
125
+ }
126
+
127
+ if not text:
128
+ return {
129
+ "error": "No text provided for speech generation."
130
+ }
131
+
132
+ print(f"Generating speech for: '{text}' (speaker: {speaker})")
133
+
134
+ # Generate speech
135
+ audio_data = generate_speech(text, speaker)
136
+
137
+ # Convert to base64
138
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
139
+
140
+ return {
141
+ "audio": audio_base64,
142
+ "format": "wav",
143
+ "sample_rate": 24000,
144
+ "speaker": speaker
145
+ }
146
+
147
+ except Exception as e:
148
+ print(f"❌ Handler error: {e}")
149
+ return {
150
+ "error": f"Speech generation failed: {str(e)}"
151
+ }
152
+
153
+ # Initialize model on startup
154
+ if __name__ != "__main__":
155
+ init()