import os import torch import numpy as np from transformers import AutoTokenizer from modeling import UnifiedMASRIHead import fasttext from huggingface_hub import hf_hub_download class taMASRIBERT: def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.repo_id = "T0KII/taMASRIBERT" print("Loading Tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id) print("Loading FastText embeddings (~3.3GB)...") self.ft_model = fasttext.load_model(hf_hub_download("facebook/fasttext-arz-vectors", "model.bin")) print("Initializing Deep Fusion Architecture...") self.model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(self.device) print("Fetching Model Weights...") # Automatically downloads weights if they aren't local weights_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin") state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval() print("✓ taMASRIBERT is ready for inference.") def _get_ft_embedding(self, text, max_len=128, embed_dim=300): tokens = str(text).split()[:max_len] matrix = np.zeros((max_len, embed_dim), dtype=np.float32) for i, tok in enumerate(tokens): try: matrix[i] = self.ft_model.get_word_vector(tok) except Exception: pass return torch.from_numpy(matrix).unsqueeze(0) def predict(self, text, task='sarcasm'): enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding='max_length') input_ids = enc['input_ids'].to(self.device) attention_mask = enc['attention_mask'].to(self.device) ft_embeds = self._get_ft_embedding(text).to(self.device) with torch.no_grad(): logits = self.model(input_ids, attention_mask, ft_embeds, task) probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1).item() return pred, probs.cpu().numpy().tolist()[0] if __name__ == "__main__": model = taMASRIBERT() text = "يا سلام عليك يا عبقري" sarc_pred, sarc_probs = model.predict(text, task='sarcasm') sent_pred, sent_probs = model.predict(text, task='sentiment') emo_pred, emo_probs = model.predict(text, task='emotion') print(f"\nText: {text}") print(f"Sarcasm: {sarc_pred} | Probs: {[f'{p:.4f}' for p in sarc_probs]}") print(f"Sentiment: {sent_pred} | Probs: {[f'{p:.4f}' for p in sent_probs]}") print(f"Emotion: {emo_pred} | Probs: {[f'{p:.4f}' for p in emo_probs]}")