drrobot9 commited on
Commit
6d8a49c
·
verified ·
1 Parent(s): df818e6

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +113 -258
app/main.py CHANGED
@@ -1,274 +1,129 @@
1
- import io
2
- import re
3
- import os
4
- import gc
5
- import time
6
- import logging
7
- import asyncio
8
- import numpy as np
9
- import torch
10
- import soundfile as sf
11
- from typing import Optional, List, AsyncGenerator
12
- from contextlib import asynccontextmanager
13
  from fastapi import FastAPI, HTTPException
14
- from fastapi.responses import StreamingResponse, Response, JSONResponse
15
- from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel, Field
17
- from transformers import AutoModelForCausalLM
18
- from yarngpt.audiotokenizer import AudioTokenizerV2
19
- from concurrent.futures import ThreadPoolExecutor
20
-
21
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S")
22
- log = logging.getLogger("yarngpt-tts")
23
-
24
- BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
25
- if os.path.exists("/data"):
26
- CACHE_DIR = "/data/.cache"
27
- log.info("Using persistent storage at /data for cache")
28
- else:
29
- CACHE_DIR = os.path.join(BASE_DIR, ".cache")
30
- log.info(f"Using local cache directory: {CACHE_DIR}")
31
-
32
- os.environ['HF_HOME'] = CACHE_DIR
33
- os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'huggingface')
34
- os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch')
35
- os.environ['HUGGINGFACE_HUB_CACHE'] = os.path.join(CACHE_DIR, 'huggingface')
36
-
37
- MODELS_DIR = "/app/models"
38
- os.makedirs(MODELS_DIR, exist_ok=True)
39
- os.makedirs(CACHE_DIR, exist_ok=True)
40
-
41
- MODEL_ID = os.getenv("MODEL_ID", "saheedniyi/YarnGPT2b")
42
- WAV_TOKENIZER_CONFIG = os.getenv("WAV_TOKENIZER_CONFIG", "/app/models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
43
- WAV_TOKENIZER_CKPT = os.getenv("WAV_TOKENIZER_CKPT", "/app/models/wavtokenizer_large_speech_320_24k.ckpt")
44
- SAMPLE_RATE = int(os.getenv("SAMPLE_RATE", "24000"))
45
- WORD_LIMIT = int(os.getenv("WORD_LIMIT", "25"))
46
- MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "30000"))
47
- GENERATION_TEMP = float(os.getenv("GENERATION_TEMP", "0.1"))
48
- REPEAT_PENALTY = float(os.getenv("REPEAT_PENALTY", "1.1"))
49
- MAX_GEN_LENGTH = int(os.getenv("MAX_GEN_LENGTH", "4000"))
50
- SILENCE_TOKEN = int(os.getenv("SILENCE_TOKEN", "453"))
51
- SILENCE_FRAMES = int(os.getenv("SILENCE_FRAMES", "38"))
52
-
53
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
54
- DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
55
- log.info(f"Device: {DEVICE} | dtype: {DTYPE}")
56
-
57
- DEFAULT_LANGUAGE = "english"
58
- DEFAULT_SPEAKER = "jude"
59
- SUPPORTED_LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"]
60
- SUPPORTED_SPEAKERS = ["idera", "jude", "tayo", "zainab", "chisom", "regina", "umar", "emma", "osagie", "amara"]
61
 
62
- audio_tokenizer: Optional[AudioTokenizerV2] = None
63
- model_lm: Optional[AutoModelForCausalLM] = None
64
- executor = ThreadPoolExecutor(max_workers=4)
65
- semaphore = asyncio.Semaphore(600)
66
- request_queue = asyncio.Queue(maxsize=1000)
67
 
68
- class TTSRequest(BaseModel):
69
- text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
70
-
71
- @asynccontextmanager
72
- async def lifespan(app: FastAPI):
73
- global audio_tokenizer, model_lm
74
- log.info("Loading YarnGPT2b …")
75
- t0 = time.time()
76
-
77
- if not os.path.exists(WAV_TOKENIZER_CONFIG):
78
- raise RuntimeError(f"Model config not found at: {WAV_TOKENIZER_CONFIG}")
79
- if not os.path.exists(WAV_TOKENIZER_CKPT):
80
- raise RuntimeError(f"Model checkpoint not found at: {WAV_TOKENIZER_CKPT}")
81
-
82
- audio_tokenizer = AudioTokenizerV2(MODEL_ID, WAV_TOKENIZER_CKPT, WAV_TOKENIZER_CONFIG)
83
-
84
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
85
- torch.cuda.set_device(0)
86
-
87
- model_lm = AutoModelForCausalLM.from_pretrained(
88
- MODEL_ID,
89
- torch_dtype=DTYPE,
90
- device_map=None,
91
- low_cpu_mem_usage=True
92
- ).to(audio_tokenizer.device)
93
-
94
- model_lm.eval()
95
-
96
- if DEVICE == "cuda":
97
- torch.backends.cudnn.benchmark = True
98
-
99
- log.info(f"✓ Model loaded in {time.time()-t0:.1f}s")
100
- asyncio.create_task(queue_processor())
101
- yield
102
- del model_lm, audio_tokenizer
103
- gc.collect()
104
- if DEVICE == "cuda":
105
- torch.cuda.empty_cache()
106
- log.info("Model unloaded.")
107
 
