chatbot / app.py
SherinMohamed's picture
Update app.py
8115984 verified
import re
import gradio as gr
import spaces
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# =========================
# 0) Config
# =========================
TRANSLATOR_MODEL = "oddadmix/Masrawy-BiLingual-v1"
ASR_MODEL = "openai/whisper-small"
LLM_MODEL = "Qwen/Qwen2.5-3B-Instruct"
USE_GPU = torch.cuda.is_available()
DEVICE = 0 if USE_GPU else -1
# =========================
# 1) Load models (once)
# =========================
translator = pipeline("translation", model=TRANSLATOR_MODEL, device=DEVICE)
asr = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL,
device=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype="auto",
device_map="auto" if USE_GPU else None,
trust_remote_code=True
)
if not USE_GPU:
model = model.to("cpu")
# =========================
# 2) Translator helpers (explicit direction, non-ambiguous)
# =========================
def to_msa(text: str) -> str:
"""
Convert ANY Arabic (Egyptian/MSA/mix) -> MSA.
Uses tag <ar> (model behavior in your translator code).
"""
text = (text or "").strip()
if not text:
return ""
return translator(text + " <ar>")[0]["translation_text"]
def to_egyptian(text: str) -> str:
"""
Convert MSA -> Egyptian.
Uses tag <arz>.
"""
text = (text or "").strip()
if not text:
return ""
return translator(text + " <arz>")[0]["translation_text"]
# =========================
# 3) Output cleaning (Detox / style shaping)
# =========================
_BANNED_PHRASES = [
"كمساعد", "كمساعد ذكي", "معلش", "آسف", "اعتذر", "مش عارف", "لا أستطيع", "غير قادر",
"لا يمكنني", "لا أقدر", "لا أملك معلومات", "قد لا يكون", "ربما", "عادةً", "بشكل عام"
]
def clean_egyptian(text: str) -> str:
"""
Lightweight cleanup to remove annoying meta/defensive phrasing.
Not meant to be perfect; keeps it simple and safe.
"""
t = (text or "").strip()
# Remove banned phrases (simple replace)
for p in _BANNED_PHRASES:
t = t.replace(p, "")
# Collapse extra spaces
t = re.sub(r"\s+", " ", t).strip()
# Remove repeated punctuation
t = re.sub(r"[.،]{3,}", "…", t).strip()
# If it becomes empty, fall back to a helpful default
if not t:
t = "تمام—قولي انت فاضي ولا عندك شغل/مذاكرة النهارده؟"
return t
# =========================
# 4) Qwen generation (in MSA for stability)
# =========================
def qwen_generate_msa(msa_prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
msa_prompt = (msa_prompt or "").strip()
if not msa_prompt:
return ""
# Behavior-first system message (MOST IMPORTANT CHANGE)
system_msg = (
"أنت مساعد شخصي عملي. "
"إذا كان سؤال المستخدم عامًا أو مفتوحًا، اقترح خطة أو خطوات عملية من نفسك فورًا "
"بدون اعتذار وبدون تبرير لحدودك. "
"اجعل الرد قصيرًا ومباشرًا ومفيدًا. "
"اكتب باللغة العربية الفصحى البسيطة فقط."
)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": msa_prompt},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
if USE_GPU:
input_ids = input_ids.to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id
)
gen_ids = output_ids[0][input_ids.shape[-1]:]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
return text
# =========================
# 5) Core pipeline (stable + non-ambiguous)
# =========================
def _pipeline_from_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float):
"""
Input -> (to MSA) -> Qwen (MSA) -> (to Egyptian) -> clean
Returns: msa_in, llm_msa, final_egy
"""
user_text = (user_text or "").strip()
if not user_text:
return "", "", ""
# 1) Normalize input to MSA (stable for LLM)
msa_in = to_msa(user_text)
# 2) LLM outputs in MSA (behavior controlled by system prompt)
llm_msa = qwen_generate_msa(msa_in, max_new_tokens, temperature, top_p)
# 3) Force Egyptian output + clean
final_egy = clean_egyptian(to_egyptian(llm_msa))
return msa_in, llm_msa, final_egy
@spaces.GPU
def process_text(user_text: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool):
msa_in, llm_msa, final_egy = _pipeline_from_text(user_text, max_new_tokens, temperature, top_p)
if show_debug:
return msa_in, llm_msa, final_egy
# hide debug outputs
return "", "", final_egy
@spaces.GPU
def process_audio(audio_path: str, max_new_tokens: int, temperature: float, top_p: float, show_debug: bool):
if not audio_path:
if show_debug:
return "", "", "", ""
return "", "", "", ""
# ASR
asr_out = asr(audio_path)
asr_text = (asr_out.get("text", "") if isinstance(asr_out, dict) else str(asr_out)).strip()
if not asr_text:
if show_debug:
return "", "", "", ""
return "", "", "", ""
msa_in, llm_msa, final_egy = _pipeline_from_text(asr_text, max_new_tokens, temperature, top_p)
if show_debug:
return asr_text, msa_in, llm_msa, final_egy
# hide debug outputs except ASR text + final
return asr_text, "", "", final_egy
# =========================
# 6) Gradio UI
# =========================
with gr.Blocks(title="Egyptian Arabic Assistant") as demo:
gr.Markdown(
"## Egyptian Arabic Assistant\n"
"منطق ثابت وواضح:\n"
"**Input → (to MSA) → Qwen (MSA) → (to Egyptian) → Output**\n\n"
"السلوك: رد عملي ومباشر، بدون اعتذار وبدون كلام Meta."
)
with gr.Row():
max_new_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p")
show_debug = gr.Checkbox(value=False, label="Show debug outputs")
with gr.Tabs():
with gr.TabItem("Text Input"):
txt_in = gr.Textbox(lines=4, placeholder="اكتب هنا (مصري/فصحى)", label="Input")
txt_btn = gr.Button("Generate")
dbg_msa_in = gr.Textbox(lines=2, label="(Debug) Input after to_msa")
dbg_llm_msa = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)")
out_egy = gr.Textbox(lines=5, label="Final Output (Egyptian)")
txt_btn.click(
process_text,
inputs=[txt_in, max_new_tokens, temperature, top_p, show_debug],
outputs=[dbg_msa_in, dbg_llm_msa, out_egy],
)
with gr.TabItem("Audio Input"):
aud_in = gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)")
aud_btn = gr.Button("Transcribe + Generate")
asr_txt = gr.Textbox(lines=2, label="ASR Text")
dbg_msa_in_a = gr.Textbox(lines=2, label="(Debug) ASR after to_msa")
dbg_llm_msa_a = gr.Textbox(lines=3, label="(Debug) Qwen output (MSA)")
out_egy_a = gr.Textbox(lines=5, label="Final Output (Egyptian)")
aud_btn.click(
process_audio,
inputs=[aud_in, max_new_tokens, temperature, top_p, show_debug],
outputs=[asr_txt, dbg_msa_in_a, dbg_llm_msa_a, out_egy_a],
)
demo.launch()