Final_Speech / app.py
Shroukkkk's picture
Update app.py
91d7c2e verified
import os
import re
import uuid
import json
import inspect
import builtins
from typing import Optional
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# ----------------------------
# Caches (HF Spaces friendly)
# ----------------------------
os.environ.setdefault("HF_HOME", "/data/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/huggingface/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/huggingface/transformers")
os.environ.setdefault("XDG_CACHE_HOME", "/data/cache")
os.environ.setdefault("XDG_DATA_HOME", "/data/local/share")
# ----------------------------
# Coqui CPML non-interactive acceptance
# ----------------------------
os.makedirs("/data/local/share/tts", exist_ok=True)
try:
with open("/data/local/share/tts/.tos_agreed", "w") as f:
f.write("y")
except Exception:
pass
_real_input = builtins.input
def _auto_input(prompt=""):
p = (prompt or "").lower()
if "cpml" in p or "license" in p or "[y/n]" in p:
return os.environ.get("COQUI_TOS", "y")
try:
for frame in inspect.stack():
fname = frame.filename.replace("\\", "/")
if fname.endswith("/TTS/utils/manage.py") and frame.function == "ask_tos":
return os.environ.get("COQUI_TOS", "y")
except Exception:
pass
return _real_input(prompt)
builtins.input = _auto_input
# ----------------------------
# Egyptian Normalizer (LLM + TTS)
# ----------------------------
class EgyptianNormalizer:
def __init__(self):
self.replacements = [
(r"\bازيك\b", "إزيك"),
(r"\bازى\b", "إزي"),
(r"\bازاي\b", "إزاي"),
(r"\bايه\b", "إيه"),
(r"\bامتى\b", "إمتى"),
(r"\bعاوز\b", "عايز"),
(r"\bعاوزه\b", "عايزه"),
(r"\bدلوقت\b", "دلوقتي"),
(r"\bكدا\b", "كده"),
(r"\bعلشان\b", "عشان"),
(r"\bعميل ايه\b", "عامل إيه"),
(r"\bعامله ايه\b", "عاملة إيه"),
(r"\bماعرفش\b", "مش عارف"),
(r"\bماكنتش\b", "مش كنت"),
(r"\bمافهمتش\b", "مش فاهم"),
(r"\bok\b", "تمام"),
(r"\bokay\b", "تمام"),
(r"\bsorry\b", "معلش"),
]
self.question_words = ["ليه", "إيه", "فين", "إمتى", "إزاي", "إزي", "كام", "مين"]
self.diacritics = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
self.zero_width = re.compile(r"[\u200c\u200d\u200e\u200f\ufeff]")
self.tatweel = "\u0640"
def normalize(self, text: str, stage: str = "llm") -> str:
if not text:
return ""
t = str(text)
# Unicode cleanup
t = self.zero_width.sub("", t)
t = t.replace(self.tatweel, "")
t = t.replace("\n", " ")
# remove diacritics
t = self.diacritics.sub("", t)
# punctuation
t = t.replace("?", "؟").replace(",", "،").replace(";", "؛")
# letters
t = re.sub(r"[أإآٱ]", "ا", t)
t = t.replace("ى", "ي")
# elongation
t = re.sub(r"(.)\1{2,}", r"\1", t)
# numbers
t = t.translate(str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789"))
if stage == "tts":
t = t.replace("%", " في المية ")
for pattern, repl in self.replacements:
t = re.sub(pattern, repl, t, flags=re.IGNORECASE)
# small fixes
t = t.replace("قلتلك", "قلت لك").replace("قولتلك", "قلت لك")
t = t.replace("قلتلهم", "قلت لهم").replace("قولتلهم", "قلت لهم")
t = re.sub(r"\s+", " ", t).strip()
# enforce question mark
if any(w in t for w in self.question_words) and not t.endswith(("؟", "!", ".")):
t += "؟"
if stage == "tts":
# remove english words
t = re.sub(r"\b[a-zA-Z]+\b", "", t)
t = re.sub(r"\s+", " ", t).strip()
return t
normalizer = EgyptianNormalizer()
# ----------------------------
# Config
# ----------------------------
TRANSLATOR_MODEL = os.getenv("TRANSLATOR_MODEL", "oddadmix/Masrawy-BiLingual-v1")
QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
WHISPER_SIZE = os.getenv("WHISPER_SIZE", "small")
XTTS_MODEL_ID = os.getenv("XTTS_MODEL_ID", "tts_models/multilingual/multi-dataset/xtts_v2")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
USE_GPU = torch.cuda.is_available()
DEVICE_STR = "cuda" if USE_GPU else "cpu"
PIPELINE_DEVICE = 0 if USE_GPU else -1
# ----------------------------
# Load models once
# ----------------------------
# Whisper
whisper_compute = "float16" if USE_GPU else "int8"
whisper_model = WhisperModel(WHISPER_SIZE, device=DEVICE_STR, compute_type=whisper_compute)
# Translator
translator = pipeline("translation", model=TRANSLATOR_MODEL, device=PIPELINE_DEVICE)
# Qwen
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)
qwen = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
torch_dtype=torch.float16 if USE_GPU else torch.float32,
device_map="auto" if USE_GPU else None,
trust_remote_code=True,
)
if not USE_GPU:
qwen = qwen.to("cpu")
# XTTS lazy init
tts = None
def get_tts():
global tts
if tts is None:
from TTS.api import TTS
tts = TTS(XTTS_MODEL_ID, gpu=USE_GPU)
return tts
TRANSLATE_MAX_LEN = int(os.getenv("TRANSLATE_MAX_LEN", "256")) # زوديها لو عايزة
TRANSLATE_MIN_LEN = int(os.getenv("TRANSLATE_MIN_LEN", "1"))
def to_msa(text: str) -> str:
text = (text or "").strip()
if not text:
return ""
out = translator(
text + " <ar>",
max_length=TRANSLATE_MAX_LEN,
min_length=TRANSLATE_MIN_LEN,
truncation=True,
)[0]["translation_text"]
return out
def to_egyptian(text: str) -> str:
text = (text or "").strip()
if not text:
return ""
out = translator(
text + " <arz>",
max_length=TRANSLATE_MAX_LEN,
min_length=TRANSLATE_MIN_LEN,
truncation=True,
)[0]["translation_text"]
return out
_BANNED_PHRASES = [
"كمساعد", "كمساعد ذكي", "معلش", "آسف", "اعتذر", "مش عارف", "لا أستطيع", "غير قادر",
"لا يمكنني", "لا أقدر", "لا أملك معلومات", "قد لا يكون", "ربما", "عادةً", "بشكل عام"
]
def clean_egyptian(text: str) -> str:
t = (text or "").strip()
for p in _BANNED_PHRASES:
t = t.replace(p, "")
t = re.sub(r"\s+", " ", t).strip()
t = re.sub(r"[.،]{3,}", "…", t).strip()
if not t:
t = "تمام، قولي تحب تعمل ايه دلوقتي؟"
return t
# ----------------------------
# Whisper
# ----------------------------
def transcribe_file(path: str) -> str:
segments, _info = whisper_model.transcribe(path, language="ar")
return " ".join(seg.text for seg in segments).strip()
# ----------------------------
# Qwen generation in MSA
# ----------------------------
def qwen_generate_msa(msa_prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
msa_prompt = (msa_prompt or "").strip()
if not msa_prompt:
return ""
system_msg = (
"أنت مساعد شخصي عملي. "
"إذا كان سؤال المستخدم عامًا أو مفتوحًا، اقترح خطة أو خطوات عملية فورًا بدون اعتذار. "
"اجعل الرد قصيرًا ومباشرًا ومفيدًا. "
"اكتب باللغة العربية الفصحى البسيطة فقط."
)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": msa_prompt},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
if USE_GPU:
input_ids = input_ids.to(qwen.device)
with torch.no_grad():
output_ids = qwen.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
gen_ids = output_ids[0][input_ids.shape[-1]:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
# hard cleanup if any role words appear
text = re.sub(r"^(system|user|assistant)\s*[:\-]?\s*", "", text, flags=re.I).strip()
text = text.replace("assistant", "").replace("user", "").replace("system", "").strip()
return text
# ----------------------------
# Full pipeline: input -> MSA -> Qwen -> Egyptian
# ----------------------------
def generate_egyptian_reply(user_text: str) -> str:
user_text = normalizer.normalize(user_text, stage="llm")
if not user_text:
return "مسمعتش كويس. ممكن تعيد تاني؟"
msa_in = to_msa(user_text)
llm_msa = qwen_generate_msa(msa_in, MAX_NEW_TOKENS)
final_egy = clean_egyptian(to_egyptian(llm_msa))
final_egy = normalizer.normalize(final_egy, stage="llm")
return final_egy
# ----------------------------
# XTTS
# ----------------------------
def xtts_speak(text: str, speaker_wav_path: Optional[str] = None) -> str:
tts_local = get_tts()
out_path = f"/tmp/{uuid.uuid4().hex}.wav"
kwargs = {"language": "ar"}
if speaker_wav_path:
kwargs["speaker_wav"] = speaker_wav_path
else:
speakers = getattr(tts_local, "speakers", None)
if speakers and len(speakers) > 0:
kwargs["speaker"] = speakers[0]
tts_local.tts_to_file(text=text, file_path=out_path, **kwargs)
return out_path
# ----------------------------
# FastAPI + simple UI
# ----------------------------
app = FastAPI(title="Egyptian Arabic Speech Chatbot")
INDEX_HTML = """
<!doctype html>
<html lang="ar" dir="rtl">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Egyptian Voice Chatbot</title>
<style>
body { font-family: Arial, sans-serif; margin: 24px; }
.box { max-width: 780px; margin: 0 auto; }
textarea { width: 100%; min-height: 90px; }
.row { display: flex; gap: 12px; flex-wrap: wrap; }
.card { border: 1px solid #ddd; border-radius: 10px; padding: 12px; margin-top: 12px; }
button { padding: 10px 16px; cursor: pointer; }
input[type=file] { width: 100%; }
.muted { color: #666; font-size: 13px; }
</style>
</head>
<body>
<div class="box">
<h2>Egyptian Arabic Voice Chatbot</h2>
<p class="muted">اكتب نص أو ارفع ملف صوتي. الرد بيرجع نص + صوت باللهجة المصرية.</p>
<div class="card">
<label>Text</label>
<textarea id="text" placeholder="اكتب هنا..."></textarea>
<div class="row" style="margin-top: 12px;">
<div style="flex:1;">
<label>Audio input (optional)</label>
<input id="audio" type="file" accept="audio/*">
</div>
<div style="flex:1;">
<label>Speaker reference (optional)</label>
<input id="speaker" type="file" accept="audio/*">
</div>
</div>
<div style="margin-top: 12px;">
<button id="send">Send</button>
<span id="status" class="muted" style="margin-right: 10px;"></span>
</div>
</div>
<div class="card">
<h3>Result</h3>
<div><b>انت:</b> <span id="user_text"></span></div>
<div style="margin-top: 6px;"><b>المساعد:</b> <span id="assistant_text"></span></div>
<div style="margin-top: 10px;">
<audio id="player" controls></audio>
</div>
</div>
</div>
<script>
const sendBtn = document.getElementById("send");
const statusEl = document.getElementById("status");
const textEl = document.getElementById("text");
const audioEl = document.getElementById("audio");
const speakerEl = document.getElementById("speaker");
const userTextOut = document.getElementById("user_text");
const assistantTextOut = document.getElementById("assistant_text");
const player = document.getElementById("player");
sendBtn.addEventListener("click", async () => {
statusEl.textContent = "Sending...";
sendBtn.disabled = true;
try {
const form = new FormData();
form.append("text", textEl.value || "");
form.append("history", "[]");
if (audioEl.files.length > 0) form.append("audio", audioEl.files[0]);
if (speakerEl.files.length > 0) form.append("speaker_ref", speakerEl.files[0]);
const res = await fetch("/chat", { method: "POST", body: form });
if (!res.ok) {
const t = await res.text();
throw new Error(t);
}
const data = await res.json();
userTextOut.textContent = data.user_text || "";
assistantTextOut.textContent = data.assistant_text || "";
if (data.audio_url) {
player.src = data.audio_url;
player.load();
player.play().catch(() => {});
}
statusEl.textContent = "Done";
} catch (e) {
statusEl.textContent = "Error: " + (e.message || e);
} finally {
sendBtn.disabled = false;
setTimeout(() => statusEl.textContent = "", 4000);
}
});
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def index():
return HTMLResponse(INDEX_HTML)
@app.get("/health")
def health():
return {"ok": True, "gpu": USE_GPU}
@app.post("/chat")
async def chat(
text: Optional[str] = Form(default=None),
history: Optional[str] = Form(default="[]"),
audio: Optional[UploadFile] = File(default=None),
speaker_ref: Optional[UploadFile] = File(default=None),
):
# history currently unused but kept for compatibility
try:
_hist = json.loads(history or "[]")
except Exception:
_hist = []
audio_path = None
if audio is not None:
audio_path = f"/tmp/{uuid.uuid4().hex}_{audio.filename}"
with open(audio_path, "wb") as f:
f.write(await audio.read())
speaker_path = None
if speaker_ref is not None:
speaker_path = f"/tmp/{uuid.uuid4().hex}_{speaker_ref.filename}"
with open(speaker_path, "wb") as f:
f.write(await speaker_ref.read())
if audio_path:
user_text = transcribe_file(audio_path)
else:
user_text = (text or "").strip()
assistant_text = generate_egyptian_reply(user_text)
tts_text = normalizer.normalize(assistant_text, stage="tts")
wav_path = xtts_speak(tts_text, speaker_path)
return JSONResponse(
{
"user_text": user_text,
"assistant_text": assistant_text,
"audio_url": f"/audio?path={wav_path}",
}
)
@app.get("/audio")
def audio(path: str):
return FileResponse(path, media_type="audio/wav", filename="reply.wav")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)