108
- app = FastAPI(title="YarnGPT2b TTS Service", description="Nigerian-accented Text-to-Speech via YarnGPT2b", version="2.0.0", lifespan=lifespan)
109
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
110
 
111
- def split_text_into_chunks(text: str, word_limit: int = WORD_LIMIT) -> List[str]:
112
- text = re.sub(r"\s+", " ", text.strip())
113
- text = text.replace("...", ".")
114
- raw_sentences = re.split(r'(?<=[.!?])\s+', text)
115
- chunks: List[str] = []
116
- for sentence in raw_sentences:
117
- sentence = sentence.strip()
118
- if not sentence:
119
- continue
120
- chunks.append(".")
121
- words = sentence.split()
122
- for i in range(0, len(words), word_limit):
123
- chunks.append(" ".join(words[i:i + word_limit]))
124
- return chunks
125
 
126
- def speed_change_np(audio: np.ndarray, speed: float) -> np.ndarray:
127
- if speed == 1.0:
128
- return audio
129
- original_len = len(audio)
130
- target_len = int(original_len / speed)
131
- indices = np.linspace(0, original_len - 1, target_len)
132
- return np.interp(indices, np.arange(original_len), audio).astype(np.float32)
133
 
134
- async def generate_codes_for_chunk_async(chunk: str, language: str, speaker: str, temperature: float, repetition_penalty: float) -> List[int]:
135
- loop = asyncio.get_event_loop()
136
- def _generate():
137
- prompt = audio_tokenizer.create_prompt(chunk, lang=language, speaker_name=speaker)
138
- input_ids = audio_tokenizer.tokenize_prompt(prompt)
139
- if isinstance(input_ids, torch.Tensor):
140
- input_ids = input_ids.to(audio_tokenizer.device)
141
- with torch.inference_mode():
142
- output = model_lm.generate(
143
- input_ids=input_ids,
144
- temperature=temperature,
145
- repetition_penalty=repetition_penalty,
146
- max_length=MAX_GEN_LENGTH,
147
- do_sample=temperature > 0,
148
- pad_token_id=model_lm.config.eos_token_id,
149
- attention_mask=torch.ones_like(input_ids).to(audio_tokenizer.device)
150
- )
151
- return audio_tokenizer.get_codes(output)
152
- return await loop.run_in_executor(executor, _generate)
153
 
154
- def codes_to_audio_np(codes: List[int]) -> np.ndarray:
155
- audio_tensor = audio_tokenizer.get_audio(codes)
156
- return audio_tensor.squeeze().cpu().float().numpy()
157
 
158
- def make_wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
159
- byte_rate = sample_rate * channels * bits // 8
160
- block_align = channels * bits // 8
161
- return (b"RIFF" + b"\xff\xff\xff\xff" + b"WAVEfmt " + (16).to_bytes(4, "little") + (1).to_bytes(2, "little") + channels.to_bytes(2, "little") + sample_rate.to_bytes(4, "little") + byte_rate.to_bytes(4, "little") + block_align.to_bytes(2, "little") + bits.to_bytes(2, "little") + b"data" + b"\xff\xff\xff\xff")
 
162
 
163
- SILENCE_AUDIO = np.zeros(int(SAMPLE_RATE * 0.5), dtype=np.float32)
 
 
 
 
 
164
 
165
- async def queue_processor():
166
- while True:
167
- try:
168
- future, req_data = await asyncio.wait_for(request_queue.get(), timeout=1.0)
169
- try:
170
- result = await process_request_async(req_data)
171
- future.set_result(result)
172
- except Exception as e:
173
- future.set_exception(e)
174
- finally:
175
- request_queue.task_done()
176
- except asyncio.TimeoutError:
177
- continue
178
- except Exception as e:
179
- log.error(f"Queue processor error: {e}")
180
 
