Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import uuid | |
| import json | |
| import inspect | |
| import builtins | |
| from typing import Optional | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from faster_whisper import WhisperModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| class EgyptianNormalizer: | |
| def __init__(self): | |
| self.replacements = [ | |
| (r"\bุงุฒูู\b", "ุฅุฒูู"), | |
| (r"\bุงุฒู\b", "ุฅุฒู"), | |
| (r"\bุงุฒุงู\b", "ุฅุฒุงู"), | |
| (r"\bุงูู\b", "ุฅูู"), | |
| (r"\bููู\b", "ููู"), | |
| (r"\bููู\b", "ููู"), | |
| (r"\bุงู ุชู\b", "ุฅู ุชู"), | |
| (r"\bุนุงูุฒ\b", "ุนุงูุฒ"), | |
| (r"\bุนุงูุฒู\b", "ุนุงูุฒู"), | |
| (r"\bุนุงูุฒ\b", "ุนุงูุฒ"), | |
| (r"\bุฏูููุช\b", "ุฏูููุชู"), | |
| (r"\bูุฏุง\b", "ูุฏู"), | |
| (r"\bุนูุดุงู\b", "ุนุดุงู"), | |
| (r"\bุนู ูู ุงูู\b", "ุนุงู ู ุฅูู"), | |
| (r"\bุนุงู ูู ุงูู\b", "ุนุงู ูุฉ ุฅูู"), | |
| (r"\bู ุงุนุฑูุด\b", "ู ุด ุนุงุฑู"), | |
| (r"\bู ุงููุชุด\b", "ู ุด ููุช"), | |
| (r"\bู ุงููู ุชุด\b", "ู ุด ูุงูู "), | |
| (r"\bok\b", "ุชู ุงู "), | |
| (r"\bokay\b", "ุชู ุงู "), | |
| (r"\bsorry\b", "ู ุนูุด"), | |
| ] | |
| self.question_words = ["ููู", "ุฅูู", "ููู", "ุฅู ุชู", "ุฅุฒุงู", "ุฅุฒู", "ูุงู ", "ู ูู"] | |
| self.diacritics = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]") | |
| self.zero_width = re.compile(r"[\u200c\u200d\u200e\u200f\ufeff]") | |
| self.tatweel = "\u0640" | |
| def normalize(self, text: str, stage: str = "llm") -> str: | |
| if not text: | |
| return "" | |
| t = str(text) | |
| t = self.zero_width.sub("", t) | |
| t = t.replace(self.tatweel, "") | |
| t = t.replace("\n", " ") | |
| t = self.diacritics.sub("", t) | |
| t = t.replace("?", "ุ") | |
| t = t.replace(",", "ุ") | |
| t = t.replace(";", "ุ") | |
| t = re.sub(r"[ุฃุฅุขูฑ]", "ุง", t) | |
| t = t.replace("ู", "ู") | |
| t = re.sub(r"(.)\1{2,}", r"\1", t) | |
| t = t.translate(str.maketrans("ู ูกูขูฃูคูฅูฆูงูจูฉ", "0123456789")) | |
| if stage == "tts": | |
| t = t.replace("%", " ูู ุงูู ูุฉ ") | |
| for pattern, repl in self.replacements: | |
| t = re.sub(pattern, repl, t, flags=re.IGNORECASE) | |
| t = t.replace("ููุชูู", "ููุช ูู") | |
| t = t.replace("ูููุชูู", "ููุช ูู") | |
| t = t.replace("ููุชููู ", "ููุช ููู ") | |
| t = t.replace("ูููุชููู ", "ููุช ููู ") | |
| t = re.sub(r"\s+", " ", t).strip() | |
| if any(w in t for w in self.question_words) and not t.endswith(("ุ", "!", ".")): | |
| t += "ุ" | |
| if stage == "tts": | |
| t = re.sub(r"\b[a-zA-Z]+\b", "", t) | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return t | |
| # ---------------------------- | |
| # Caches (HF Spaces friendly) | |
| # ---------------------------- | |
| os.environ.setdefault("HF_HOME", "/data/huggingface") | |
| os.environ.setdefault("HF_HUB_CACHE", "/data/huggingface/hub") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/data/huggingface/transformers") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/data/cache") | |
| os.environ.setdefault("XDG_DATA_HOME", "/data/local/share") | |
| # ---------------------------- | |
| # Coqui CPML non-interactive acceptance | |
| # ---------------------------- | |
| os.makedirs("/data/local/share/tts", exist_ok=True) | |
| try: | |
| with open("/data/local/share/tts/.tos_agreed", "w") as f: | |
| f.write("y") | |
| except Exception: | |
| pass | |
| _real_input = builtins.input | |
| def _auto_input(prompt=""): | |
| p = (prompt or "").lower() | |
| if "cpml" in p or "license" in p or "[y/n]" in p: | |
| return os.environ.get("COQUI_TOS", "y") | |
| try: | |
| for frame in inspect.stack(): | |
| fname = frame.filename.replace("\\", "/") | |
| if fname.endswith("/TTS/utils/manage.py") and frame.function == "ask_tos": | |
| return os.environ.get("COQUI_TOS", "y") | |
| except Exception: | |
| pass | |
| return _real_input(prompt) | |
| builtins.input = _auto_input | |
| # ---------------------------- | |
| # Optional CAMeL Tools normalization | |
| # ---------------------------- | |
| try: | |
| from camel_tools.utils.normalize import ( | |
| normalize_alef_maksura_ar, | |
| normalize_alef_ar, | |
| normalize_teh_marbuta_ar, | |
| normalize_unicode, | |
| ) | |
| CAMEL_OK = True | |
| except Exception: | |
| CAMEL_OK = False | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct") | |
| WHISPER_SIZE = os.getenv("WHISPER_SIZE", "small") | |
| XTTS_MODEL_ID = os.getenv("XTTS_MODEL_ID", "tts_models/multilingual/multi-dataset/xtts_v2") | |
| SYSTEM_PROMPT = os.getenv( | |
| "SYSTEM_PROMPT", | |
| "ุงูุช ู ุณุงุนุฏ ู ุตุฑู. ุฑุฏ ุจุงูููุฌุฉ ุงูู ุตุฑูุฉ ููุท ูุจุงูุนุฑุจู ููุท. ู ู ููุน ุชุณุชุฎุฏู ุงู ููุงู ุงูุฌููุฒู. ุฑุฏูุฏู ูุตูุฑุฉ ููุงุถุญุฉ." | |
| ) | |
| MAX_TURNS = int(os.getenv("MAX_TURNS", "8")) | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) | |
| HAS_GPU = torch.cuda.is_available() | |
| DEVICE_STR = "cuda" if HAS_GPU else "cpu" | |
| normalizer = EgyptianNormalizer() | |
| # ---------------------------- | |
| # Load Whisper + Qwen once | |
| # ---------------------------- | |
| whisper_compute = "float16" if HAS_GPU else "int8" | |
| whisper_model = WhisperModel(WHISPER_SIZE, device=DEVICE_STR, compute_type=whisper_compute) | |
| tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True) | |
| qwen = AutoModelForCausalLM.from_pretrained( | |
| QWEN_MODEL_ID, | |
| torch_dtype=torch.float16 if HAS_GPU else torch.float32, | |
| device_map="auto" if HAS_GPU else None, | |
| trust_remote_code=True, | |
| ) | |
| # XTTS lazy init | |
| tts = None | |
| def get_tts(): | |
| global tts | |
| if tts is None: | |
| from TTS.api import TTS | |
| tts = TTS(XTTS_MODEL_ID, gpu=HAS_GPU) | |
| return tts | |
| # ---------------------------- | |
| # Whisper | |
| # ---------------------------- | |
| def transcribe_file(path: str) -> str: | |
| segments, _info = whisper_model.transcribe(path, language="ar") | |
| return " ".join(seg.text for seg in segments).strip() | |
| # ---------------------------- | |
| # Qwen reply | |
| # history format: list of [user, assistant] | |
| # ---------------------------- | |
| def qwen_reply(history, user_text: str) -> str: | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT or ""}] | |
| for u, a in (history or [])[-MAX_TURNS:]: | |
| u = "" if u is None else str(u) | |
| a = "" if a is None else str(a) | |
| if u.strip(): | |
| messages.append({"role": "user", "content": u}) | |
| if a.strip(): | |
| messages.append({"role": "assistant", "content": a}) | |
| user_text = "" if user_text is None else str(user_text) | |
| if not user_text.strip(): | |
| return "ู ู ูู ุชูุชุจ ุณุคุงูู ุชุงููุ" | |
| messages.append({"role": "user", "content": user_text}) | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if HAS_GPU: | |
| inputs = {k: v.to(qwen.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out_ids = qwen.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode only the new tokens (not the prompt) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| gen_ids = out_ids[0][prompt_len:] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| # Extra cleanup if model still outputs role words | |
| text = re.sub(r"^(system|user|assistant)\s*[:\-]?\s*", "", text, flags=re.I) | |
| text = text.replace("assistant", "").replace("user", "").replace("system", "").strip() | |
| return text | |
| # ---------------------------- | |
| # XTTS | |
| # ---------------------------- | |
| def xtts_speak(text: str, speaker_wav_path: Optional[str] = None) -> str: | |
| tts_local = get_tts() | |
| out_path = f"/tmp/{uuid.uuid4().hex}.wav" | |
| kwargs = { | |
| "language": "ar", | |
| } | |
| if speaker_wav_path: | |
| kwargs["speaker_wav"] = speaker_wav_path | |
| else: | |
| speakers = getattr(tts_local, "speakers", None) | |
| if speakers and len(speakers) > 0: | |
| kwargs["speaker"] = speakers[0] | |
| tts_local.tts_to_file(text=text, file_path=out_path, **kwargs) | |
| return out_path | |
| # ---------------------------- | |
| # FastAPI | |
| # ---------------------------- | |
| app = FastAPI(title="Arabic Dialect Speech Chatbot") | |
| INDEX_HTML = """ | |
| <!doctype html> | |
| <html lang="ar" dir="rtl"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Arabic Dialect Voice Chatbot</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 24px; } | |
| .box { max-width: 780px; margin: 0 auto; } | |
| textarea { width: 100%; min-height: 90px; } | |
| .row { display: flex; gap: 12px; flex-wrap: wrap; } | |
| .card { border: 1px solid #ddd; border-radius: 10px; padding: 12px; margin-top: 12px; } | |
| button { padding: 10px 16px; cursor: pointer; } | |
| input[type=file] { width: 100%; } | |
| .muted { color: #666; font-size: 13px; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="box"> | |
| <h2>Arabic Dialect Voice Chatbot</h2> | |
| <p class="muted">ุงูุชุจ ูุต ุฃู ุงุฑูุน ู ูู ุตูุชู. ุงูุฑุฏ ูุฑุฌุน ูุต + ุตูุช.</p> | |
| <div class="card"> | |
| <label>Text</label> | |
| <textarea id="text" placeholder="ุงูุชุจ ููุง..."></textarea> | |
| <div class="row" style="margin-top: 12px;"> | |
| <div style="flex:1;"> | |
| <label>Audio input (optional)</label> | |
| <input id="audio" type="file" accept="audio/*"> | |
| </div> | |
| <div style="flex:1;"> | |
| <label>Speaker reference (optional)</label> | |
| <input id="speaker" type="file" accept="audio/*"> | |
| </div> | |
| </div> | |
| <div style="margin-top: 12px;"> | |
| <button id="send">Send</button> | |
| <span id="status" class="muted" style="margin-right: 10px;"></span> | |
| </div> | |
| </div> | |
| <div class="card"> | |
| <h3>Result</h3> | |
| <div><b>User text:</b> <span id="user_text"></span></div> | |
| <div style="margin-top: 6px;"><b>Assistant:</b> <span id="assistant_text"></span></div> | |
| <div style="margin-top: 10px;"> | |
| <audio id="player" controls></audio> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const sendBtn = document.getElementById("send"); | |
| const statusEl = document.getElementById("status"); | |
| const textEl = document.getElementById("text"); | |
| const audioEl = document.getElementById("audio"); | |
| const speakerEl = document.getElementById("speaker"); | |
| const userTextOut = document.getElementById("user_text"); | |
| const assistantTextOut = document.getElementById("assistant_text"); | |
| const player = document.getElementById("player"); | |
| sendBtn.addEventListener("click", async () => { | |
| statusEl.textContent = "Sending..."; | |
| sendBtn.disabled = true; | |
| try { | |
| const form = new FormData(); | |
| form.append("text", textEl.value || ""); | |
| form.append("history", "[]"); | |
| if (audioEl.files.length > 0) { | |
| form.append("audio", audioEl.files[0]); | |
| } | |
| if (speakerEl.files.length > 0) { | |
| form.append("speaker_ref", speakerEl.files[0]); | |
| } | |
| const res = await fetch("/chat", { method: "POST", body: form }); | |
| if (!res.ok) { | |
| const t = await res.text(); | |
| throw new Error(t); | |
| } | |
| const data = await res.json(); | |
| userTextOut.textContent = data.user_text || ""; | |
| assistantTextOut.textContent = data.assistant_text || ""; | |
| if (data.audio_url) { | |
| player.src = data.audio_url; | |
| player.load(); | |
| player.play().catch(() => {}); | |
| } | |
| statusEl.textContent = "Done"; | |
| } catch (e) { | |
| statusEl.textContent = "Error: " + e.message; | |
| } finally { | |
| sendBtn.disabled = false; | |
| setTimeout(() => statusEl.textContent = "", 4000); | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def index(): | |
| return HTMLResponse(INDEX_HTML) | |
| def health(): | |
| return {"ok": True, "gpu": HAS_GPU} | |
| async def chat( | |
| text: Optional[str] = Form(default=None), | |
| history: Optional[str] = Form(default="[]"), | |
| audio: Optional[UploadFile] = File(default=None), | |
| speaker_ref: Optional[UploadFile] = File(default=None), | |
| ): | |
| try: | |
| hist = json.loads(history or "[]") | |
| except Exception: | |
| hist = [] | |
| audio_path = None | |
| if audio is not None: | |
| audio_path = f"/tmp/{uuid.uuid4().hex}_{audio.filename}" | |
| with open(audio_path, "wb") as f: | |
| f.write(await audio.read()) | |
| speaker_path = None | |
| if speaker_ref is not None: | |
| speaker_path = f"/tmp/{uuid.uuid4().hex}_{speaker_ref.filename}" | |
| with open(speaker_path, "wb") as f: | |
| f.write(await speaker_ref.read()) | |
| if audio_path: | |
| user_text = transcribe_file(audio_path) | |
| else: | |
| user_text = (text or "").strip() | |
| user_norm = normalizer.normalize(user_text, stage="llm") | |
| if not user_norm: | |
| assistant_text = "ู ุณู ุนุชุด ูููุณ. ู ู ูู ุชุนูุฏ ุชุงููุ" | |
| else: | |
| assistant_text = qwen_reply(hist, user_norm) | |
| tts_text = normalizer.normalize(assistant_text, stage="tts") | |
| wav_path = xtts_speak(tts_text, speaker_path) | |
| hist = (hist or []) + [[user_text, assistant_text]] | |
| return JSONResponse( | |
| { | |
| "user_text": user_text, | |
| "assistant_text": assistant_text, | |
| "history": hist, | |
| "audio_url": f"/audio?path={wav_path}", | |
| } | |
| ) | |
| def audio(path: str): | |
| return FileResponse(path, media_type="audio/wav", filename="reply.wav") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |