Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import uuid | |
| import json | |
| import inspect | |
| import builtins | |
| from typing import Optional | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from faster_whisper import WhisperModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # ---------------------------- | |
| # Caches (HF Spaces friendly) | |
| # ---------------------------- | |
| os.environ.setdefault("HF_HOME", "/data/huggingface") | |
| os.environ.setdefault("HF_HUB_CACHE", "/data/huggingface/hub") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/data/huggingface/transformers") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/data/cache") | |
| os.environ.setdefault("XDG_DATA_HOME", "/data/local/share") | |
| # ---------------------------- | |
| # Coqui CPML non-interactive acceptance | |
| # ---------------------------- | |
| os.makedirs("/data/local/share/tts", exist_ok=True) | |
| try: | |
| with open("/data/local/share/tts/.tos_agreed", "w") as f: | |
| f.write("y") | |
| except Exception: | |
| pass | |
| _real_input = builtins.input | |
| def _auto_input(prompt=""): | |
| p = (prompt or "").lower() | |
| if "cpml" in p or "license" in p or "[y/n]" in p: | |
| return os.environ.get("COQUI_TOS", "y") | |
| try: | |
| for frame in inspect.stack(): | |
| fname = frame.filename.replace("\\", "/") | |
| if fname.endswith("/TTS/utils/manage.py") and frame.function == "ask_tos": | |
| return os.environ.get("COQUI_TOS", "y") | |
| except Exception: | |
| pass | |
| return _real_input(prompt) | |
| builtins.input = _auto_input | |
| # ---------------------------- | |
| # Egyptian Normalizer (LLM + TTS) | |
| # ---------------------------- | |
| class EgyptianNormalizer: | |
| def __init__(self): | |
| self.replacements = [ | |
| (r"\bازيك\b", "إزيك"), | |
| (r"\bازى\b", "إزي"), | |
| (r"\bازاي\b", "إزاي"), | |
| (r"\bايه\b", "إيه"), | |
| (r"\bامتى\b", "إمتى"), | |
| (r"\bعاوز\b", "عايز"), | |
| (r"\bعاوزه\b", "عايزه"), | |
| (r"\bدلوقت\b", "دلوقتي"), | |
| (r"\bكدا\b", "كده"), | |
| (r"\bعلشان\b", "عشان"), | |
| (r"\bعميل ايه\b", "عامل إيه"), | |
| (r"\bعامله ايه\b", "عاملة إيه"), | |
| (r"\bماعرفش\b", "مش عارف"), | |
| (r"\bماكنتش\b", "مش كنت"), | |
| (r"\bمافهمتش\b", "مش فاهم"), | |
| (r"\bok\b", "تمام"), | |
| (r"\bokay\b", "تمام"), | |
| (r"\bsorry\b", "معلش"), | |
| ] | |
| self.question_words = ["ليه", "إيه", "فين", "إمتى", "إزاي", "إزي", "كام", "مين"] | |
| self.diacritics = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]") | |
| self.zero_width = re.compile(r"[\u200c\u200d\u200e\u200f\ufeff]") | |
| self.tatweel = "\u0640" | |
| def normalize(self, text: str, stage: str = "llm") -> str: | |
| if not text: | |
| return "" | |
| t = str(text) | |
| # Unicode cleanup | |
| t = self.zero_width.sub("", t) | |
| t = t.replace(self.tatweel, "") | |
| t = t.replace("\n", " ") | |
| # remove diacritics | |
| t = self.diacritics.sub("", t) | |
| # punctuation | |
| t = t.replace("?", "؟").replace(",", "،").replace(";", "؛") | |
| # letters | |
| t = re.sub(r"[أإآٱ]", "ا", t) | |
| t = t.replace("ى", "ي") | |
| # elongation | |
| t = re.sub(r"(.)\1{2,}", r"\1", t) | |
| # numbers | |
| t = t.translate(str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")) | |
| if stage == "tts": | |
| t = t.replace("%", " في المية ") | |
| for pattern, repl in self.replacements: | |
| t = re.sub(pattern, repl, t, flags=re.IGNORECASE) | |
| # small fixes | |
| t = t.replace("قلتلك", "قلت لك").replace("قولتلك", "قلت لك") | |
| t = t.replace("قلتلهم", "قلت لهم").replace("قولتلهم", "قلت لهم") | |
| t = re.sub(r"\s+", " ", t).strip() | |
| # enforce question mark | |
| if any(w in t for w in self.question_words) and not t.endswith(("؟", "!", ".")): | |
| t += "؟" | |
| if stage == "tts": | |
| # remove english words | |
| t = re.sub(r"\b[a-zA-Z]+\b", "", t) | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return t | |
| normalizer = EgyptianNormalizer() | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| TRANSLATOR_MODEL = os.getenv("TRANSLATOR_MODEL", "oddadmix/Masrawy-BiLingual-v1") | |
| QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct") | |
| WHISPER_SIZE = os.getenv("WHISPER_SIZE", "small") | |
| XTTS_MODEL_ID = os.getenv("XTTS_MODEL_ID", "tts_models/multilingual/multi-dataset/xtts_v2") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) | |
| USE_GPU = torch.cuda.is_available() | |
| DEVICE_STR = "cuda" if USE_GPU else "cpu" | |
| PIPELINE_DEVICE = 0 if USE_GPU else -1 | |
| # ---------------------------- | |
| # Load models once | |
| # ---------------------------- | |
| # Whisper | |
| whisper_compute = "float16" if USE_GPU else "int8" | |
| whisper_model = WhisperModel(WHISPER_SIZE, device=DEVICE_STR, compute_type=whisper_compute) | |
| # Translator | |
| translator = pipeline("translation", model=TRANSLATOR_MODEL, device=PIPELINE_DEVICE) | |
| # Qwen | |
| tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True) | |
| qwen = AutoModelForCausalLM.from_pretrained( | |
| QWEN_MODEL_ID, | |
| torch_dtype=torch.float16 if USE_GPU else torch.float32, | |
| device_map="auto" if USE_GPU else None, | |
| trust_remote_code=True, | |
| ) | |
| if not USE_GPU: | |
| qwen = qwen.to("cpu") | |
| # XTTS lazy init | |
| tts = None | |
| def get_tts(): | |
| global tts | |
| if tts is None: | |
| from TTS.api import TTS | |
| tts = TTS(XTTS_MODEL_ID, gpu=USE_GPU) | |
| return tts | |
| TRANSLATE_MAX_LEN = int(os.getenv("TRANSLATE_MAX_LEN", "256")) # زوديها لو عايزة | |
| TRANSLATE_MIN_LEN = int(os.getenv("TRANSLATE_MIN_LEN", "1")) | |
| def to_msa(text: str) -> str: | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| out = translator( | |
| text + " <ar>", | |
| max_length=TRANSLATE_MAX_LEN, | |
| min_length=TRANSLATE_MIN_LEN, | |
| truncation=True, | |
| )[0]["translation_text"] | |
| return out | |
| def to_egyptian(text: str) -> str: | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| out = translator( | |
| text + " <arz>", | |
| max_length=TRANSLATE_MAX_LEN, | |
| min_length=TRANSLATE_MIN_LEN, | |
| truncation=True, | |
| )[0]["translation_text"] | |
| return out | |
| _BANNED_PHRASES = [ | |
| "كمساعد", "كمساعد ذكي", "معلش", "آسف", "اعتذر", "مش عارف", "لا أستطيع", "غير قادر", | |
| "لا يمكنني", "لا أقدر", "لا أملك معلومات", "قد لا يكون", "ربما", "عادةً", "بشكل عام" | |
| ] | |
| def clean_egyptian(text: str) -> str: | |
| t = (text or "").strip() | |
| for p in _BANNED_PHRASES: | |
| t = t.replace(p, "") | |
| t = re.sub(r"\s+", " ", t).strip() | |
| t = re.sub(r"[.،]{3,}", "…", t).strip() | |
| if not t: | |
| t = "تمام، قولي تحب تعمل ايه دلوقتي؟" | |
| return t | |
| # ---------------------------- | |
| # Whisper | |
| # ---------------------------- | |
| def transcribe_file(path: str) -> str: | |
| segments, _info = whisper_model.transcribe(path, language="ar") | |
| return " ".join(seg.text for seg in segments).strip() | |
| # ---------------------------- | |
| # Qwen generation in MSA | |
| # ---------------------------- | |
| def qwen_generate_msa(msa_prompt: str, max_new_tokens: int = MAX_NEW_TOKENS) -> str: | |
| msa_prompt = (msa_prompt or "").strip() | |
| if not msa_prompt: | |
| return "" | |
| system_msg = ( | |
| "أنت مساعد شخصي عملي. " | |
| "إذا كان سؤال المستخدم عامًا أو مفتوحًا، اقترح خطة أو خطوات عملية فورًا بدون اعتذار. " | |
| "اجعل الرد قصيرًا ومباشرًا ومفيدًا. " | |
| "اكتب باللغة العربية الفصحى البسيطة فقط." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": msa_prompt}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| if USE_GPU: | |
| input_ids = input_ids.to(qwen.device) | |
| with torch.no_grad(): | |
| output_ids = qwen.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| gen_ids = output_ids[0][input_ids.shape[-1]:] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| # hard cleanup if any role words appear | |
| text = re.sub(r"^(system|user|assistant)\s*[:\-]?\s*", "", text, flags=re.I).strip() | |
| text = text.replace("assistant", "").replace("user", "").replace("system", "").strip() | |
| return text | |
| # ---------------------------- | |
| # Full pipeline: input -> MSA -> Qwen -> Egyptian | |
| # ---------------------------- | |
| def generate_egyptian_reply(user_text: str) -> str: | |
| user_text = normalizer.normalize(user_text, stage="llm") | |
| if not user_text: | |
| return "مسمعتش كويس. ممكن تعيد تاني؟" | |
| msa_in = to_msa(user_text) | |
| llm_msa = qwen_generate_msa(msa_in, MAX_NEW_TOKENS) | |
| final_egy = clean_egyptian(to_egyptian(llm_msa)) | |
| final_egy = normalizer.normalize(final_egy, stage="llm") | |
| return final_egy | |
| # ---------------------------- | |
| # XTTS | |
| # ---------------------------- | |
| def xtts_speak(text: str, speaker_wav_path: Optional[str] = None) -> str: | |
| tts_local = get_tts() | |
| out_path = f"/tmp/{uuid.uuid4().hex}.wav" | |
| kwargs = {"language": "ar"} | |
| if speaker_wav_path: | |
| kwargs["speaker_wav"] = speaker_wav_path | |
| else: | |
| speakers = getattr(tts_local, "speakers", None) | |
| if speakers and len(speakers) > 0: | |
| kwargs["speaker"] = speakers[0] | |
| tts_local.tts_to_file(text=text, file_path=out_path, **kwargs) | |
| return out_path | |
| # ---------------------------- | |
| # FastAPI + simple UI | |
| # ---------------------------- | |
| app = FastAPI(title="Egyptian Arabic Speech Chatbot") | |
| INDEX_HTML = """ | |
| <!doctype html> | |
| <html lang="ar" dir="rtl"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Egyptian Voice Chatbot</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 24px; } | |
| .box { max-width: 780px; margin: 0 auto; } | |
| textarea { width: 100%; min-height: 90px; } | |
| .row { display: flex; gap: 12px; flex-wrap: wrap; } | |
| .card { border: 1px solid #ddd; border-radius: 10px; padding: 12px; margin-top: 12px; } | |
| button { padding: 10px 16px; cursor: pointer; } | |
| input[type=file] { width: 100%; } | |
| .muted { color: #666; font-size: 13px; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="box"> | |
| <h2>Egyptian Arabic Voice Chatbot</h2> | |
| <p class="muted">اكتب نص أو ارفع ملف صوتي. الرد بيرجع نص + صوت باللهجة المصرية.</p> | |
| <div class="card"> | |
| <label>Text</label> | |
| <textarea id="text" placeholder="اكتب هنا..."></textarea> | |
| <div class="row" style="margin-top: 12px;"> | |
| <div style="flex:1;"> | |
| <label>Audio input (optional)</label> | |
| <input id="audio" type="file" accept="audio/*"> | |
| </div> | |
| <div style="flex:1;"> | |
| <label>Speaker reference (optional)</label> | |
| <input id="speaker" type="file" accept="audio/*"> | |
| </div> | |
| </div> | |
| <div style="margin-top: 12px;"> | |
| <button id="send">Send</button> | |
| <span id="status" class="muted" style="margin-right: 10px;"></span> | |
| </div> | |
| </div> | |
| <div class="card"> | |
| <h3>Result</h3> | |
| <div><b>انت:</b> <span id="user_text"></span></div> | |
| <div style="margin-top: 6px;"><b>المساعد:</b> <span id="assistant_text"></span></div> | |
| <div style="margin-top: 10px;"> | |
| <audio id="player" controls></audio> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const sendBtn = document.getElementById("send"); | |
| const statusEl = document.getElementById("status"); | |
| const textEl = document.getElementById("text"); | |
| const audioEl = document.getElementById("audio"); | |
| const speakerEl = document.getElementById("speaker"); | |
| const userTextOut = document.getElementById("user_text"); | |
| const assistantTextOut = document.getElementById("assistant_text"); | |
| const player = document.getElementById("player"); | |
| sendBtn.addEventListener("click", async () => { | |
| statusEl.textContent = "Sending..."; | |
| sendBtn.disabled = true; | |
| try { | |
| const form = new FormData(); | |
| form.append("text", textEl.value || ""); | |
| form.append("history", "[]"); | |
| if (audioEl.files.length > 0) form.append("audio", audioEl.files[0]); | |
| if (speakerEl.files.length > 0) form.append("speaker_ref", speakerEl.files[0]); | |
| const res = await fetch("/chat", { method: "POST", body: form }); | |
| if (!res.ok) { | |
| const t = await res.text(); | |
| throw new Error(t); | |
| } | |
| const data = await res.json(); | |
| userTextOut.textContent = data.user_text || ""; | |
| assistantTextOut.textContent = data.assistant_text || ""; | |
| if (data.audio_url) { | |
| player.src = data.audio_url; | |
| player.load(); | |
| player.play().catch(() => {}); | |
| } | |
| statusEl.textContent = "Done"; | |
| } catch (e) { | |
| statusEl.textContent = "Error: " + (e.message || e); | |
| } finally { | |
| sendBtn.disabled = false; | |
| setTimeout(() => statusEl.textContent = "", 4000); | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def index(): | |
| return HTMLResponse(INDEX_HTML) | |
| def health(): | |
| return {"ok": True, "gpu": USE_GPU} | |
| async def chat( | |
| text: Optional[str] = Form(default=None), | |
| history: Optional[str] = Form(default="[]"), | |
| audio: Optional[UploadFile] = File(default=None), | |
| speaker_ref: Optional[UploadFile] = File(default=None), | |
| ): | |
| # history currently unused but kept for compatibility | |
| try: | |
| _hist = json.loads(history or "[]") | |
| except Exception: | |
| _hist = [] | |
| audio_path = None | |
| if audio is not None: | |
| audio_path = f"/tmp/{uuid.uuid4().hex}_{audio.filename}" | |
| with open(audio_path, "wb") as f: | |
| f.write(await audio.read()) | |
| speaker_path = None | |
| if speaker_ref is not None: | |
| speaker_path = f"/tmp/{uuid.uuid4().hex}_{speaker_ref.filename}" | |
| with open(speaker_path, "wb") as f: | |
| f.write(await speaker_ref.read()) | |
| if audio_path: | |
| user_text = transcribe_file(audio_path) | |
| else: | |
| user_text = (text or "").strip() | |
| assistant_text = generate_egyptian_reply(user_text) | |
| tts_text = normalizer.normalize(assistant_text, stage="tts") | |
| wav_path = xtts_speak(tts_text, speaker_path) | |
| return JSONResponse( | |
| { | |
| "user_text": user_text, | |
| "assistant_text": assistant_text, | |
| "audio_url": f"/audio?path={wav_path}", | |
| } | |
| ) | |
| def audio(path: str): | |
| return FileResponse(path, media_type="audio/wav", filename="reply.wav") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |