bayan-api / src /nlp /dialect /dialect_service.py
youssefreda9's picture
fix: remove forced torch_dtype from dialect model - load with native weights
caddab8
Raw
History Blame Contribute Delete
2.5 kB
"""
Dialect-to-MSA (Modern Standard Arabic) conversion service.
Uses bayan10/dialect-to-msa-model (mT5 300M) to convert colloquial
Arabic dialects (Egyptian, Gulf, Levantine, Maghrebi) to formal MSA.
Singleton pattern — lazy-loads the model on first request to avoid
blocking server startup.
"""
import logging
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
logger = logging.getLogger(__name__)
_instance = None
class DialectConverter:
"""Converts dialect Arabic text to Modern Standard Arabic (MSA)."""
PREFIX = "حوّل إلى الفصحى: "
REPO_ID = "bayan10/dialect-to-msa-model"
MAX_INPUT_LENGTH = 128
MAX_OUTPUT_LENGTH = 128
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[DIALECT] Loading model from '{self.REPO_ID}' on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.REPO_ID)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.REPO_ID).to(self.device)
self.model.eval()
logger.info("[DIALECT] Model loaded successfully.")
def convert(self, dialect_text: str, num_beams: int = 4) -> str:
"""Convert a single dialect sentence to MSA."""
if not dialect_text or not dialect_text.strip():
return dialect_text
input_text = self.PREFIX + dialect_text.strip()
inputs = self.tokenizer(
input_text,
return_tensors="pt",
max_length=self.MAX_INPUT_LENGTH,
truncation=True,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.MAX_OUTPUT_LENGTH,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=3,
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
def is_ready(self) -> bool:
"""Check if the model is loaded and ready."""
return self.model is not None and self.tokenizer is not None
def get_dialect_model() -> DialectConverter:
"""Get or create the singleton DialectConverter instance."""
global _instance
if _instance is None:
_instance = DialectConverter()
return _instance
def is_loaded() -> bool:
"""Check if the dialect model is loaded (without triggering lazy load)."""
return _instance is not None and _instance.is_ready()