import io import os from typing import Optional, Literal, Dict, Any, List import numpy as np from fastapi import FastAPI, HTTPException, Query from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import torch import nltk from transformers import AutoTokenizer, AutoFeatureExtractor from parler_tts import ParlerTTSForConditionalGeneration # --- one-time setup --- nltk.download("punkt_tab") DEVICE = ( "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32 # finetuned model only FINETUNED_REPO_ID = "ai4bharat/indic-parler-tts" model = ParlerTTSForConditionalGeneration.from_pretrained( FINETUNED_REPO_ID, attn_implementation="eager", torch_dtype=TORCH_DTYPE ).to(DEVICE) # tokenizers / feature extractor # NOTE: the base repo id provides tokenizer & feature extractor BASE_REPO_FOR_TOK = "ai4bharat/indic-parler-tts-pretrained" tokenizer = AutoTokenizer.from_pretrained(BASE_REPO_FOR_TOK) description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_REPO_FOR_TOK) SAMPLE_RATE = feature_extractor.sampling_rate # --- FastAPI app --- app = FastAPI(title="Indic Parler-TTS (finetuned) API", version="1.0.0") # Optional default voice descriptions per language DEFAULT_DESCRIPTIONS: Dict[str, str] = { "english": ( "A calm, neutral male voice speaks natural English at a moderate pace. " "Very clear audio with no background noise." ), "urdu": ( "A warm, neutral female voice speaks natural Urdu at a moderate pace. " "Very clear audio with no background noise." ), "punjabi": ( "A friendly, neutral male voice speaks natural Punjabi at a moderate pace. " "Very clear audio with no background noise." ), } def numpy_to_mp3(audio_array: np.ndarray, sampling_rate: int) -> bytes: """ Converts mono int16/float array to MP3 (320 kbps). Uses pydub/ffmpeg; falls back to WAV if pydub not available. """ try: from pydub import AudioSegment # normalize float → int16 if np.issubdtype(audio_array.dtype, np.floating): max_val = np.max(np.abs(audio_array)) or 1.0 audio_array = (audio_array / max_val) * 32767 audio_array = audio_array.astype(np.int16) seg = AudioSegment( audio_array.tobytes(), frame_rate=sampling_rate, sample_width=audio_array.dtype.itemsize, channels=1, ) buf = io.BytesIO() seg.export(buf, format="mp3", bitrate="320k") out = buf.getvalue() buf.close() return out except Exception: # fallback: WAV to keep things working even without ffmpeg import soundfile as sf buf = io.BytesIO() sf.write(buf, audio_array, sampling_rate, format="WAV", subtype="PCM_16") return buf.getvalue() def split_text_into_chunks(text: str, max_words: int = 25) -> List[str]: sentences = nltk.sent_tokenize(text) curr = "" chunks: List[str] = [] for s in sentences: candidate = (curr + " " + s).strip() if curr else s if len(candidate.split()) >= max_words and curr: chunks.append(curr.strip()) curr = s else: curr = candidate if curr.strip(): chunks.append(curr.strip()) return chunks def synthesize(text: str, description: str) -> np.ndarray: inputs = description_tokenizer(description, return_tensors="pt").to(DEVICE) chunks = split_text_into_chunks(text, max_words=25) all_audio = [] for chunk in chunks: prompt = tokenizer(chunk, return_tensors="pt").to(DEVICE) generation = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, prompt_input_ids=prompt.input_ids, prompt_attention_mask=prompt.attention_mask, do_sample=True, return_dict_in_generate=True, ) if hasattr(generation, "sequences") and hasattr(generation, "audios_length"): audio = generation.sequences[0, : generation.audios_length[0]] audio_np = audio.to(torch.float32).cpu().numpy().squeeze() if audio_np.ndim > 1: audio_np = audio_np.flatten() all_audio.append(audio_np) if not all_audio: raise RuntimeError("TTS generation produced no audio.") return np.concatenate(all_audio) # ---- API schemas ---- class TTSRequest(BaseModel): text: str language: Optional[Literal["english", "urdu", "punjabi"]] = None voice_description: Optional[str] = None # "mp3" (default) or "wav" (force WAV fallback) format: Optional[Literal["mp3", "wav"]] = "mp3" @app.get("/healthz") def health() -> Dict[str, Any]: return {"status": "ok", "device": DEVICE, "sample_rate": SAMPLE_RATE} @app.post("/tts") def tts(body: TTSRequest): if not body.text or not body.text.strip(): raise HTTPException(status_code=400, detail="`text` is required.") # choose description description = ( body.voice_description or DEFAULT_DESCRIPTIONS.get((body.language or "").lower(), None) or "The speaker speaks naturally with a neutral tone. The recording is very high quality with no background noise." ) try: audio = synthesize(body.text, description) except Exception as e: raise HTTPException(status_code=500, detail=f"generation_error: {e}") # return bytes stream if body.format == "wav": import soundfile as sf buf = io.BytesIO() sf.write(buf, audio, SAMPLE_RATE, format="WAV", subtype="PCM_16") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") # default: mp3 (falls back to WAV inside helper if mp3 fails) mp3_bytes = numpy_to_mp3(audio, SAMPLE_RATE) # crude detection if fallback produced WAV if mp3_bytes[:4] == b"RIFF": return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/wav") return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/mpeg") # uvicorn entrypoint (Spaces sets PORT) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))