Spaces:
Build error
Build error
| import re | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| # ========================= | |
| # 0) Config | |
| # ========================= | |
| TRANSLATOR_MODEL = "oddadmix/Masrawy-BiLingual-v1" | |
| ASR_MODEL = "openai/whisper-small" | |
| LLM_MODEL = "Qwen/Qwen2.5-3B-Instruct" | |
| USE_GPU = torch.cuda.is_available() | |
| DEVICE = 0 if USE_GPU else -1 | |
| # ========================= | |
| # 1) Load models (once) | |
| # ========================= | |
| translator = pipeline("translation", model=TRANSLATOR_MODEL, device=DEVICE) | |
| asr = pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_MODEL, | |
| device=DEVICE | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL, | |
| torch_dtype="auto", | |
| device_map="auto" if USE_GPU else None, | |
| trust_remote_code=True | |
| ) | |
| if not USE_GPU: | |
| model = model.to("cpu") | |
| # ========================= | |
| # 2) Translator helpers (explicit direction, non-ambiguous) | |
| # ========================= | |
| def to_msa(text: str) -> str: | |
| """ | |
| Convert ANY Arabic (Egyptian/MSA/mix) -> MSA. | |
| Uses tag <ar> (model behavior in your translator code). | |
| """ | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| return translator(text + " <ar>")[0]["translation_text"] | |
| def to_egyptian(text: str) -> str: | |
| """ | |
| Convert MSA -> Egyptian. | |
| Uses tag <arz>. | |
| """ | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| return translator(text + " <arz>")[0]["translation_text"] | |
| # ========================= | |
| # 3) Output cleaning (Detox / style shaping) | |
| # ========================= | |
| _BANNED_PHRASES = [ | |
| "كمساعد", "كمساعد ذكي", "معلش", "آسف", "اعتذر", "مش عارف", "لا أستطيع", "غير قادر", | |
| "لا يمكنني", "لا أقدر", "لا أملك معلومات", "قد لا يكون", "ربما", "عادةً", "بشكل عام" | |
| ] | |
| def clean_egyptian(text: str) -> str: | |
| """ | |
| Lightweight cleanup to remove annoying meta/defensive phrasing. | |
| Not meant to be perfect; keeps it simple and safe. | |
| """ | |
| t = (text or "").strip() | |
| # Remove banned phrases (simple replace) | |
| for p in _BANNED_PHRASES: | |
| t = t.replace(p, "") | |
| # Collapse extra spaces | |
| t = re.sub(r"\s+", " ", t).strip() | |
| # Remove repeated punctuation | |
| t = re.sub(r"[.،]{3,}", "…", t).strip() | |
| # If it becomes empty, fall back to a helpful default | |
| if not t: | |
| t = "تمام—قولي انت فاضي ولا عندك شغل/مذاكرة النهارده؟" | |
| return t | |
| # ========================= | |
| # 4) Qwen generation (in MSA for stability) | |
| # ========================= | |
| def qwen_generate_msa(msa_prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
| msa_prompt = (msa_prompt or "").strip() | |
| if not msa_prompt: | |
| return "" | |
| # Behavior-first system message (MOST IMPORTANT CHANGE) | |
| system_msg = ( | |
| "أنت مساعد شخصي عملي. " | |
| "إذا كان سؤال المستخدم عامًا أو مفتوحًا، اقترح خطة أو خطوات عملية من نفسك فورًا " | |
| "بدون اعتذار وبدون تبرير لحدودك. " | |
| "اجعل الرد قصيرًا ومباشرًا ومفيدًا. " | |
| "اكتب باللغة العربية الفصحى البسيطة فقط." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": msa_prompt}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| if USE_GPU: | |
| input_ids = input_ids.to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| gen_ids = output_ids[0][input_ids.shape[-1]:] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| return text | |
| # ========================= | |
| # 5) Core pipeline (stable + non-ambiguous) | |
| # ========================= | |
| def _pipeline_from_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float): | |
| """ | |
| Input -> (to MSA) -> Qwen (MSA) -> (to Egyptian) -> clean | |
| Returns: msa_in, llm_msa, final_egy | |
| """ | |
| user_text = (user_text or "").strip() | |
| if not user_text: | |
| return "", "", "" | |
| # 1) Normalize input to MSA (stable for LLM) | |
| msa_in = to_msa(user_text) | |
| # 2) LLM outputs in MSA (behavior controlled by system prompt) | |
| llm_msa = qwen_generate_msa(msa_in, max_new_tokens, temperature, top_p) | |
| # 3) Force Egyptian output + clean | |
| final_egy = clean_egyptian(to_egyptian(llm_msa)) | |
| return msa_in, llm_msa, final_egy | |
| def process_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool): | |
| msa_in, llm_msa, final_egy = _pipeline_from_text(user_text, max_new_tokens, temperature, top_p) | |
| if show_debug: | |
| return msa_in, llm_msa, final_egy | |
| # hide debug outputs | |
| return "", "", final_egy | |
| def process_audio(audio_path: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool): | |
| if not audio_path: | |
| if show_debug: | |
| return "", "", "", "" | |
| return "", "", "", "" | |
| # ASR | |
| asr_out = asr(audio_path) | |
| asr_text = (asr_out.get("text", "") if isinstance(asr_out, dict) else str(asr_out)).strip() | |
| if not asr_text: | |
| if show_debug: | |
| return "", "", "", "" | |
| return "", "", "", "" | |
| msa_in, llm_msa, final_egy = _pipeline_from_text(asr_text, max_new_tokens, temperature, top_p) | |
| if show_debug: | |
| return asr_text, msa_in, llm_msa, final_egy | |
| # hide debug outputs except ASR text + final | |
| return asr_text, "", "", final_egy | |
| # ========================= | |
| # 6) Gradio UI | |
| # ========================= | |
| with gr.Blocks(title="Egyptian Arabic Assistant") as demo: | |
| gr.Markdown( | |
| "## Egyptian Arabic Assistant\n" | |
| "منطق ثابت وواضح:\n" | |
| "**Input → (to MSA) → Qwen (MSA) → (to Egyptian) → Output**\n\n" | |
| "السلوك: رد عملي ومباشر، بدون اعتذار وبدون كلام Meta." | |
| ) | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p") | |
| show_debug = gr.Checkbox(value=False, label="Show debug outputs") | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Input"): | |
| txt_in = gr.Textbox(lines=4, placeholder="اكتب هنا (مصري/فصحى)", label="Input") | |
| txt_btn = gr.Button("Generate") | |
| dbg_msa_in = gr.Textbox(lines=2, label="(Debug) Input after to_msa") | |
| dbg_llm_msa = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)") | |
| out_egy = gr.Textbox(lines=5, label="Final Output (Egyptian)") | |
| txt_btn.click( | |
| process_text, | |
| inputs=[txt_in, max_new_tokens, temperature, top_p, show_debug], | |
| outputs=[dbg_msa_in, dbg_llm_msa, out_egy], | |
| ) | |
| with gr.TabItem("Audio Input"): | |
| aud_in = gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)") | |
| aud_btn = gr.Button("Transcribe + Generate") | |
| asr_txt = gr.Textbox(lines=2, label="ASR Text") | |
| dbg_msa_in_a = gr.Textbox(lines=2, label="(Debug) ASR after to_msa") | |
| dbg_llm_msa_a = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)") | |
| out_egy_a = gr.Textbox(lines=5, label="Final Output (Egyptian)") | |
| aud_btn.click( | |
| process_audio, | |
| inputs=[aud_in, max_new_tokens, temperature, top_p, show_debug], | |
| outputs=[asr_txt, dbg_msa_in_a, dbg_llm_msa_a, out_egy_a], | |
| ) | |
| demo.launch() | |