taMASRIBERTs / app.py
T0KII's picture
Update app.py
7373841 verified
Raw
History Blame Contribute Delete
13.2 kB
"""
kalamna-space / app.py
Gradio + FastAPI Space for taMASRIBERT โ€” the unified deep-fusion model.
Routing Logic:
- ALL text uses Deep Fusion (BERT + FastText).
- Any non-Arabic text (English or Franco) is dynamically extracted, translated/transliterated
into Egyptian Arabic via NAMAA, and stitched back into the sequence before inference.
"""
import os, re, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import fasttext
import uvicorn
from fastapi import FastAPI, HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel
from typing import Optional
from transformers import AutoTokenizer, AutoModel, MarianTokenizer, MarianMTModel
from huggingface_hub import hf_hub_download, login
# Try to login if HF_TOKEN is set
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 128
FT_DIM = 300
RNN_HID = 256
CLASSNAMES = ['none', 'anger', 'joy', 'sadness', 'love', 'sympathy', 'surprise', 'fear']
SENT_LABELS = ['negative', 'neutral', 'positive']
SARC_LABELS = ['not sarcastic', 'sarcastic']
EMO_THRESHOLD = 0.45
# โ”€โ”€ Regex and Cleaning โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_DIACRITICS_RE = re.compile(r"[\u064B-\u065F\u0670\u06D6-\u06DC\u06DF-\u06E4\u06E7\u06E8\u06EA-\u06ED\u0640]")
_ARABIC_RE = re.compile(r"[\u0600-\u06FF]")
_LATIN_RE = re.compile(r"[a-zA-Z]")
def has_latin(text: str) -> bool: return bool(_LATIN_RE.search(text))
def clean_text(text: str) -> str:
if not isinstance(text, str): return ""
text = re.sub(r"(?:https?://|www\.)\S+", "", text)
text = re.sub(r"\S+@\S+", "", text)
text = re.sub(r"@\w+", "", text)
text = re.sub(r"#", "", text)
text = re.sub(r"\n+", " ", text)
text = re.sub(r"[ุฅุฃุขุง]", "ุง", text)
text = re.sub(r"ู‰", "ูŠ", text)
text = re.sub(r"ุค", "ุก", text)
text = re.sub(r"ุฆ", "ุก", text)
text = re.sub(r"ฺฏ", "ูƒ", text)
text = _DIACRITICS_RE.sub("", text)
text = re.sub(r"(.)\1{3,}", r"\1\1", text)
text = re.sub(r"[^\u0600-\u06FFa-zA-Z0-9\s\.\,\!\?\;\:\"\'\(\)\[\]\{\}\-\+\=\/\\]", "", text)
return re.sub(r"\s+", " ", text).strip()
# โ”€โ”€ NAMAA ENโ†’EGY translation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print("Loading NAMAA ENโ†’EGY translatorโ€ฆ")
_EGY_MODEL = "NAMAA-Space/masrawy-english-to-egyptian-arabic-translator-v2.9"
egy_tokenizer = MarianTokenizer.from_pretrained(_EGY_MODEL)
egy_model = MarianMTModel.from_pretrained(_EGY_MODEL).to(DEVICE)
egy_model.eval()
def translate_en_to_egy(text: str) -> str:
tokens = egy_tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
out = egy_model.generate(**tokens, num_beams=4, max_new_tokens=128)
return egy_tokenizer.decode(out[0], skip_special_tokens=True)
# โ”€โ”€ Force Arabic Translation Pre-Processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def force_arabic_translation(raw: str) -> tuple[str, str]:
if not has_latin(raw):
return clean_text(raw), "Pure Arabic (Deep Fusion)"
# Function to translate isolated chunks of Latin/Franco text
def translate_span(match):
span = match.group(0).strip()
if not span: return match.group(0)
try:
translated = translate_en_to_egy(span)
return " " + translated + " "
except:
return match.group(0)
# Greedily match contiguous blocks of Latin characters and numbers
# This captures full English sentences, Franco blocks, or mixed chunks
pattern = re.compile(r'[a-zA-Z0-9]+(?:\s+[a-zA-Z0-9]+)*')
processed = pattern.sub(translate_span, raw)
cleaned = clean_text(processed)
return cleaned, "Translated to Arabic (Deep Fusion)"
# โ”€โ”€ Model definition โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class TaskHead(nn.Module):
def __init__(self, input_size: int, n_classes: int, dropout: float = 0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, 256), nn.LayerNorm(256), nn.GELU(),
nn.Dropout(dropout), nn.Linear(256, n_classes),
)
def forward(self, x): return self.net(x)
class UnifiedMASRIHead(nn.Module):
def __init__(self, bert_model_name: str = "T0KII/MASRIBERTv3", ft_dim: int = 300, rnn_hidden: int = 256, num_layers: int = 2, dropout: float = 0.3):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
combined_dim = self.bert.config.hidden_size + rnn_hidden * 4
self.bilstm = nn.LSTM(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.bigru = nn.GRU(ft_dim, rnn_hidden, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
self.rnn_dropout = nn.Dropout(dropout)
self.sarcasm_head = TaskHead(combined_dim, 2, dropout=0.5)
self.sentiment_head = TaskHead(combined_dim, 3, dropout=0.3)
self.emotion_head = TaskHead(combined_dim, 8, dropout=0.3)
def forward(self, input_ids, attention_mask, ft_embeds):
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_vec = bert_out.last_hidden_state[:, 0, :]
lstm_out, _ = self.bilstm(ft_embeds)
gru_out, _ = self.bigru(ft_embeds)
rnn_feat = self.rnn_dropout(torch.cat([lstm_out[:, -1, :], gru_out[:, -1, :]], dim=1))
combined = torch.cat([cls_vec, rnn_feat], dim=1)
return self.emotion_head(combined), self.sentiment_head(combined), self.sarcasm_head(combined)
# โ”€โ”€ Load FastText โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print("Loading FastText arz vectorsโ€ฆ")
_ft_path = hf_hub_download("facebook/fasttext-arz-vectors", "model.bin")
_ft_model = fasttext.load_model(_ft_path)
def get_ft_embedding(text: str) -> np.ndarray:
tokens = text.split()[:MAX_LEN]
matrix = np.zeros((MAX_LEN, FT_DIM), dtype=np.float32)
for i, tok in enumerate(tokens):
try: matrix[i] = _ft_model.get_word_vector(tok)
except: pass
return matrix
# โ”€โ”€ Load taMASRIBERT โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print("Loading taMASRIBERT tokenizerโ€ฆ")
_REPO = "T0KII/taMASRIBERT"
ta_tokenizer = AutoTokenizer.from_pretrained(_REPO)
print("Initialising UnifiedMASRIHeadโ€ฆ")
ta_model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(DEVICE)
ta_model.load_state_dict(torch.load(hf_hub_download(repo_id=_REPO, filename="pytorch_model.bin"), map_location=DEVICE), strict=False)
ta_model.eval()
print("โœ“ taMASRIBERT ready")
# โ”€โ”€ Core inference โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@torch.no_grad()
def _run_inference(cleaned: str):
enc = ta_tokenizer(cleaned, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
ft_emb = torch.from_numpy(np.array([get_ft_embedding(cleaned)], dtype=np.float32)).to(DEVICE)
e_l, st_l, sc_l = ta_model(enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE), ft_emb)
return (F.softmax(e_l, -1).cpu().numpy()[0], F.softmax(st_l, -1).cpu().numpy()[0], F.softmax(sc_l, -1).cpu().numpy()[0])
def predict(text: str):
if not text or not text.strip(): return {}, {}, {}, "โ€”", ""
cleaned, route_tag = force_arabic_translation(text)
if not cleaned: return {}, {}, {}, "โ€”", ""
e_p, st_p, sc_p = _run_inference(cleaned)
max_emo_p = float(np.max(e_p))
top_emo = CLASSNAMES[int(np.argmax(e_p))] if max_emo_p >= EMO_THRESHOLD else "none"
top_sent = SENT_LABELS[int(np.argmax(st_p))]
top_sarc = SARC_LABELS[int(np.argmax(sc_p))]
summary = (
f"**Emotion:** {top_emo.title()} ({max_emo_p:.1%}) | "
f"**Sentiment:** {top_sent.title()} ({float(np.max(st_p)):.1%}) | "
f"**Sarcasm:** {top_sarc.title()} ({float(np.max(sc_p)):.1%}) \n\n"
f"๐Ÿท๏ธ **Tag:** {route_tag}"
)
return (
{CLASSNAMES[i]: float(e_p[i]) for i in range(len(CLASSNAMES))},
{SENT_LABELS[i]: float(st_p[i]) for i in range(len(SENT_LABELS))},
{SARC_LABELS[i]: float(sc_p[i]) for i in range(len(SARC_LABELS))},
summary,
cleaned
)
# โ”€โ”€ Gradio UI โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
gr.Markdown("# Kalamna: Egyptian Arabic Analyzer")
gr.Markdown("**taMASRIBERT** with Automatic Arabic Script Normalization.")
with gr.Row():
txt = gr.Textbox(label="Input (Arabic / Franco / English / Mixed)", placeholder="ุงูƒุชุจ ู‡ู†ุง ...", lines=3, rtl=True)
ctx = gr.Textbox(label="Pre-processed (what the model sees)", interactive=False, rtl=True)
btn = gr.Button("Analyze", variant="primary")
sum_md = gr.Markdown()
with gr.Row():
e_out = gr.Label(label="Emotion", num_top_classes=8)
st_out = gr.Label(label="Sentiment", num_top_classes=3)
sc_out = gr.Label(label="Sarcasm", num_top_classes=2)
gr.Examples(
examples=[
"ุงู„ู…ูˆุชูˆุณูŠูƒู„ ุงู„ูƒูŠ ูˆุงูŠ ุณูˆุจุฑ ู„ุงูŠุช ุทู„ุน ุนูŠู†ูŠ ููŠ ุงู„ุณูุฑูŠุฉ ุจุชุงุนุฉ ุงู„ุณุงุญู„ุŒ ุงู„ุชูˆูƒูŠู„ ุงุณูˆุฃ ู…ุง ูŠู…ูƒู†",
"ya gd3an el match elgai is crucial, el bad performance dh msh hnkml beeh",
"ู…ุง ุดุงุก ุงู„ู„ู‡ ุนู„ู‰ ุณุฑุนุฉ ุงู„ู†ุชุŒ ุงู„ุณู„ุญูุงุฉ ุจุชุณุจู‚ู‡ุŒ ุจุฌุฏ ุงุญุณู† ุฎุฏู…ุฉ ุนู…ู„ุงุก ููŠ ุงู„ุฏู†ูŠุง.",
"ana 3mlt update l flutter w el app crash, this is extremely frustrating",
"ุงู†ุง ุทู„ุจุช ุงู„ุงูˆุฑุฏุฑ ู…ู† ุดู‡ุฑ ูˆ ู„ุณู‡ ู…ุฌุงุดุŒ your delivery service is completely useless and I want a refund",
"el boxy fit shirt dh shklo gamed awy, perfect style w el material quality is top notch"
],
inputs=txt,
)
btn.click(predict, txt, [e_out, st_out, sc_out, sum_md, ctx])
txt.submit(predict, txt, [e_out, st_out, sc_out, sum_md, ctx])
# โ”€โ”€ FastAPI + Gradio mounting โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class DetectRequest(BaseModel): text: str
class SarcasmResult(BaseModel): label: str; score: float
class DetectResponse(BaseModel):
emotion: str; confidence: float; sentiment: str; sarcasm: SarcasmResult
urgent: bool; latency_ms: float; source: str; route: str; cleaned: Optional[str] = None
_FALLBACK = DetectResponse(
emotion="neutral", confidence=0.0, sentiment="neutral",
sarcasm=SarcasmResult(label="not sarcastic", score=0.0),
urgent=False, latency_ms=0.0, source="fallback", route="none", cleaned=None,
)
fapp = FastAPI(title="Kalamna Emotion API", version="2.0.0", root_path=os.environ.get("ROOT_PATH", ""))
@fapp.post("/detect", response_model=DetectResponse)
def detect_api(body: DetectRequest):
if not body.text or not body.text.strip(): return _FALLBACK.model_copy()
t0 = time.perf_counter()
try:
cleaned, route_tag = force_arabic_translation(body.text)
if not cleaned: return _FALLBACK.model_copy()
e_p, st_p, sc_p = _run_inference(cleaned)
latency_ms = (time.perf_counter() - t0) * 1000
max_emo_p = float(np.max(e_p))
top_emo = CLASSNAMES[int(np.argmax(e_p))] if max_emo_p >= EMO_THRESHOLD else "none"
top_sent = SENT_LABELS[int(np.argmax(st_p))]
top_sarc = SARC_LABELS[int(np.argmax(sc_p))]
urgent = (top_emo in {"sadness", "fear", "anger"} and top_sent == "negative")
return DetectResponse(
emotion=top_emo, confidence=max_emo_p, sentiment=top_sent,
sarcasm=SarcasmResult(label=top_sarc, score=float(np.max(sc_p))),
urgent=urgent, latency_ms=latency_ms, source="model",
route=route_tag, cleaned=cleaned
)
except Exception as exc:
return _FALLBACK.model_copy()
app = gr.mount_gradio_app(fapp, demo, path="/")
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
forwarded_allow_ips="*",
proxy_headers=True
)