File size: 6,435 Bytes
ad9b287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import io
import os
from typing import Optional, Literal, Dict, Any, List

import numpy as np
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import torch
import nltk

from transformers import AutoTokenizer, AutoFeatureExtractor
from parler_tts import ParlerTTSForConditionalGeneration

# --- one-time setup ---
nltk.download("punkt_tab")

DEVICE = (
    "cuda:0" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32

# finetuned model only
FINETUNED_REPO_ID = "ai4bharat/indic-parler-tts"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    FINETUNED_REPO_ID, attn_implementation="eager", torch_dtype=TORCH_DTYPE
).to(DEVICE)

# tokenizers / feature extractor
# NOTE: the base repo id provides tokenizer & feature extractor
BASE_REPO_FOR_TOK = "ai4bharat/indic-parler-tts-pretrained"
tokenizer = AutoTokenizer.from_pretrained(BASE_REPO_FOR_TOK)
description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_REPO_FOR_TOK)

SAMPLE_RATE = feature_extractor.sampling_rate

# --- FastAPI app ---
app = FastAPI(title="Indic Parler-TTS (finetuned) API", version="1.0.0")

# Optional default voice descriptions per language
DEFAULT_DESCRIPTIONS: Dict[str, str] = {
    "english": (
        "A calm, neutral male voice speaks natural English at a moderate pace. "
        "Very clear audio with no background noise."
    ),
    "urdu": (
        "A warm, neutral female voice speaks natural Urdu at a moderate pace. "
        "Very clear audio with no background noise."
    ),
    "punjabi": (
        "A friendly, neutral male voice speaks natural Punjabi at a moderate pace. "
        "Very clear audio with no background noise."
    ),
}

def numpy_to_mp3(audio_array: np.ndarray, sampling_rate: int) -> bytes:
    """
    Converts mono int16/float array to MP3 (320 kbps).
    Uses pydub/ffmpeg; falls back to WAV if pydub not available.
    """
    try:
        from pydub import AudioSegment
        # normalize float → int16
        if np.issubdtype(audio_array.dtype, np.floating):
            max_val = np.max(np.abs(audio_array)) or 1.0
            audio_array = (audio_array / max_val) * 32767
            audio_array = audio_array.astype(np.int16)

        seg = AudioSegment(
            audio_array.tobytes(),
            frame_rate=sampling_rate,
            sample_width=audio_array.dtype.itemsize,
            channels=1,
        )
        buf = io.BytesIO()
        seg.export(buf, format="mp3", bitrate="320k")
        out = buf.getvalue()
        buf.close()
        return out
    except Exception:
        # fallback: WAV to keep things working even without ffmpeg
        import soundfile as sf
        buf = io.BytesIO()
        sf.write(buf, audio_array, sampling_rate, format="WAV", subtype="PCM_16")
        return buf.getvalue()

def split_text_into_chunks(text: str, max_words: int = 25) -> List[str]:
    sentences = nltk.sent_tokenize(text)
    curr = ""
    chunks: List[str] = []
    for s in sentences:
        candidate = (curr + " " + s).strip() if curr else s
        if len(candidate.split()) >= max_words and curr:
            chunks.append(curr.strip())
            curr = s
        else:
            curr = candidate
    if curr.strip():
        chunks.append(curr.strip())
    return chunks

def synthesize(text: str, description: str) -> np.ndarray:
    inputs = description_tokenizer(description, return_tensors="pt").to(DEVICE)
    chunks = split_text_into_chunks(text, max_words=25)

    all_audio = []
    for chunk in chunks:
        prompt = tokenizer(chunk, return_tensors="pt").to(DEVICE)
        generation = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            prompt_input_ids=prompt.input_ids,
            prompt_attention_mask=prompt.attention_mask,
            do_sample=True,
            return_dict_in_generate=True,
        )
        if hasattr(generation, "sequences") and hasattr(generation, "audios_length"):
            audio = generation.sequences[0, : generation.audios_length[0]]
            audio_np = audio.to(torch.float32).cpu().numpy().squeeze()
            if audio_np.ndim > 1:
                audio_np = audio_np.flatten()
            all_audio.append(audio_np)

    if not all_audio:
        raise RuntimeError("TTS generation produced no audio.")

    return np.concatenate(all_audio)

# ---- API schemas ----
class TTSRequest(BaseModel):
    text: str
    language: Optional[Literal["english", "urdu", "punjabi"]] = None
    voice_description: Optional[str] = None
    # "mp3" (default) or "wav" (force WAV fallback)
    format: Optional[Literal["mp3", "wav"]] = "mp3"

@app.get("/healthz")
def health() -> Dict[str, Any]:
    return {"status": "ok", "device": DEVICE, "sample_rate": SAMPLE_RATE}

@app.post("/tts")
def tts(body: TTSRequest):
    if not body.text or not body.text.strip():
        raise HTTPException(status_code=400, detail="`text` is required.")

    # choose description
    description = (
        body.voice_description
        or DEFAULT_DESCRIPTIONS.get((body.language or "").lower(), None)
        or "The speaker speaks naturally with a neutral tone. The recording is very high quality with no background noise."
    )

    try:
        audio = synthesize(body.text, description)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"generation_error: {e}")

    # return bytes stream
    if body.format == "wav":
        import soundfile as sf
        buf = io.BytesIO()
        sf.write(buf, audio, SAMPLE_RATE, format="WAV", subtype="PCM_16")
        buf.seek(0)
        return StreamingResponse(buf, media_type="audio/wav")

    # default: mp3 (falls back to WAV inside helper if mp3 fails)
    mp3_bytes = numpy_to_mp3(audio, SAMPLE_RATE)
    # crude detection if fallback produced WAV
    if mp3_bytes[:4] == b"RIFF":
        return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/wav")
    return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/mpeg")


# uvicorn entrypoint (Spaces sets PORT)
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))