Spaces:
Sleeping
Sleeping
| """ | |
| kalamna-space / app.py | |
| Gradio + FastAPI Space for taMASRIBERT โ the unified deep-fusion model. | |
| Routing Logic: | |
| - ALL text uses Deep Fusion (BERT + FastText). | |
| - Any non-Arabic text (English or Franco) is dynamically extracted, translated/transliterated | |
| into Egyptian Arabic via NAMAA, and stitched back into the sequence before inference. | |
| """ | |
| import os, re, time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import fasttext | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, Security | |
| from fastapi.security.api_key import APIKeyHeader | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from transformers import AutoTokenizer, AutoModel, MarianTokenizer, MarianMTModel | |
| from huggingface_hub import hf_hub_download, login | |
| # Try to login if HF_TOKEN is set | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MAX_LEN = 128 | |
| FT_DIM = 300 | |
| RNN_HID = 256 | |
| CLASSNAMES = ['none', 'anger', 'joy', 'sadness', 'love', 'sympathy', 'surprise', 'fear'] | |
| SENT_LABELS = ['negative', 'neutral', 'positive'] | |
| SARC_LABELS = ['not sarcastic', 'sarcastic'] | |
| EMO_THRESHOLD = 0.45 | |
| # โโ Regex and Cleaning โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| _DIACRITICS_RE = re.compile(r"[\u064B-\u065F\u0670\u06D6-\u06DC\u06DF-\u06E4\u06E7\u06E8\u06EA-\u06ED\u0640]") | |
| _ARABIC_RE = re.compile(r"[\u0600-\u06FF]") | |
| _LATIN_RE = re.compile(r"[a-zA-Z]") | |
| def has_latin(text: str) -> bool: return bool(_LATIN_RE.search(text)) | |
| def clean_text(text: str) -> str: | |
| if not isinstance(text, str): return "" | |
| text = re.sub(r"(?:https?://|www\.)\S+", "", text) | |
| text = re.sub(r"\S+@\S+", "", text) | |
| text = re.sub(r"@\w+", "", text) | |
| text = re.sub(r"#", "", text) | |
| text = re.sub(r"\n+", " ", text) | |
| text = re.sub(r"[ุฅุฃุขุง]", "ุง", text) | |
| text = re.sub(r"ู", "ู", text) | |
| text = re.sub(r"ุค", "ุก", text) | |
| text = re.sub(r"ุฆ", "ุก", text) | |
| text = re.sub(r"ฺฏ", "ู", text) | |
| text = _DIACRITICS_RE.sub("", text) | |
| text = re.sub(r"(.)\1{3,}", r"\1\1", text) | |
| text = re.sub(r"[^\u0600-\u06FFa-zA-Z0-9\s\.\,\!\?\;\:\"\'\(\)\[\]\{\}\-\+\=\/\\]", "", text) | |
| return re.sub(r"\s+", " ", text).strip() | |
| # โโ NAMAA ENโEGY translation โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| print("Loading NAMAA ENโEGY translatorโฆ") | |
| _EGY_MODEL = "NAMAA-Space/masrawy-english-to-egyptian-arabic-translator-v2.9" | |
| egy_tokenizer = MarianTokenizer.from_pretrained(_EGY_MODEL) | |
| egy_model = MarianMTModel.from_pretrained(_EGY_MODEL).to(DEVICE) | |
| egy_model.eval() | |
| def translate_en_to_egy(text: str) -> str: | |
| tokens = egy_tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) | |
| with torch.no_grad(): | |
| out = egy_model.generate(**tokens, num_beams=4, max_new_tokens=128) | |
| return egy_tokenizer.decode(out[0], skip_special_tokens=True) | |
| # โโ Force Arabic Translation Pre-Processing โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def force_arabic_translation(raw: str) -> tuple[str, str]: | |
| if not has_latin(raw): | |
| return clean_text(raw), "Pure Arabic (Deep Fusion)" | |
| # Function to translate isolated chunks of Latin/Franco text | |
| def translate_span(match): | |
| span = match.group(0).strip() | |
| if not span: return match.group(0) | |
| try: | |
| translated = translate_en_to_egy(span) | |
| return " " + translated + " " | |
| except: | |
| return match.group(0) | |
| # Greedily match contiguous blocks of Latin characters and numbers | |
| # This captures full English sentences, Franco blocks, or mixed chunks | |
| pattern = re.compile(r'[a-zA-Z0-9]+(?:\s+[a-zA-Z0-9]+)*') | |
| processed = pattern.sub(translate_span, raw) | |
| cleaned = clean_text(processed) | |
| return cleaned, "Translated to Arabic (Deep Fusion)" | |
| # โโ Model definition โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class TaskHead(nn.Module): | |
| def __init__(self, input_size: int, n_classes: int, dropout: float = 0.3): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_size, 256), nn.LayerNorm(256), nn.GELU(), | |
| nn.Dropout(dropout), nn.Linear(256, n_classes), | |
| ) | |
| def forward(self, x): return self.net(x) | |
| class UnifiedMASRIHead(nn.Module): | |
| def __init__(self, bert_model_name: str = "T0KII/MASRIBERTv3", ft_dim: int = 300, rnn_hidden: int = 256, num_layers: int = 2, dropout: float = 0.3): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(bert_model_name) | |
| combined_dim = self.bert.config.hidden_size + rnn_hidden * 4 | |
| self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout) | |
| self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout) | |
| self.rnn_dropout = nn.Dropout(dropout) | |
| self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5) | |
| self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3) | |
| self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3) | |
| def forward(self, input_ids, attention_mask, ft_embeds): | |
| bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_vec = bert_out.last_hidden_state[:, 0, :] | |
| lstm_out, _ = self.bilstm(ft_embeds) | |
| gru_out, _ = self.bigru(ft_embeds) | |
| rnn_feat = self.rnn_dropout(torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1)) | |
| combined = torch.cat([cls_vec, rnn_feat], dim=1) | |
| return self.emotion_head(combined), self.sentiment_head(combined), self.sarcasm_head(combined) | |
| # โโ Load FastText โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| print("Loading FastText arz vectorsโฆ") | |
| _ft_path = hf_hub_download("facebook/fasttext-arz-vectors", "model.bin") | |
| _ft_model = fasttext.load_model(_ft_path) | |
| def get_ft_embedding(text: str) -> np.ndarray: | |
| tokens = text.split()[:MAX_LEN] | |
| matrix = np.zeros((MAX_LEN, FT_DIM), dtype=np.float32) | |
| for i, tok in enumerate(tokens): | |
| try: matrix[i] = _ft_model.get_word_vector(tok) | |
| except: pass | |
| return matrix | |
| # โโ Load taMASRIBERT โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| print("Loading taMASRIBERT tokenizerโฆ") | |
| _REPO = "T0KII/taMASRIBERT" | |
| ta_tokenizer = AutoTokenizer.from_pretrained(_REPO) | |
| print("Initialising UnifiedMASRIHeadโฆ") | |
| ta_model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(DEVICE) | |
| ta_model.load_state_dict(torch.load(hf_hub_download(repo_id=_REPO, filename="pytorch_model.bin"), map_location=DEVICE), strict=False) | |
| ta_model.eval() | |
| print("โ taMASRIBERT ready") | |
| # โโ Core inference โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _run_inference(cleaned: str): | |
| enc = ta_tokenizer(cleaned, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt") | |
| ft_emb = torch.from_numpy(np.array([get_ft_embedding(cleaned)], dtype=np.float32)).to(DEVICE) | |
| e_l, st_l, sc_l = ta_model(enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE), ft_emb) | |
| return (F.softmax(e_l, -1).cpu().numpy()[0], F.softmax(st_l, -1).cpu().numpy()[0], F.softmax(sc_l, -1).cpu().numpy()[0]) | |
| def predict(text: str): | |
| if not text or not text.strip(): return {}, {}, {}, "โ", "" | |
| cleaned, route_tag = force_arabic_translation(text) | |
| if not cleaned: return {}, {}, {}, "โ", "" | |
| e_p, st_p, sc_p = _run_inference(cleaned) | |
| max_emo_p = float(np.max(e_p)) | |
| top_emo = CLASSNAMES[int(np.argmax(e_p))] if max_emo_p >= EMO_THRESHOLD else "none" | |
| top_sent = SENT_LABELS[int(np.argmax(st_p))] | |
| top_sarc = SARC_LABELS[int(np.argmax(sc_p))] | |
| summary = ( | |
| f"**Emotion:** {top_emo.title()} ({max_emo_p:.1%}) | " | |
| f"**Sentiment:** {top_sent.title()} ({float(np.max(st_p)):.1%}) | " | |
| f"**Sarcasm:** {top_sarc.title()} ({float(np.max(sc_p)):.1%}) \n\n" | |
| f"๐ท๏ธ **Tag:** {route_tag}" | |
| ) | |
| return ( | |
| {CLASSNAMES[i]: float(e_p[i]) for i in range(len(CLASSNAMES))}, | |
| {SENT_LABELS[i]: float(st_p[i]) for i in range(len(SENT_LABELS))}, | |
| {SARC_LABELS[i]: float(sc_p[i]) for i in range(len(SARC_LABELS))}, | |
| summary, | |
| cleaned | |
| ) | |
| # โโ Gradio UI โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo: | |
| gr.Markdown("# Kalamna: Egyptian Arabic Analyzer") | |
| gr.Markdown("**taMASRIBERT** with Automatic Arabic Script Normalization.") | |
| with gr.Row(): | |
| txt = gr.Textbox(label="Input (Arabic / Franco / English / Mixed)", placeholder="ุงูุชุจ ููุง ...", lines=3, rtl=True) | |
| ctx = gr.Textbox(label="Pre-processed (what the model sees)", interactive=False, rtl=True) | |
| btn = gr.Button("Analyze", variant="primary") | |
| sum_md = gr.Markdown() | |
| with gr.Row(): | |
| e_out = gr.Label(label="Emotion", num_top_classes=8) | |
| st_out = gr.Label(label="Sentiment", num_top_classes=3) | |
| sc_out = gr.Label(label="Sarcasm", num_top_classes=2) | |
| gr.Examples( | |
| examples=[ | |
| "ุงูู ูุชูุณููู ุงููู ูุงู ุณูุจุฑ ูุงูุช ุทูุน ุนููู ูู ุงูุณูุฑูุฉ ุจุชุงุนุฉ ุงูุณุงุญูุ ุงูุชูููู ุงุณูุฃ ู ุง ูู ูู", | |
| "ya gd3an el match elgai is crucial, el bad performance dh msh hnkml beeh", | |
| "ู ุง ุดุงุก ุงููู ุนูู ุณุฑุนุฉ ุงููุชุ ุงูุณูุญูุงุฉ ุจุชุณุจููุ ุจุฌุฏ ุงุญุณู ุฎุฏู ุฉ ุนู ูุงุก ูู ุงูุฏููุง.", | |
| "ana 3mlt update l flutter w el app crash, this is extremely frustrating", | |
| "ุงูุง ุทูุจุช ุงูุงูุฑุฏุฑ ู ู ุดูุฑ ู ูุณู ู ุฌุงุดุ your delivery service is completely useless and I want a refund", | |
| "el boxy fit shirt dh shklo gamed awy, perfect style w el material quality is top notch" | |
| ], | |
| inputs=txt, | |
| ) | |
| btn.click(predict, txt, [e_out, st_out, sc_out, sum_md, ctx]) | |
| txt.submit(predict, txt, [e_out, st_out, sc_out, sum_md, ctx]) | |
| # โโ FastAPI + Gradio mounting โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class DetectRequest(BaseModel): text: str | |
| class SarcasmResult(BaseModel): label: str; score: float | |
| class DetectResponse(BaseModel): | |
| emotion: str; confidence: float; sentiment: str; sarcasm: SarcasmResult | |
| urgent: bool; latency_ms: float; source: str; route: str; cleaned: Optional[str] = None | |
| _FALLBACK = DetectResponse( | |
| emotion="neutral", confidence=0.0, sentiment="neutral", | |
| sarcasm=SarcasmResult(label="not sarcastic", score=0.0), | |
| urgent=False, latency_ms=0.0, source="fallback", route="none", cleaned=None, | |
| ) | |
| fapp = FastAPI(title="Kalamna Emotion API", version="2.0.0", root_path=os.environ.get("ROOT_PATH", "")) | |
| def detect_api(body: DetectRequest): | |
| if not body.text or not body.text.strip(): return _FALLBACK.model_copy() | |
| t0 = time.perf_counter() | |
| try: | |
| cleaned, route_tag = force_arabic_translation(body.text) | |
| if not cleaned: return _FALLBACK.model_copy() | |
| e_p, st_p, sc_p = _run_inference(cleaned) | |
| latency_ms = (time.perf_counter() - t0) * 1000 | |
| max_emo_p = float(np.max(e_p)) | |
| top_emo = CLASSNAMES[int(np.argmax(e_p))] if max_emo_p >= EMO_THRESHOLD else "none" | |
| top_sent = SENT_LABELS[int(np.argmax(st_p))] | |
| top_sarc = SARC_LABELS[int(np.argmax(sc_p))] | |
| urgent = (top_emo in {"sadness", "fear", "anger"} and top_sent == "negative") | |
| return DetectResponse( | |
| emotion=top_emo, confidence=max_emo_p, sentiment=top_sent, | |
| sarcasm=SarcasmResult(label=top_sarc, score=float(np.max(sc_p))), | |
| urgent=urgent, latency_ms=latency_ms, source="model", | |
| route=route_tag, cleaned=cleaned | |
| ) | |
| except Exception as exc: | |
| return _FALLBACK.model_copy() | |
| app = gr.mount_gradio_app(fapp, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| forwarded_allow_ips="*", | |
| proxy_headers=True | |
| ) |