msabonkudi commited on
Commit
4e08dd5
·
verified ·
1 Parent(s): 5038cb2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import torch
4
+ import torchaudio
5
+ from pydub import AudioSegment
6
+ from chatterbox import mtl_tts
7
+ from huggingface_hub import snapshot_download
8
+ from safetensors.torch import load_file as load_safetensors
9
+ from fastapi import FastAPI
10
+ from fastapi.responses import JSONResponse
11
+ from pydantic import BaseModel
12
+
13
+ app = FastAPI()
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ print("⏳ Loading model...")
18
+
19
+ ckpt_dir = snapshot_download(
20
+ repo_id="NAMAA-Space/NAMAA-Saudi-TTS",
21
+ repo_type="model"
22
+ )
23
+
24
+ model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=device)
25
+
26
+ t3_state = load_safetensors(f"{ckpt_dir}/t3_mtl23ls_v2.safetensors", device=device)
27
+ model.t3.load_state_dict(t3_state)
28
+
29
+ print(f"✅ Model loaded on {device}")
30
+
31
+
32
+ class TTSRequest(BaseModel):
33
+ text: str
34
+
35
+
36
+ @app.post("/tts")
37
+ def tts(req: TTSRequest):
38
+ try:
39
+ with torch.no_grad():
40
+ audio_tensor = model.generate(req.text)
41
+
42
+ audio_np = audio_tensor.cpu().numpy().squeeze()
43
+ sample_rate = 24000
44
+
45
+ audio_io = io.BytesIO()
46
+ torchaudio.save(
47
+ audio_io,
48
+ torch.from_numpy(audio_np).unsqueeze(0),
49
+ sample_rate,
50
+ format="wav"
51
+ )
52
+ audio_io.seek(0)
53
+
54
+ audio = AudioSegment.from_wav(audio_io)
55
+ mp3_io = io.BytesIO()
56
+ audio.export(mp3_io, format="mp3", bitrate="192k")
57
+ mp3_io.seek(0)
58
+
59
+ audio_base64 = base64.b64encode(mp3_io.read()).decode("utf-8")
60
+
61
+ return {"audio": audio_base64}
62
+
63
+ except Exception as e:
64
+ return JSONResponse({"error": str(e)}, status_code=500)