Spaces:
Sleeping
Sleeping
File size: 6,435 Bytes
ad9b287 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import io
import os
from typing import Optional, Literal, Dict, Any, List
import numpy as np
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import torch
import nltk
from transformers import AutoTokenizer, AutoFeatureExtractor
from parler_tts import ParlerTTSForConditionalGeneration
# --- one-time setup ---
nltk.download("punkt_tab")
DEVICE = (
"cuda:0" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32
# finetuned model only
FINETUNED_REPO_ID = "ai4bharat/indic-parler-tts"
model = ParlerTTSForConditionalGeneration.from_pretrained(
FINETUNED_REPO_ID, attn_implementation="eager", torch_dtype=TORCH_DTYPE
).to(DEVICE)
# tokenizers / feature extractor
# NOTE: the base repo id provides tokenizer & feature extractor
BASE_REPO_FOR_TOK = "ai4bharat/indic-parler-tts-pretrained"
tokenizer = AutoTokenizer.from_pretrained(BASE_REPO_FOR_TOK)
description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_REPO_FOR_TOK)
SAMPLE_RATE = feature_extractor.sampling_rate
# --- FastAPI app ---
app = FastAPI(title="Indic Parler-TTS (finetuned) API", version="1.0.0")
# Optional default voice descriptions per language
DEFAULT_DESCRIPTIONS: Dict[str, str] = {
"english": (
"A calm, neutral male voice speaks natural English at a moderate pace. "
"Very clear audio with no background noise."
),
"urdu": (
"A warm, neutral female voice speaks natural Urdu at a moderate pace. "
"Very clear audio with no background noise."
),
"punjabi": (
"A friendly, neutral male voice speaks natural Punjabi at a moderate pace. "
"Very clear audio with no background noise."
),
}
def numpy_to_mp3(audio_array: np.ndarray, sampling_rate: int) -> bytes:
"""
Converts mono int16/float array to MP3 (320 kbps).
Uses pydub/ffmpeg; falls back to WAV if pydub not available.
"""
try:
from pydub import AudioSegment
# normalize float → int16
if np.issubdtype(audio_array.dtype, np.floating):
max_val = np.max(np.abs(audio_array)) or 1.0
audio_array = (audio_array / max_val) * 32767
audio_array = audio_array.astype(np.int16)
seg = AudioSegment(
audio_array.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_array.dtype.itemsize,
channels=1,
)
buf = io.BytesIO()
seg.export(buf, format="mp3", bitrate="320k")
out = buf.getvalue()
buf.close()
return out
except Exception:
# fallback: WAV to keep things working even without ffmpeg
import soundfile as sf
buf = io.BytesIO()
sf.write(buf, audio_array, sampling_rate, format="WAV", subtype="PCM_16")
return buf.getvalue()
def split_text_into_chunks(text: str, max_words: int = 25) -> List[str]:
sentences = nltk.sent_tokenize(text)
curr = ""
chunks: List[str] = []
for s in sentences:
candidate = (curr + " " + s).strip() if curr else s
if len(candidate.split()) >= max_words and curr:
chunks.append(curr.strip())
curr = s
else:
curr = candidate
if curr.strip():
chunks.append(curr.strip())
return chunks
def synthesize(text: str, description: str) -> np.ndarray:
inputs = description_tokenizer(description, return_tensors="pt").to(DEVICE)
chunks = split_text_into_chunks(text, max_words=25)
all_audio = []
for chunk in chunks:
prompt = tokenizer(chunk, return_tensors="pt").to(DEVICE)
generation = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_input_ids=prompt.input_ids,
prompt_attention_mask=prompt.attention_mask,
do_sample=True,
return_dict_in_generate=True,
)
if hasattr(generation, "sequences") and hasattr(generation, "audios_length"):
audio = generation.sequences[0, : generation.audios_length[0]]
audio_np = audio.to(torch.float32).cpu().numpy().squeeze()
if audio_np.ndim > 1:
audio_np = audio_np.flatten()
all_audio.append(audio_np)
if not all_audio:
raise RuntimeError("TTS generation produced no audio.")
return np.concatenate(all_audio)
# ---- API schemas ----
class TTSRequest(BaseModel):
text: str
language: Optional[Literal["english", "urdu", "punjabi"]] = None
voice_description: Optional[str] = None
# "mp3" (default) or "wav" (force WAV fallback)
format: Optional[Literal["mp3", "wav"]] = "mp3"
@app.get("/healthz")
def health() -> Dict[str, Any]:
return {"status": "ok", "device": DEVICE, "sample_rate": SAMPLE_RATE}
@app.post("/tts")
def tts(body: TTSRequest):
if not body.text or not body.text.strip():
raise HTTPException(status_code=400, detail="`text` is required.")
# choose description
description = (
body.voice_description
or DEFAULT_DESCRIPTIONS.get((body.language or "").lower(), None)
or "The speaker speaks naturally with a neutral tone. The recording is very high quality with no background noise."
)
try:
audio = synthesize(body.text, description)
except Exception as e:
raise HTTPException(status_code=500, detail=f"generation_error: {e}")
# return bytes stream
if body.format == "wav":
import soundfile as sf
buf = io.BytesIO()
sf.write(buf, audio, SAMPLE_RATE, format="WAV", subtype="PCM_16")
buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav")
# default: mp3 (falls back to WAV inside helper if mp3 fails)
mp3_bytes = numpy_to_mp3(audio, SAMPLE_RATE)
# crude detection if fallback produced WAV
if mp3_bytes[:4] == b"RIFF":
return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/wav")
return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/mpeg")
# uvicorn entrypoint (Spaces sets PORT)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860"))) |