import re import gradio as gr import spaces import torch from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM # ========================= # 0) Config # ========================= TRANSLATOR_MODEL = "oddadmix/Masrawy-BiLingual-v1" ASR_MODEL = "openai/whisper-small" LLM_MODEL = "Qwen/Qwen2.5-3B-Instruct" USE_GPU = torch.cuda.is_available() DEVICE = 0 if USE_GPU else -1 # ========================= # 1) Load models (once) # ========================= translator = pipeline("translation", model=TRANSLATOR_MODEL, device=DEVICE) asr = pipeline( "automatic-speech-recognition", model=ASR_MODEL, device=DEVICE ) tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( LLM_MODEL, torch_dtype="auto", device_map="auto" if USE_GPU else None, trust_remote_code=True ) if not USE_GPU: model = model.to("cpu") # ========================= # 2) Translator helpers (explicit direction, non-ambiguous) # ========================= def to_msa(text: str) -> str: """ Convert ANY Arabic (Egyptian/MSA/mix) -> MSA. Uses tag (model behavior in your translator code). """ text = (text or "").strip() if not text: return "" return translator(text + " ")[0]["translation_text"] def to_egyptian(text: str) -> str: """ Convert MSA -> Egyptian. Uses tag . """ text = (text or "").strip() if not text: return "" return translator(text + " ")[0]["translation_text"] # ========================= # 3) Output cleaning (Detox / style shaping) # ========================= _BANNED_PHRASES = [ "كمساعد", "كمساعد ذكي", "معلش", "آسف", "اعتذر", "مش عارف", "لا أستطيع", "غير قادر", "لا يمكنني", "لا أقدر", "لا أملك معلومات", "قد لا يكون", "ربما", "عادةً", "بشكل عام" ] def clean_egyptian(text: str) -> str: """ Lightweight cleanup to remove annoying meta/defensive phrasing. Not meant to be perfect; keeps it simple and safe. """ t = (text or "").strip() # Remove banned phrases (simple replace) for p in _BANNED_PHRASES: t = t.replace(p, "") # Collapse extra spaces t = re.sub(r"\s+", " ", t).strip() # Remove repeated punctuation t = re.sub(r"[.،]{3,}", "…", t).strip() # If it becomes empty, fall back to a helpful default if not t: t = "تمام—قولي انت فاضي ولا عندك شغل/مذاكرة النهارده؟" return t # ========================= # 4) Qwen generation (in MSA for stability) # ========================= def qwen_generate_msa(msa_prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: msa_prompt = (msa_prompt or "").strip() if not msa_prompt: return "" # Behavior-first system message (MOST IMPORTANT CHANGE) system_msg = ( "أنت مساعد شخصي عملي. " "إذا كان سؤال المستخدم عامًا أو مفتوحًا، اقترح خطة أو خطوات عملية من نفسك فورًا " "بدون اعتذار وبدون تبرير لحدودك. " "اجعل الرد قصيرًا ومباشرًا ومفيدًا. " "اكتب باللغة العربية الفصحى البسيطة فقط." ) messages = [ {"role": "system", "content": system_msg}, {"role": "user", "content": msa_prompt}, ] input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) if USE_GPU: input_ids = input_ids.to(model.device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id ) gen_ids = output_ids[0][input_ids.shape[-1]:] text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() return text # ========================= # 5) Core pipeline (stable + non-ambiguous) # ========================= def _pipeline_from_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float): """ Input -> (to MSA) -> Qwen (MSA) -> (to Egyptian) -> clean Returns: msa_in, llm_msa, final_egy """ user_text = (user_text or "").strip() if not user_text: return "", "", "" # 1) Normalize input to MSA (stable for LLM) msa_in = to_msa(user_text) # 2) LLM outputs in MSA (behavior controlled by system prompt) llm_msa = qwen_generate_msa(msa_in, max_new_tokens, temperature, top_p) # 3) Force Egyptian output + clean final_egy = clean_egyptian(to_egyptian(llm_msa)) return msa_in, llm_msa, final_egy @spaces.GPU def process_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool): msa_in, llm_msa, final_egy = _pipeline_from_text(user_text, max_new_tokens, temperature, top_p) if show_debug: return msa_in, llm_msa, final_egy # hide debug outputs return "", "", final_egy @spaces.GPU def process_audio(audio_path: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool): if not audio_path: if show_debug: return "", "", "", "" return "", "", "", "" # ASR asr_out = asr(audio_path) asr_text = (asr_out.get("text", "") if isinstance(asr_out, dict) else str(asr_out)).strip() if not asr_text: if show_debug: return "", "", "", "" return "", "", "", "" msa_in, llm_msa, final_egy = _pipeline_from_text(asr_text, max_new_tokens, temperature, top_p) if show_debug: return asr_text, msa_in, llm_msa, final_egy # hide debug outputs except ASR text + final return asr_text, "", "", final_egy # ========================= # 6) Gradio UI # ========================= with gr.Blocks(title="Egyptian Arabic Assistant") as demo: gr.Markdown( "## Egyptian Arabic Assistant\n" "منطق ثابت وواضح:\n" "**Input → (to MSA) → Qwen (MSA) → (to Egyptian) → Output**\n\n" "السلوك: رد عملي ومباشر، بدون اعتذار وبدون كلام Meta." ) with gr.Row(): max_new_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens") temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p") show_debug = gr.Checkbox(value=False, label="Show debug outputs") with gr.Tabs(): with gr.TabItem("Text Input"): txt_in = gr.Textbox(lines=4, placeholder="اكتب هنا (مصري/فصحى)", label="Input") txt_btn = gr.Button("Generate") dbg_msa_in = gr.Textbox(lines=2, label="(Debug) Input after to_msa") dbg_llm_msa = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)") out_egy = gr.Textbox(lines=5, label="Final Output (Egyptian)") txt_btn.click( process_text, inputs=[txt_in, max_new_tokens, temperature, top_p, show_debug], outputs=[dbg_msa_in, dbg_llm_msa, out_egy], ) with gr.TabItem("Audio Input"): aud_in = gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)") aud_btn = gr.Button("Transcribe + Generate") asr_txt = gr.Textbox(lines=2, label="ASR Text") dbg_msa_in_a = gr.Textbox(lines=2, label="(Debug) ASR after to_msa") dbg_llm_msa_a = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)") out_egy_a = gr.Textbox(lines=5, label="Final Output (Egyptian)") aud_btn.click( process_audio, inputs=[aud_in, max_new_tokens, temperature, top_p, show_debug], outputs=[asr_txt, dbg_msa_in_a, dbg_llm_msa_a, out_egy_a], ) demo.launch()