181
- async def process_request_async(req: TTSRequest) -> bytes:
182
- chunks = split_text_into_chunks(req.text, WORD_LIMIT)
183
- all_codes = []
184
- for chunk in chunks:
185
- if chunk == ".":
186
- all_codes.extend([SILENCE_TOKEN] * SILENCE_FRAMES)
187
- else:
188
- try:
189
- codes = await generate_codes_for_chunk_async(chunk, DEFAULT_LANGUAGE, DEFAULT_SPEAKER, GENERATION_TEMP, REPEAT_PENALTY)
190
- all_codes.extend(codes)
191
- except Exception as e:
192
- log.error(f"Chunk error: {e}")
193
- all_codes.extend([SILENCE_TOKEN] * SILENCE_FRAMES)
194
- audio_np = codes_to_audio_np(all_codes)
195
- audio_np = speed_change_np(audio_np, 1.0)
196
- buf = io.BytesIO()
197
- sf.write(buf, audio_np, SAMPLE_RATE, format="WAV", subtype="PCM_16")
198
- return buf.getvalue()
199
 
200
- async def audio_stream_generator(req: TTSRequest) -> AsyncGenerator[bytes, None]:
201
- chunks = split_text_into_chunks(req.text, WORD_LIMIT)
202
- log.info(f"Streaming {len(chunks)} chunks")
203
- yield make_wav_header(SAMPLE_RATE)
204
- silence_bytes = (SILENCE_AUDIO * 32767).astype(np.int16).tobytes()
205
- for chunk in chunks:
206
- if chunk == ".":
207
- yield silence_bytes
208
- continue
209
- try:
210
- codes = await generate_codes_for_chunk_async(chunk, DEFAULT_LANGUAGE, DEFAULT_SPEAKER, GENERATION_TEMP, REPEAT_PENALTY)
211
- audio_np = codes_to_audio_np(codes)
212
- audio_np = speed_change_np(audio_np, 1.0)
213
- audio_i16 = (audio_np * 32767).astype(np.int16)
214
- yield audio_i16.tobytes()
215
- except Exception as e:
216
- log.error(f"Stream chunk error: {e}")
217
- yield silence_bytes
218
 
219
- @app.get("/health", tags=["Meta"])
220
  async def health():
