File size: 4,578 Bytes
d334bcd
 
 
851663b
 
 
 
 
 
d334bcd
851663b
 
 
 
 
 
 
 
036ffff
d334bcd
 
 
 
4fadd0f
 
 
 
 
f79db70
d334bcd
914eb9e
d334bcd
 
 
fc72e9f
 
b06930a
 
 
 
 
 
d334bcd
 
 
f79db70
 
 
d334bcd
fc72e9f
d334bcd
f036d34
aaaab74
d334bcd
4fadd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d334bcd
4fadd0f
 
d334bcd
 
 
f036d34
d334bcd
4fadd0f
d334bcd
 
fc72e9f
d334bcd
 
 
 
 
 
 
 
 
 
4fadd0f
b06930a
 
 
 
 
 
 
 
 
 
 
4fadd0f
 
 
 
 
 
 
 
 
 
 
 
 
b06930a
 
 
 
 
 
 
4fadd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
d2e46d8
4fadd0f
 
 
17a0dd0
 
 
4fadd0f
 
 
 
 
 
 
b06930a
 
 
4fadd0f
 
 
b06930a
 
 
4fadd0f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# ===============================
# FORCE CPU ONLY (VERY TOP)
# ===============================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_FORCE_CPU"] = "1"

import torch

# ---- HARD FORCE torch.load → CPU ----
_original_torch_load = torch.load

def cpu_only_torch_load(*args, **kwargs):
    kwargs["map_location"] = torch.device("cpu")
    return _original_torch_load(*args, **kwargs)

torch.load = cpu_only_torch_load
torch.cuda.is_available = lambda: False

# ===============================
# STANDARD IMPORTS
# ===============================
from fastapi import FastAPI
from pydantic import BaseModel
import base64
import numpy as np
import io
from scipy.io.wavfile import write as write_wav

from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS

# ===============================
# GLOBAL MODEL CACHE
# ===============================
MODEL = None

# ===============================
# MAX QUOTA (from ENV)
# ===============================
TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10))  # default 10 requests/day
tts_usage = 0  # simple in-memory counter for demo

# ===============================
# MODEL LOADER
# ===============================
def get_or_load_model():
    global MODEL
    if MODEL is None:
        print("🔄 Loading ChatterboxMultilingualTTS (CPU ONLY)")
        MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
        print("✅ Model loaded on CPU")
    return MODEL

# ===============================
# SINGING FORMATTER
# ===============================
def format_for_singing(lyrics: str) -> str:
    lines = []
    for line in lyrics.splitlines():
        line = line.strip()
        if not line:
            continue
        # Stretch vowels lightly
        line = (
            line.replace("a", "aa")
                .replace("e", "ee")
                .replace("i", "ii")
                .replace("o", "oo")
                .replace("u", "uu")
        )
        lines.append(f"{line} ♪ ...")
    return "\n".join(lines)

# ===============================
# FASTAPI APP + LIFESPAN
# ===============================
from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Warmup on startup
    get_or_load_model()
    yield
    # No shutdown logic needed

app = FastAPI(lifespan=lifespan)

# ===============================
# HEALTH CHECK
# ===============================
@app.get("/health")
def health():
    return {
        "status": "ok",
        "device": "cpu",
        "cuda_available": torch.cuda.is_available()
    }

# ===============================
# QUOTA INFO
# ===============================
@app.get("/quota")
def get_quota():
    return {
        "used": tts_usage,
        "limit": TTS_MAX_QUOTA,
        "remaining": max(0, TTS_MAX_QUOTA - tts_usage)
    }

# ===============================
# TTS INPUT SCHEMA
# ===============================
class TTSPayload(BaseModel):
    text: str
    language_id: str = "en"
    mode: str = "Speak 🗣️"  # or "Sing 🎵"

# ===============================
# TTS ENDPOINT
# ===============================
@app.post("/tts")
def generate_tts(payload: TTSPayload):
    global tts_usage
    if tts_usage >= TTS_MAX_QUOTA:
        return {
            "error": "Quota exceeded",
            "message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow."
        }

    model = get_or_load_model()

    # Determine final text
    if payload.mode == "Sing 🎵":
        if not payload.text.strip():
            return {"error": "Lyrics required for Sing mode."}
        final_text = format_for_singing(payload.text)
    else:
        if not payload.text.strip():
            return {"error": "Text required for Speak mode."}
        final_text = payload.text

    # CPU-safe inference
    with torch.no_grad():
        wav = model.generate(
            final_text[:300],
            language_id=payload.language_id,
        )
        # convert tensor → numpy
        wav = wav.squeeze(0).detach().cpu().numpy()
        sr = model.sr

    # Convert numpy -> WAV bytes
    buf = io.BytesIO()
    write_wav(buf, sr, wav.astype(np.float32))
    buf.seek(0)
    audio_bytes = buf.read()

    # Increment quota usage
    tts_usage += 1

    # Return as base64
    return {
        "sr": sr,
        "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
        "quota_used": tts_usage,
        "quota_limit": TTS_MAX_QUOTA
    }

# ===============================
# RUN: uvicorn app:app --host 0.0.0.0 --port 7860
# ===============================