|
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp" |
|
|
os.environ["TORCH_HOME"] = "/tmp" |
|
|
os.environ["XDG_CACHE_HOME"] = "/tmp" |
|
|
|
|
|
import io |
|
|
import re |
|
|
import math |
|
|
import numpy as np |
|
|
import scipy.io.wavfile |
|
|
import torch |
|
|
from fastapi import FastAPI, Query |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from transformers import VitsModel, AutoTokenizer |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model = VitsModel.from_pretrained("Somali-tts/somali_tts_model") |
|
|
tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
number_words = { |
|
|
0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan", |
|
|
6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban", |
|
|
11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex", |
|
|
14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix", |
|
|
17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal", |
|
|
20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton", |
|
|
60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan", |
|
|
100: "boqol", 1000: "kun" |
|
|
} |
|
|
|
|
|
def number_to_words(number: int) -> str: |
|
|
if number < 20: |
|
|
return number_words[number] |
|
|
elif number < 100: |
|
|
tens, unit = divmod(number, 10) |
|
|
return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "") |
|
|
elif number < 1000: |
|
|
hundreds, remainder = divmod(number, 100) |
|
|
part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol" |
|
|
if remainder: |
|
|
part += " iyo " + number_to_words(remainder) |
|
|
return part |
|
|
elif number < 1000000: |
|
|
thousands, remainder = divmod(number, 1000) |
|
|
words = [] |
|
|
if thousands == 1: |
|
|
words.append("kun") |
|
|
else: |
|
|
words.append(number_to_words(thousands) + " kun") |
|
|
if remainder: |
|
|
words.append("iyo " + number_to_words(remainder)) |
|
|
return " ".join(words) |
|
|
elif number < 1000000000: |
|
|
millions, remainder = divmod(number, 1000000) |
|
|
words = [] |
|
|
if millions == 1: |
|
|
words.append("milyan") |
|
|
else: |
|
|
words.append(number_to_words(millions) + " milyan") |
|
|
if remainder: |
|
|
words.append(number_to_words(remainder)) |
|
|
return " ".join(words) |
|
|
else: |
|
|
return str(number) |
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
|
numbers = re.findall(r'\d+', text) |
|
|
for num in numbers: |
|
|
text = text.replace(num, number_to_words(int(num))) |
|
|
text = text.replace("KH", "qa").replace("Z", "S") |
|
|
text = text.replace("SH", "SHa'a").replace("DH", "Dha'a") |
|
|
text = text.replace("ZamZam", "SamSam") |
|
|
return text |
|
|
|
|
|
def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes: |
|
|
np_waveform = waveform.cpu().numpy() |
|
|
if np_waveform.ndim == 3: |
|
|
np_waveform = np_waveform[0] |
|
|
if np_waveform.ndim == 2: |
|
|
np_waveform = np_waveform.mean(axis=0) |
|
|
np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32) |
|
|
pcm_waveform = (np_waveform * 32767).astype(np.int16) |
|
|
buf = io.BytesIO() |
|
|
scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) |
|
|
buf.seek(0) |
|
|
return buf.read() |
|
|
|
|
|
class TextIn(BaseModel): |
|
|
inputs: str |
|
|
|
|
|
@app.post("/synthesize") |
|
|
async def synthesize_post(data: TextIn): |
|
|
text = normalize_text(data.inputs) |
|
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
output = model(**inputs) |
|
|
if hasattr(output, "waveform"): |
|
|
waveform = output.waveform |
|
|
elif isinstance(output, dict) and "waveform" in output: |
|
|
waveform = output["waveform"] |
|
|
elif isinstance(output, (tuple, list)): |
|
|
waveform = output[0] |
|
|
else: |
|
|
return {"error": "Waveform not found in model output"} |
|
|
sample_rate = getattr(model.config, "sampling_rate", 22050) |
|
|
wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) |
|
|
return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") |
|
|
|
|
|
@app.get("/synthesize") |
|
|
async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)): |
|
|
if test: |
|
|
duration_s = 2.0 |
|
|
sample_rate = 22050 |
|
|
t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) |
|
|
freq = 440 |
|
|
waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32) |
|
|
pcm_waveform = (waveform * 32767).astype(np.int16) |
|
|
buf = io.BytesIO() |
|
|
scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) |
|
|
buf.seek(0) |
|
|
return StreamingResponse(buf, media_type="audio/wav") |
|
|
normalized = normalize_text(text) |
|
|
inputs = tokenizer(normalized, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
output = model(**inputs) |
|
|
if hasattr(output, "waveform"): |
|
|
waveform = output.waveform |
|
|
elif isinstance(output, dict) and "waveform" in output: |
|
|
waveform = output["waveform"] |
|
|
elif isinstance(output, (tuple, list)): |
|
|
waveform = output[0] |
|
|
else: |
|
|
return {"error": "Waveform not found in model output"} |
|
|
sample_rate = getattr(model.config, "sampling_rate", 22050) |
|
|
wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) |
|
|
return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") |
|
|
|