221
- gpu_info = {}
222
- if DEVICE == "cuda":
223
- gpu_info = {"gpu_name": torch.cuda.get_device_name(0), "vram_total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2), "vram_reserved_gb": round(torch.cuda.memory_reserved(0) / 1e9, 2), "vram_used_gb": round(torch.cuda.memory_allocated(0) / 1e9, 2)}
224
- models_status = {"config_exists": os.path.exists(WAV_TOKENIZER_CONFIG), "ckpt_exists": os.path.exists(WAV_TOKENIZER_CKPT), "models_dir": MODELS_DIR, "cache_dir": CACHE_DIR}
225
- return {"status": "ok", "model": MODEL_ID, "device": DEVICE, "dtype": str(DTYPE), "sample_rate": SAMPLE_RATE, "queue_size": request_queue.qsize(), "models": models_status, **gpu_info}
226
-
227
- @app.get("/voices", tags=["Meta"])
228
- async def list_voices():
229
- return {"languages": SUPPORTED_LANGUAGES, "speakers": SUPPORTED_SPEAKERS}
230
-
231
- @app.post("/tts/stream", tags=["TTS"])
232
- async def stream_tts(payload: TTSRequest):
233
- if model_lm is None:
234
- raise HTTPException(503, "Model not loaded yet.")
235
- async with semaphore:
236
- return StreamingResponse(audio_stream_generator(payload), media_type="audio/wav", headers={"Content-Disposition": "inline; filename=stream.wav"})
237
-
238
- @app.post("/tts/buffered", tags=["TTS"])
239
- async def buffered_tts(payload: TTSRequest):
240
- if model_lm is None:
241
- raise HTTPException(503, "Model not loaded yet.")
242
- async with semaphore:
243
- loop = asyncio.get_event_loop()
244
- future = loop.create_future()
245
- try:
246
- await asyncio.wait_for(request_queue.put((future, payload)), timeout=10000.0)
247
- audio_data = await asyncio.wait_for(future, timeout=120000.0)
248
- duration_s = len(audio_data) / (SAMPLE_RATE * 2)
249
- return Response(content=audio_data, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=tts.wav", "X-Audio-Duration-Sec": str(round(duration_s, 2))})
250
- except asyncio.TimeoutError:
251
- raise HTTPException(503, "Service busy, please try again later")
252
-
253
- @app.post("/tts/codes", tags=["TTS (Advanced)"])
254
- async def get_codes(payload: TTSRequest):
255
- if model_lm is None:
256
- raise HTTPException(503, "Model not loaded yet.")
257
- async with semaphore:
258
- chunks = split_text_into_chunks(payload.text, WORD_LIMIT)
259
- all_codes = []
260
- for chunk in chunks:
261
- if chunk == ".":
262
- all_codes.extend([SILENCE_TOKEN] * SILENCE_FRAMES)
263
- else:
264
- try:
265
- codes = await generate_codes_for_chunk_async(chunk, DEFAULT_LANGUAGE, DEFAULT_SPEAKER, GENERATION_TEMP, REPEAT_PENALTY)
266
- all_codes.extend(codes)
267
- except Exception as e:
268
- log.error(f"Code generation failed: {e}")
269
- all_codes.extend([SILENCE_TOKEN] * SILENCE_FRAMES)
270
- return JSONResponse({"codes": all_codes, "total_tokens": len(all_codes)})
271
-
272
- if __name__ == "__main__":
273
- import uvicorn
274
- uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=False, workers=1, limit_concurrency=1000, backlog=2048, timeout_keep_alive=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/main.py
2
+
3
+ import json
4
+ import httpx
 
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, HTTPException
6
+ from fastapi.responses import StreamingResponse, Response
 
7
  from pydantic import BaseModel, Field
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ CONFIG_PATH = "app/config.json"
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def load_config() -> dict:
13
+ with open(CONFIG_PATH, "r") as f:
14
+ return json.load(f)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ config = load_config()
 
 
 
 
 
 
18
 
19
+ EL_API_KEY = config["elevenlabs"]["api_key"]
20
+ VOICE_ID = config["elevenlabs"]["voice_id"]
21
+ MODEL_ID = config["elevenlabs"]["model_id"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ELEVENLABS_STREAM_URL = (
24
+ f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
25
+ )
26
 
27
+ HEADERS = {
28
+ "xi-api-key": EL_API_KEY,
29
+ "Content-Type": "application/json",
30
+ "Accept": "audio/mpeg",
31
+ }
32
 
33
+ app = FastAPI(
34
+ title="Production TTS Service",
35
+ version="1.0.2",
36
+ docs_url="/docs",
37
+ redoc_url="/redoc",
38
+ )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ class TTSRequest(BaseModel):
42
+ text: str = Field(..., min_length=1, max_length=5000)
43
+ stability: float = Field(0.5, ge=0.0, le=1.0)
44
+ similarity_boost: float = Field(0.5, ge=0.0, le=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ @app.get("/health")
48
  async def health():
49
+ return {"status": "ok"}
50
+
51
+
52
+ @app.post("/tts")
53
+ async def text_to_speech(payload: TTSRequest):
54
+ body = {
55
+ "text": payload.text,
56
+ "model_id": MODEL_ID,
57
+ "voice_settings": {
58
+ "stability": payload.stability,
59
+ "similarity_boost": payload.similarity_boost,
60
+ },
61
+ }
62
+
63
+ async def audio_stream():
64
+ async with httpx.AsyncClient(timeout=None) as client:
65
+ async with client.stream(
66
+ method="POST",
67
+ url=ELEVENLABS_STREAM_URL,
68
+ headers=HEADERS,
69
+ json=body,
70
+ ) as response:
71
+
72
+ if response.status_code != 200:
73
+ error = await response.aread()
74
+ raise HTTPException(
75
+ status_code=502,
76
+ detail=error.decode(),
77
+ )
78
+
79
+ async for chunk in response.aiter_bytes():
80
+ yield chunk
81
+
82
+ return StreamingResponse(
83
+ audio_stream(),
84
+ media_type="audio/mpeg",
85
+ headers={
86
+ "Content-Disposition": "inline; filename=tts.mp3",
87
+ },
88
+ )
89
+
90
+
91
+
92
+ @app.post("/tts/buffered")
93
+ async def text_to_speech_buffered(payload: TTSRequest):
94
+ body = {
95
+ "text": payload.text,
96
+ "model_id": MODEL_ID,
97
+ "voice_settings": {
98
+ "stability": payload.stability,
99
+ "similarity_boost": payload.similarity_boost,
100
+ },
101
+ }
102
+
103
+ async with httpx.AsyncClient(timeout=30.0) as client:
104
+ response = await client.post(
105
+ ELEVENLABS_STREAM_URL,
106
+ headers=HEADERS,
107
+ json=body,
108
+ )
109
+
110
+ if response.status_code != 200:
111
+ raise HTTPException(
112
+ status_code=502,
113
+ detail=response.text,
114
+ )
115
+
116
+ if not response.content:
117
+ raise HTTPException(
118
+ status_code=500,
119
+ detail="Received empty audio buffer",
120
+ )
121
+
122
+ return Response(
123
+ content=response.content,
124
+ media_type="audio/mpeg",
125
+ headers={
126
+ "Content-Disposition": "attachment; filename=tts.mp3",
127
+ "Content-Length": str(len(response.content)),
128
+ },
129
+ )