MAS-AI-0000 commited on
Commit
ad9b287
·
verified ·
1 Parent(s): fe79872

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Optional, Literal, Dict, Any, List
4
+
5
+ import numpy as np
6
+ from fastapi import FastAPI, HTTPException, Query
7
+ from fastapi.responses import StreamingResponse, JSONResponse
8
+ from pydantic import BaseModel
9
+ import torch
10
+ import nltk
11
+
12
+ from transformers import AutoTokenizer, AutoFeatureExtractor
13
+ from parler_tts import ParlerTTSForConditionalGeneration
14
+
15
+ # --- one-time setup ---
16
+ nltk.download("punkt_tab")
17
+
18
+ DEVICE = (
19
+ "cuda:0" if torch.cuda.is_available()
20
+ else "mps" if torch.backends.mps.is_available()
21
+ else "cpu"
22
+ )
23
+ TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32
24
+
25
+ # finetuned model only
26
+ FINETUNED_REPO_ID = "ai4bharat/indic-parler-tts"
27
+
28
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
29
+ FINETUNED_REPO_ID, attn_implementation="eager", torch_dtype=TORCH_DTYPE
30
+ ).to(DEVICE)
31
+
32
+ # tokenizers / feature extractor
33
+ # NOTE: the base repo id provides tokenizer & feature extractor
34
+ BASE_REPO_FOR_TOK = "ai4bharat/indic-parler-tts-pretrained"
35
+ tokenizer = AutoTokenizer.from_pretrained(BASE_REPO_FOR_TOK)
36
+ description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
37
+ feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_REPO_FOR_TOK)
38
+
39
+ SAMPLE_RATE = feature_extractor.sampling_rate
40
+
41
+ # --- FastAPI app ---
42
+ app = FastAPI(title="Indic Parler-TTS (finetuned) API", version="1.0.0")
43
+
44
+ # Optional default voice descriptions per language
45
+ DEFAULT_DESCRIPTIONS: Dict[str, str] = {
46
+ "english": (
47
+ "A calm, neutral male voice speaks natural English at a moderate pace. "
48
+ "Very clear audio with no background noise."
49
+ ),
50
+ "urdu": (
51
+ "A warm, neutral female voice speaks natural Urdu at a moderate pace. "
52
+ "Very clear audio with no background noise."
53
+ ),
54
+ "punjabi": (
55
+ "A friendly, neutral male voice speaks natural Punjabi at a moderate pace. "
56
+ "Very clear audio with no background noise."
57
+ ),
58
+ }
59
+
60
+ def numpy_to_mp3(audio_array: np.ndarray, sampling_rate: int) -> bytes:
61
+ """
62
+ Converts mono int16/float array to MP3 (320 kbps).
63
+ Uses pydub/ffmpeg; falls back to WAV if pydub not available.
64
+ """
65
+ try:
66
+ from pydub import AudioSegment
67
+ # normalize float → int16
68
+ if np.issubdtype(audio_array.dtype, np.floating):
69
+ max_val = np.max(np.abs(audio_array)) or 1.0
70
+ audio_array = (audio_array / max_val) * 32767
71
+ audio_array = audio_array.astype(np.int16)
72
+
73
+ seg = AudioSegment(
74
+ audio_array.tobytes(),
75
+ frame_rate=sampling_rate,
76
+ sample_width=audio_array.dtype.itemsize,
77
+ channels=1,
78
+ )
79
+ buf = io.BytesIO()
80
+ seg.export(buf, format="mp3", bitrate="320k")
81
+ out = buf.getvalue()
82
+ buf.close()
83
+ return out
84
+ except Exception:
85
+ # fallback: WAV to keep things working even without ffmpeg
86
+ import soundfile as sf
87
+ buf = io.BytesIO()
88
+ sf.write(buf, audio_array, sampling_rate, format="WAV", subtype="PCM_16")
89
+ return buf.getvalue()
90
+
91
+ def split_text_into_chunks(text: str, max_words: int = 25) -> List[str]:
92
+ sentences = nltk.sent_tokenize(text)
93
+ curr = ""
94
+ chunks: List[str] = []
95
+ for s in sentences:
96
+ candidate = (curr + " " + s).strip() if curr else s
97
+ if len(candidate.split()) >= max_words and curr:
98
+ chunks.append(curr.strip())
99
+ curr = s
100
+ else:
101
+ curr = candidate
102
+ if curr.strip():
103
+ chunks.append(curr.strip())
104
+ return chunks
105
+
106
+ def synthesize(text: str, description: str) -> np.ndarray:
107
+ inputs = description_tokenizer(description, return_tensors="pt").to(DEVICE)
108
+ chunks = split_text_into_chunks(text, max_words=25)
109
+
110
+ all_audio = []
111
+ for chunk in chunks:
112
+ prompt = tokenizer(chunk, return_tensors="pt").to(DEVICE)
113
+ generation = model.generate(
114
+ input_ids=inputs.input_ids,
115
+ attention_mask=inputs.attention_mask,
116
+ prompt_input_ids=prompt.input_ids,
117
+ prompt_attention_mask=prompt.attention_mask,
118
+ do_sample=True,
119
+ return_dict_in_generate=True,
120
+ )
121
+ if hasattr(generation, "sequences") and hasattr(generation, "audios_length"):
122
+ audio = generation.sequences[0, : generation.audios_length[0]]
123
+ audio_np = audio.to(torch.float32).cpu().numpy().squeeze()
124
+ if audio_np.ndim > 1:
125
+ audio_np = audio_np.flatten()
126
+ all_audio.append(audio_np)
127
+
128
+ if not all_audio:
129
+ raise RuntimeError("TTS generation produced no audio.")
130
+
131
+ return np.concatenate(all_audio)
132
+
133
+ # ---- API schemas ----
134
+ class TTSRequest(BaseModel):
135
+ text: str
136
+ language: Optional[Literal["english", "urdu", "punjabi"]] = None
137
+ voice_description: Optional[str] = None
138
+ # "mp3" (default) or "wav" (force WAV fallback)
139
+ format: Optional[Literal["mp3", "wav"]] = "mp3"
140
+
141
+ @app.get("/healthz")
142
+ def health() -> Dict[str, Any]:
143
+ return {"status": "ok", "device": DEVICE, "sample_rate": SAMPLE_RATE}
144
+
145
+ @app.post("/tts")
146
+ def tts(body: TTSRequest):
147
+ if not body.text or not body.text.strip():
148
+ raise HTTPException(status_code=400, detail="`text` is required.")
149
+
150
+ # choose description
151
+ description = (
152
+ body.voice_description
153
+ or DEFAULT_DESCRIPTIONS.get((body.language or "").lower(), None)
154
+ or "The speaker speaks naturally with a neutral tone. The recording is very high quality with no background noise."
155
+ )
156
+
157
+ try:
158
+ audio = synthesize(body.text, description)
159
+ except Exception as e:
160
+ raise HTTPException(status_code=500, detail=f"generation_error: {e}")
161
+
162
+ # return bytes stream
163
+ if body.format == "wav":
164
+ import soundfile as sf
165
+ buf = io.BytesIO()
166
+ sf.write(buf, audio, SAMPLE_RATE, format="WAV", subtype="PCM_16")
167
+ buf.seek(0)
168
+ return StreamingResponse(buf, media_type="audio/wav")
169
+
170
+ # default: mp3 (falls back to WAV inside helper if mp3 fails)
171
+ mp3_bytes = numpy_to_mp3(audio, SAMPLE_RATE)
172
+ # crude detection if fallback produced WAV
173
+ if mp3_bytes[:4] == b"RIFF":
174
+ return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/wav")
175
+ return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/mpeg")
176
+
177
+
178
+ # uvicorn entrypoint (Spaces sets PORT)
179
+ if __name__ == "__main__":
180
+ import uvicorn
181
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))