|
|
| import os |
| import torch |
| import numpy as np |
| from transformers import AutoTokenizer |
| from modeling import UnifiedMASRIHead |
| import fasttext |
| from huggingface_hub import hf_hub_download |
|
|
| class taMASRIBERT: |
| def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): |
| self.device = device |
| self.repo_id = "T0KII/taMASRIBERT" |
| |
| print("Loading Tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id) |
| |
| print("Loading FastText embeddings (~3.3GB)...") |
| self.ft_model = fasttext.load_model(hf_hub_download("facebook/fasttext-arz-vectors", "model.bin")) |
| |
| print("Initializing Deep Fusion Architecture...") |
| self.model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(self.device) |
| |
| print("Fetching Model Weights...") |
| |
| weights_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin") |
| state_dict = torch.load(weights_path, map_location=self.device) |
| self.model.load_state_dict(state_dict) |
| self.model.eval() |
| print("✓ taMASRIBERT is ready for inference.") |
| |
| def _get_ft_embedding(self, text, max_len=128, embed_dim=300): |
| tokens = str(text).split()[:max_len] |
| matrix = np.zeros((max_len, embed_dim), dtype=np.float32) |
| for i, tok in enumerate(tokens): |
| try: matrix[i] = self.ft_model.get_word_vector(tok) |
| except Exception: pass |
| return torch.from_numpy(matrix).unsqueeze(0) |
| |
| def predict(self, text, task='sarcasm'): |
| enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding='max_length') |
| input_ids = enc['input_ids'].to(self.device) |
| attention_mask = enc['attention_mask'].to(self.device) |
| ft_embeds = self._get_ft_embedding(text).to(self.device) |
| |
| with torch.no_grad(): |
| logits = self.model(input_ids, attention_mask, ft_embeds, task) |
| probs = torch.softmax(logits, dim=1) |
| pred = torch.argmax(probs, dim=1).item() |
| return pred, probs.cpu().numpy().tolist()[0] |
|
|
| if __name__ == "__main__": |
| model = taMASRIBERT() |
| text = "يا سلام عليك يا عبقري" |
| |
| sarc_pred, sarc_probs = model.predict(text, task='sarcasm') |
| sent_pred, sent_probs = model.predict(text, task='sentiment') |
| emo_pred, emo_probs = model.predict(text, task='emotion') |
| |
| print(f"\nText: {text}") |
| print(f"Sarcasm: {sarc_pred} | Probs: {[f'{p:.4f}' for p in sarc_probs]}") |
| print(f"Sentiment: {sent_pred} | Probs: {[f'{p:.4f}' for p in sent_probs]}") |
| print(f"Emotion: {emo_pred} | Probs: {[f'{p:.4f}' for p in emo_probs]}") |
|
|