import os import re import uuid import json import inspect import builtins from typing import Optional import torch from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from faster_whisper import WhisperModel from transformers import AutoTokenizer, AutoModelForCausalLM class EgyptianNormalizer: def __init__(self): self.replacements = [ (r"\bازيك\b", "إزيك"), (r"\bازى\b", "إزي"), (r"\bازاي\b", "إزاي"), (r"\bايه\b", "إيه"), (r"\bليه\b", "ليه"), (r"\bفين\b", "فين"), (r"\bامتى\b", "إمتى"), (r"\bعاوز\b", "عايز"), (r"\bعاوزه\b", "عايزه"), (r"\bعايز\b", "عايز"), (r"\bدلوقت\b", "دلوقتي"), (r"\bكدا\b", "كده"), (r"\bعلشان\b", "عشان"), (r"\bعميل ايه\b", "عامل إيه"), (r"\bعامله ايه\b", "عاملة إيه"), (r"\bماعرفش\b", "مش عارف"), (r"\bماكنتش\b", "مش كنت"), (r"\bمافهمتش\b", "مش فاهم"), (r"\bok\b", "تمام"), (r"\bokay\b", "تمام"), (r"\bsorry\b", "معلش"), ] self.question_words = ["ليه", "إيه", "فين", "إمتى", "إزاي", "إزي", "كام", "مين"] self.diacritics = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]") self.zero_width = re.compile(r"[\u200c\u200d\u200e\u200f\ufeff]") self.tatweel = "\u0640" def normalize(self, text: str, stage: str = "llm") -> str: if not text: return "" t = str(text) t = self.zero_width.sub("", t) t = t.replace(self.tatweel, "") t = t.replace("\n", " ") t = self.diacritics.sub("", t) t = t.replace("?", "؟") t = t.replace(",", "،") t = t.replace(";", "؛") t = re.sub(r"[أإآٱ]", "ا", t) t = t.replace("ى", "ي") t = re.sub(r"(.)\1{2,}", r"\1", t) t = t.translate(str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")) if stage == "tts": t = t.replace("%", " في المية ") for pattern, repl in self.replacements: t = re.sub(pattern, repl, t, flags=re.IGNORECASE) t = t.replace("قلتلك", "قلت لك") t = t.replace("قولتلك", "قلت لك") t = t.replace("قلتلهم", "قلت لهم") t = t.replace("قولتلهم", "قلت لهم") t = re.sub(r"\s+", " ", t).strip() if any(w in t for w in self.question_words) and not t.endswith(("؟", "!", ".")): t += "؟" if stage == "tts": t = re.sub(r"\b[a-zA-Z]+\b", "", t) t = re.sub(r"\s+", " ", t).strip() return t # ---------------------------- # Caches (HF Spaces friendly) # ---------------------------- os.environ.setdefault("HF_HOME", "/data/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/data/huggingface/hub") os.environ.setdefault("TRANSFORMERS_CACHE", "/data/huggingface/transformers") os.environ.setdefault("XDG_CACHE_HOME", "/data/cache") os.environ.setdefault("XDG_DATA_HOME", "/data/local/share") # ---------------------------- # Coqui CPML non-interactive acceptance # ---------------------------- os.makedirs("/data/local/share/tts", exist_ok=True) try: with open("/data/local/share/tts/.tos_agreed", "w") as f: f.write("y") except Exception: pass _real_input = builtins.input def _auto_input(prompt=""): p = (prompt or "").lower() if "cpml" in p or "license" in p or "[y/n]" in p: return os.environ.get("COQUI_TOS", "y") try: for frame in inspect.stack(): fname = frame.filename.replace("\\", "/") if fname.endswith("/TTS/utils/manage.py") and frame.function == "ask_tos": return os.environ.get("COQUI_TOS", "y") except Exception: pass return _real_input(prompt) builtins.input = _auto_input # ---------------------------- # Optional CAMeL Tools normalization # ---------------------------- try: from camel_tools.utils.normalize import ( normalize_alef_maksura_ar, normalize_alef_ar, normalize_teh_marbuta_ar, normalize_unicode, ) CAMEL_OK = True except Exception: CAMEL_OK = False # ---------------------------- # Config # ---------------------------- QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-3B-Instruct") WHISPER_SIZE = os.getenv("WHISPER_SIZE", "small") XTTS_MODEL_ID = os.getenv("XTTS_MODEL_ID", "tts_models/multilingual/multi-dataset/xtts_v2") SYSTEM_PROMPT = os.getenv( "SYSTEM_PROMPT", "انت مساعد مصري. رد باللهجة المصرية فقط وبالعربي فقط. ممنوع تستخدم اي كلام انجليزي. ردودك قصيرة وواضحة." ) MAX_TURNS = int(os.getenv("MAX_TURNS", "8")) MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) HAS_GPU = torch.cuda.is_available() DEVICE_STR = "cuda" if HAS_GPU else "cpu" normalizer = EgyptianNormalizer() # ---------------------------- # Load Whisper + Qwen once # ---------------------------- whisper_compute = "float16" if HAS_GPU else "int8" whisper_model = WhisperModel(WHISPER_SIZE, device=DEVICE_STR, compute_type=whisper_compute) tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True) qwen = AutoModelForCausalLM.from_pretrained( QWEN_MODEL_ID, torch_dtype=torch.float16 if HAS_GPU else torch.float32, device_map="auto" if HAS_GPU else None, trust_remote_code=True, ) # XTTS lazy init tts = None def get_tts(): global tts if tts is None: from TTS.api import TTS tts = TTS(XTTS_MODEL_ID, gpu=HAS_GPU) return tts # ---------------------------- # Whisper # ---------------------------- def transcribe_file(path: str) -> str: segments, _info = whisper_model.transcribe(path, language="ar") return " ".join(seg.text for seg in segments).strip() # ---------------------------- # Qwen reply # history format: list of [user, assistant] # ---------------------------- def qwen_reply(history, user_text: str) -> str: messages = [{"role": "system", "content": SYSTEM_PROMPT or ""}] for u, a in (history or [])[-MAX_TURNS:]: u = "" if u is None else str(u) a = "" if a is None else str(a) if u.strip(): messages.append({"role": "user", "content": u}) if a.strip(): messages.append({"role": "assistant", "content": a}) user_text = "" if user_text is None else str(user_text) if not user_text.strip(): return "ممكن تكتب سؤالك تاني؟" messages.append({"role": "user", "content": user_text}) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt") if HAS_GPU: inputs = {k: v.to(qwen.device) for k, v in inputs.items()} with torch.no_grad(): out_ids = qwen.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) # Decode only the new tokens (not the prompt) prompt_len = inputs["input_ids"].shape[1] gen_ids = out_ids[0][prompt_len:] text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() # Extra cleanup if model still outputs role words text = re.sub(r"^(system|user|assistant)\s*[:\-]?\s*", "", text, flags=re.I) text = text.replace("assistant", "").replace("user", "").replace("system", "").strip() return text # ---------------------------- # XTTS # ---------------------------- def xtts_speak(text: str, speaker_wav_path: Optional[str] = None) -> str: tts_local = get_tts() out_path = f"/tmp/{uuid.uuid4().hex}.wav" kwargs = { "language": "ar", } if speaker_wav_path: kwargs["speaker_wav"] = speaker_wav_path else: speakers = getattr(tts_local, "speakers", None) if speakers and len(speakers) > 0: kwargs["speaker"] = speakers[0] tts_local.tts_to_file(text=text, file_path=out_path, **kwargs) return out_path # ---------------------------- # FastAPI # ---------------------------- app = FastAPI(title="Arabic Dialect Speech Chatbot") INDEX_HTML = """
اكتب نص أو ارفع ملف صوتي. الرد يرجع نص + صوت.