speech_project / app.py
Shroukkkk's picture
Update app.py
55c5783 verified
import os
import re
import uuid
import json
import inspect
import builtins
from typing import Optional
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM
class EgyptianNormalizer:
def __init__(self):
self.replacements = [
(r"\bุงุฒูŠูƒ\b", "ุฅุฒูŠูƒ"),
(r"\bุงุฒู‰\b", "ุฅุฒูŠ"),
(r"\bุงุฒุงูŠ\b", "ุฅุฒุงูŠ"),
(r"\bุงูŠู‡\b", "ุฅูŠู‡"),
(r"\bู„ูŠู‡\b", "ู„ูŠู‡"),
(r"\bููŠู†\b", "ููŠู†"),
(r"\bุงู…ุชู‰\b", "ุฅู…ุชู‰"),
(r"\bุนุงูˆุฒ\b", "ุนุงูŠุฒ"),
(r"\bุนุงูˆุฒู‡\b", "ุนุงูŠุฒู‡"),
(r"\bุนุงูŠุฒ\b", "ุนุงูŠุฒ"),
(r"\bุฏู„ูˆู‚ุช\b", "ุฏู„ูˆู‚ุชูŠ"),
(r"\bูƒุฏุง\b", "ูƒุฏู‡"),
(r"\bุนู„ุดุงู†\b", "ุนุดุงู†"),
(r"\bุนู…ูŠู„ ุงูŠู‡\b", "ุนุงู…ู„ ุฅูŠู‡"),
(r"\bุนุงู…ู„ู‡ ุงูŠู‡\b", "ุนุงู…ู„ุฉ ุฅูŠู‡"),
(r"\bู…ุงุนุฑูุด\b", "ู…ุด ุนุงุฑู"),
(r"\bู…ุงูƒู†ุชุด\b", "ู…ุด ูƒู†ุช"),
(r"\bู…ุงูู‡ู…ุชุด\b", "ู…ุด ูุงู‡ู…"),
(r"\bok\b", "ุชู…ุงู…"),
(r"\bokay\b", "ุชู…ุงู…"),
(r"\bsorry\b", "ู…ุนู„ุด"),
]
self.question_words = ["ู„ูŠู‡", "ุฅูŠู‡", "ููŠู†", "ุฅู…ุชู‰", "ุฅุฒุงูŠ", "ุฅุฒูŠ", "ูƒุงู…", "ู…ูŠู†"]
self.diacritics = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
self.zero_width = re.compile(r"[\u200c\u200d\u200e\u200f\ufeff]")
self.tatweel = "\u0640"
def normalize(self, text: str, stage: str = "llm") -> str:
if not text:
return ""
t = str(text)
t = self.zero_width.sub("", t)
t = t.replace(self.tatweel, "")
t = t.replace("\n", " ")
t = self.diacritics.sub("", t)
t = t.replace("?", "ุŸ")
t = t.replace(",", "ุŒ")
t = t.replace(";", "ุ›")
t = re.sub(r"[ุฃุฅุขูฑ]", "ุง", t)
t = t.replace("ู‰", "ูŠ")
t = re.sub(r"(.)\1{2,}", r"\1", t)
t = t.translate(str.maketrans("ู ูกูขูฃูคูฅูฆูงูจูฉ", "0123456789"))
if stage == "tts":
t = t.replace("%", " ููŠ ุงู„ู…ูŠุฉ ")
for pattern, repl in self.replacements:
t = re.sub(pattern, repl, t, flags=re.IGNORECASE)
t = t.replace("ู‚ู„ุชู„ูƒ", "ู‚ู„ุช ู„ูƒ")
t = t.replace("ู‚ูˆู„ุชู„ูƒ", "ู‚ู„ุช ู„ูƒ")
t = t.replace("ู‚ู„ุชู„ู‡ู…", "ู‚ู„ุช ู„ู‡ู…")
t = t.replace("ู‚ูˆู„ุชู„ู‡ู…", "ู‚ู„ุช ู„ู‡ู…")
t = re.sub(r"\s+", " ", t).strip()
if any(w in t for w in self.question_words) and not t.endswith(("ุŸ", "!", ".")):
t += "ุŸ"
if stage == "tts":
t = re.sub(r"\b[a-zA-Z]+\b", "", t)
t = re.sub(r"\s+", " ", t).strip()
return t
# ----------------------------
# Caches (HF Spaces friendly)
# ----------------------------
os.environ.setdefault("HF_HOME", "/data/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/huggingface/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/huggingface/transformers")
os.environ.setdefault("XDG_CACHE_HOME", "/data/cache")
os.environ.setdefault("XDG_DATA_HOME", "/data/local/share")
# ----------------------------
# Coqui CPML non-interactive acceptance
# ----------------------------
os.makedirs("/data/local/share/tts", exist_ok=True)
try:
with open("/data/local/share/tts/.tos_agreed", "w") as f:
f.write("y")
except Exception:
pass
_real_input = builtins.input
def _auto_input(prompt=""):
p = (prompt or "").lower()
if "cpml" in p or "license" in p or "[y/n]" in p:
return os.environ.get("COQUI_TOS", "y")
try:
for frame in inspect.stack():
fname = frame.filename.replace("\\", "/")
if fname.endswith("/TTS/utils/manage.py") and frame.function == "ask_tos":
return os.environ.get("COQUI_TOS", "y")
except Exception:
pass
return _real_input(prompt)
builtins.input = _auto_input
# ----------------------------
# Optional CAMeL Tools normalization
# ----------------------------
try:
from camel_tools.utils.normalize import (
normalize_alef_maksura_ar,
normalize_alef_ar,
normalize_teh_marbuta_ar,
normalize_unicode,
)
CAMEL_OK = True
except Exception:
CAMEL_OK = False
# ----------------------------
# Config
# ----------------------------
QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
WHISPER_SIZE = os.getenv("WHISPER_SIZE", "small")
XTTS_MODEL_ID = os.getenv("XTTS_MODEL_ID", "tts_models/multilingual/multi-dataset/xtts_v2")
SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"ุงู†ุช ู…ุณุงุนุฏ ู…ุตุฑูŠ. ุฑุฏ ุจุงู„ู„ู‡ุฌุฉ ุงู„ู…ุตุฑูŠุฉ ูู‚ุท ูˆุจุงู„ุนุฑุจูŠ ูู‚ุท. ู…ู…ู†ูˆุน ุชุณุชุฎุฏู… ุงูŠ ูƒู„ุงู… ุงู†ุฌู„ูŠุฒูŠ. ุฑุฏูˆุฏูƒ ู‚ุตูŠุฑุฉ ูˆูˆุงุถุญุฉ."
)
MAX_TURNS = int(os.getenv("MAX_TURNS", "8"))
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
HAS_GPU = torch.cuda.is_available()
DEVICE_STR = "cuda" if HAS_GPU else "cpu"
normalizer = EgyptianNormalizer()
# ----------------------------
# Load Whisper + Qwen once
# ----------------------------
whisper_compute = "float16" if HAS_GPU else "int8"
whisper_model = WhisperModel(WHISPER_SIZE, device=DEVICE_STR, compute_type=whisper_compute)
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)
qwen = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
torch_dtype=torch.float16 if HAS_GPU else torch.float32,
device_map="auto" if HAS_GPU else None,
trust_remote_code=True,
)
# XTTS lazy init
tts = None
def get_tts():
global tts
if tts is None:
from TTS.api import TTS
tts = TTS(XTTS_MODEL_ID, gpu=HAS_GPU)
return tts
# ----------------------------
# Whisper
# ----------------------------
def transcribe_file(path: str) -> str:
segments, _info = whisper_model.transcribe(path, language="ar")
return " ".join(seg.text for seg in segments).strip()
# ----------------------------
# Qwen reply
# history format: list of [user, assistant]
# ----------------------------
def qwen_reply(history, user_text: str) -> str:
messages = [{"role": "system", "content": SYSTEM_PROMPT or ""}]
for u, a in (history or [])[-MAX_TURNS:]:
u = "" if u is None else str(u)
a = "" if a is None else str(a)
if u.strip():
messages.append({"role": "user", "content": u})
if a.strip():
messages.append({"role": "assistant", "content": a})
user_text = "" if user_text is None else str(user_text)
if not user_text.strip():
return "ู…ู…ูƒู† ุชูƒุชุจ ุณุคุงู„ูƒ ุชุงู†ูŠุŸ"
messages.append({"role": "user", "content": user_text})
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
if HAS_GPU:
inputs = {k: v.to(qwen.device) for k, v in inputs.items()}
with torch.no_grad():
out_ids = qwen.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the new tokens (not the prompt)
prompt_len = inputs["input_ids"].shape[1]
gen_ids = out_ids[0][prompt_len:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
# Extra cleanup if model still outputs role words
text = re.sub(r"^(system|user|assistant)\s*[:\-]?\s*", "", text, flags=re.I)
text = text.replace("assistant", "").replace("user", "").replace("system", "").strip()
return text
# ----------------------------
# XTTS
# ----------------------------
def xtts_speak(text: str, speaker_wav_path: Optional[str] = None) -> str:
tts_local = get_tts()
out_path = f"/tmp/{uuid.uuid4().hex}.wav"
kwargs = {
"language": "ar",
}
if speaker_wav_path:
kwargs["speaker_wav"] = speaker_wav_path
else:
speakers = getattr(tts_local, "speakers", None)
if speakers and len(speakers) > 0:
kwargs["speaker"] = speakers[0]
tts_local.tts_to_file(text=text, file_path=out_path, **kwargs)
return out_path
# ----------------------------
# FastAPI
# ----------------------------
app = FastAPI(title="Arabic Dialect Speech Chatbot")
INDEX_HTML = """
<!doctype html>
<html lang="ar" dir="rtl">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Arabic Dialect Voice Chatbot</title>
<style>
body { font-family: Arial, sans-serif; margin: 24px; }
.box { max-width: 780px; margin: 0 auto; }
textarea { width: 100%; min-height: 90px; }
.row { display: flex; gap: 12px; flex-wrap: wrap; }
.card { border: 1px solid #ddd; border-radius: 10px; padding: 12px; margin-top: 12px; }
button { padding: 10px 16px; cursor: pointer; }
input[type=file] { width: 100%; }
.muted { color: #666; font-size: 13px; }
</style>
</head>
<body>
<div class="box">
<h2>Arabic Dialect Voice Chatbot</h2>
<p class="muted">ุงูƒุชุจ ู†ุต ุฃูˆ ุงุฑูุน ู…ู„ู ุตูˆุชูŠ. ุงู„ุฑุฏ ูŠุฑุฌุน ู†ุต + ุตูˆุช.</p>
<div class="card">
<label>Text</label>
<textarea id="text" placeholder="ุงูƒุชุจ ู‡ู†ุง..."></textarea>
<div class="row" style="margin-top: 12px;">
<div style="flex:1;">
<label>Audio input (optional)</label>
<input id="audio" type="file" accept="audio/*">
</div>
<div style="flex:1;">
<label>Speaker reference (optional)</label>
<input id="speaker" type="file" accept="audio/*">
</div>
</div>
<div style="margin-top: 12px;">
<button id="send">Send</button>
<span id="status" class="muted" style="margin-right: 10px;"></span>
</div>
</div>
<div class="card">
<h3>Result</h3>
<div><b>User text:</b> <span id="user_text"></span></div>
<div style="margin-top: 6px;"><b>Assistant:</b> <span id="assistant_text"></span></div>
<div style="margin-top: 10px;">
<audio id="player" controls></audio>
</div>
</div>
</div>
<script>
const sendBtn = document.getElementById("send");
const statusEl = document.getElementById("status");
const textEl = document.getElementById("text");
const audioEl = document.getElementById("audio");
const speakerEl = document.getElementById("speaker");
const userTextOut = document.getElementById("user_text");
const assistantTextOut = document.getElementById("assistant_text");
const player = document.getElementById("player");
sendBtn.addEventListener("click", async () => {
statusEl.textContent = "Sending...";
sendBtn.disabled = true;
try {
const form = new FormData();
form.append("text", textEl.value || "");
form.append("history", "[]");
if (audioEl.files.length > 0) {
form.append("audio", audioEl.files[0]);
}
if (speakerEl.files.length > 0) {
form.append("speaker_ref", speakerEl.files[0]);
}
const res = await fetch("/chat", { method: "POST", body: form });
if (!res.ok) {
const t = await res.text();
throw new Error(t);
}
const data = await res.json();
userTextOut.textContent = data.user_text || "";
assistantTextOut.textContent = data.assistant_text || "";
if (data.audio_url) {
player.src = data.audio_url;
player.load();
player.play().catch(() => {});
}
statusEl.textContent = "Done";
} catch (e) {
statusEl.textContent = "Error: " + e.message;
} finally {
sendBtn.disabled = false;
setTimeout(() => statusEl.textContent = "", 4000);
}
});
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def index():
return HTMLResponse(INDEX_HTML)
@app.get("/health")
def health():
return {"ok": True, "gpu": HAS_GPU}
@app.post("/chat")
async def chat(
text: Optional[str] = Form(default=None),
history: Optional[str] = Form(default="[]"),
audio: Optional[UploadFile] = File(default=None),
speaker_ref: Optional[UploadFile] = File(default=None),
):
try:
hist = json.loads(history or "[]")
except Exception:
hist = []
audio_path = None
if audio is not None:
audio_path = f"/tmp/{uuid.uuid4().hex}_{audio.filename}"
with open(audio_path, "wb") as f:
f.write(await audio.read())
speaker_path = None
if speaker_ref is not None:
speaker_path = f"/tmp/{uuid.uuid4().hex}_{speaker_ref.filename}"
with open(speaker_path, "wb") as f:
f.write(await speaker_ref.read())
if audio_path:
user_text = transcribe_file(audio_path)
else:
user_text = (text or "").strip()
user_norm = normalizer.normalize(user_text, stage="llm")
if not user_norm:
assistant_text = "ู…ุณู…ุนุชุด ูƒูˆูŠุณ. ู…ู…ูƒู† ุชุนูŠุฏ ุชุงู†ูŠุŸ"
else:
assistant_text = qwen_reply(hist, user_norm)
tts_text = normalizer.normalize(assistant_text, stage="tts")
wav_path = xtts_speak(tts_text, speaker_path)
hist = (hist or []) + [[user_text, assistant_text]]
return JSONResponse(
{
"user_text": user_text,
"assistant_text": assistant_text,
"history": hist,
"audio_url": f"/audio?path={wav_path}",
}
)
@app.get("/audio")
def audio(path: str):
return FileResponse(path, media_type="audio/wav", filename="reply.wav")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)