| |
| import os |
| os.environ["HF_HOME"] = "/tmp" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp" |
| os.environ["TORCH_HOME"] = "/tmp" |
| os.environ["XDG_CACHE_HOME"] = "/tmp" |
|
|
| import io |
| import re |
| import math |
| import numpy as np |
| import scipy.io.wavfile |
| import torch |
| from fastapi import FastAPI, Query |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from transformers import VitsModel, AutoTokenizer |
|
|
| app = FastAPI() |
|
|
| model = VitsModel.from_pretrained("najiib9/somali_tts_final_model") |
| tokenizer = AutoTokenizer.from_pretrained("najiib9/somali_tts_final_model") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
|
|
| number_words = { |
| 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan", |
| 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban", |
| 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex", |
| 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix", |
| 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal", |
| 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton", |
| 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan", |
| 100: "boqol", 1000: "kun" |
| } |
|
|
| shortcut_map = { |
| "asc": "asalaamu caleykum", |
| "wcs": "wacaleykum salaam", |
| "fcn": "fiican", |
| "xld": "xaaladda ka waran", |
| "kwrn": "kawaran", |
| "scw": "salalaahu caleyhi wa salam", |
| "alx": "alxamdu lilaahi", |
| "m.a": "maasha allah", |
| "sthy": "side tahey", |
| "sxp": "saaxiib" |
| } |
|
|
| country_map = { |
| "somalia": "Soomaaliya", |
| "ethiopia": "Itoobiya", |
| "kenya": "Kenya", |
| "djibouti": "Jabuuti", |
| "sudan": "Suudaan", |
| "Yeman": "yemaan", |
| "uganda": "Ugaandha", |
| "tanzania": "Tansaaniya", |
| "egypt": "Masar", |
| "libya": "Liibiya", |
| "algeria": "Aljeeriya", |
| "morocco": "Morooko", |
| "tunisia": "Tuniisiya", |
| "eritrea": "Eriteriya", |
| "malawi": "Malaawi", |
| "English": "ingiriis", |
| "Spain": "isbeen", |
| "Brazil": "baraasiil", |
| "niger": "Niyjer", |
| "Italy": "itaaliya", |
| "united states": "Maraykanka", |
| "china": "Shiinaha", |
| "india": "Hindiya", |
| "russia": "Ruushka", |
| "Saudi Arabia": "Sucuudi Carabiya", |
| "germany": "Jarmalka", |
| "france": "Faransiiska", |
| "japan": "Jabaan", |
| "canada": "Kanada", |
| "australia": "Australia" |
| } |
|
|
| def number_to_words(number): |
| number = int(number) |
| if number < 20: |
| return number_words[number] |
| elif number < 100: |
| tens, unit = divmod(number, 10) |
| return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "") |
| elif number < 1000: |
| hundreds, remainder = divmod(number, 100) |
| part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol" |
| if remainder: |
| part += " iyo " + number_to_words(remainder) |
| return part |
| elif number < 1000000: |
| thousands, remainder = divmod(number, 1000) |
| words = [number_to_words(thousands) + " kun" if thousands > 1 else "kun"] |
| if remainder: |
| words.append("iyo " + number_to_words(remainder)) |
| return " ".join(words) |
| elif number < 1000000000: |
| millions, remainder = divmod(number, 1000000) |
| words = [number_to_words(millions) + " milyan" if millions > 1 else "milyan"] |
| if remainder: |
| words.append(number_to_words(remainder)) |
| return " ".join(words) |
| else: |
| return str(number) |
|
|
| def normalize_text(text): |
| text = re.sub(r'(?i)(?<!\w)zamzam(?!\w)', 'samsam', text) |
|
|
| def replace_shortcuts(match): |
| word = match.group(0).lower() |
| return shortcut_map.get(word, word) |
| pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in shortcut_map.keys()) + r')\b', re.IGNORECASE) |
| text = pattern.sub(replace_shortcuts, text) |
|
|
| def replace_countries(match): |
| word = match.group(0).lower() |
| return country_map.get(word, word) |
| country_pattern = re.compile(r'\b(' + '|'.join(re.escape(k) for k in country_map.keys()) + r')\b', re.IGNORECASE) |
| text = country_pattern.sub(replace_countries, text) |
|
|
| text = re.sub(r'(\d{1,3})(,\d{3})+', lambda m: m.group(0).replace(",", ""), text) |
| text = re.sub(r'\.\d+', '', text) |
|
|
| def replace_num(match): |
| return number_to_words(match.group()) |
| text = re.sub(r'\d+', replace_num, text) |
|
|
| symbol_map = { |
| '$': 'doolar', |
| '=': 'egwal', |
| '+': 'balaas', |
| '#': 'haash' |
| } |
| for sym, word in symbol_map.items(): |
| text = text.replace(sym, ' ' + word + ' ') |
|
|
| text = text.replace("KH", "qa").replace("Z", "S") |
| text = text.replace("SH", "SHa'a").replace("DH", "Dha'a") |
|
|
| if re.search(r'(?i)(zamzam|samsam)[\s\.,!?]*$', text.strip()): |
| text += " m" |
|
|
| return text |
|
|
| def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes: |
| np_waveform = waveform.cpu().numpy() |
| if np_waveform.ndim == 3: |
| np_waveform = np_waveform[0] |
| if np_waveform.ndim == 2: |
| np_waveform = np_waveform.mean(axis=0) |
| np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32) |
| pcm_waveform = (np_waveform * 32767).astype(np.int16) |
| buf = io.BytesIO() |
| scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) |
| buf.seek(0) |
| return buf.read() |
|
|
| class TextIn(BaseModel): |
| inputs: str |
|
|
| @app.post("/synthesize") |
| async def synthesize_post(data: TextIn): |
| paragraphs = [p.strip() for p in data.inputs.split('\n') if p.strip()] |
| sample_rate = getattr(model.config, "sampling_rate", 22050) |
| all_waveforms = [] |
|
|
| for paragraph in paragraphs: |
| normalized = normalize_text(paragraph) |
| inputs = tokenizer(normalized, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| output = model(**inputs) |
| waveform = ( |
| output.waveform if hasattr(output, "waveform") else |
| output["waveform"] if isinstance(output, dict) and "waveform" in output else |
| output[0] if isinstance(output, (tuple, list)) else |
| None |
| ) |
| if waveform is None: |
| continue |
| all_waveforms.append(waveform) |
| silence = torch.zeros(1, sample_rate).to(waveform.device) |
| all_waveforms.append(silence) |
|
|
| if not all_waveforms: |
| return {"error": "No audio generated."} |
|
|
| final_waveform = torch.cat(all_waveforms, dim=-1) |
| wav_bytes = waveform_to_wav_bytes(final_waveform, sample_rate=sample_rate) |
| return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") |
|
|
| @app.get("/synthesize") |
| async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)): |
| if test: |
| paragraphs = text.count("\n") + 1 |
| duration_s = paragraphs * 6 |
| sample_rate = 22050 |
| t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) |
| freq = 440 |
| waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32) |
| pcm_waveform = (waveform * 32767).astype(np.int16) |
| buf = io.BytesIO() |
| scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) |
| buf.seek(0) |
| return StreamingResponse(buf, media_type="audio/wav") |
|
|
| normalized = normalize_text(text) |
| inputs = tokenizer(normalized, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| output = model(**inputs) |
| waveform = ( |
| output.waveform if hasattr(output, "waveform") else |
| output["waveform"] if isinstance(output, dict) and "waveform" in output else |
| output[0] if isinstance(output, (tuple, list)) else |
| None |
| ) |
| if waveform is None: |
| return {"error": "Waveform not found in model output"} |
| sample_rate = getattr(model.config, "sampling_rate", 22050) |
| wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) |
| return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") |