File size: 2,769 Bytes
31d7b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

import os
import torch
import numpy as np
from transformers import AutoTokenizer
from modeling import UnifiedMASRIHead
import fasttext
from huggingface_hub import hf_hub_download

class taMASRIBERT:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.repo_id = "T0KII/taMASRIBERT"
        
        print("Loading Tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
        
        print("Loading FastText embeddings (~3.3GB)...")
        self.ft_model = fasttext.load_model(hf_hub_download("facebook/fasttext-arz-vectors", "model.bin"))
        
        print("Initializing Deep Fusion Architecture...")
        self.model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(self.device)
        
        print("Fetching Model Weights...")
        # Automatically downloads weights if they aren't local
        weights_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin")
        state_dict = torch.load(weights_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        print("✓ taMASRIBERT is ready for inference.")
    
    def _get_ft_embedding(self, text, max_len=128, embed_dim=300):
        tokens = str(text).split()[:max_len]
        matrix = np.zeros((max_len, embed_dim), dtype=np.float32)
        for i, tok in enumerate(tokens):
            try: matrix[i] = self.ft_model.get_word_vector(tok)
            except Exception: pass
        return torch.from_numpy(matrix).unsqueeze(0) 
    
    def predict(self, text, task='sarcasm'):
        enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding='max_length')
        input_ids = enc['input_ids'].to(self.device)
        attention_mask = enc['attention_mask'].to(self.device)
        ft_embeds = self._get_ft_embedding(text).to(self.device)
        
        with torch.no_grad():
            logits = self.model(input_ids, attention_mask, ft_embeds, task)
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(probs, dim=1).item()
        return pred, probs.cpu().numpy().tolist()[0]

if __name__ == "__main__":
    model = taMASRIBERT()
    text = "يا سلام عليك يا عبقري"
    
    sarc_pred, sarc_probs = model.predict(text, task='sarcasm')
    sent_pred, sent_probs = model.predict(text, task='sentiment')
    emo_pred, emo_probs = model.predict(text, task='emotion')
    
    print(f"\nText: {text}")
    print(f"Sarcasm: {sarc_pred} | Probs: {[f'{p:.4f}' for p in sarc_probs]}")
    print(f"Sentiment: {sent_pred} | Probs: {[f'{p:.4f}' for p in sent_probs]}")
    print(f"Emotion: {emo_pred} | Probs: {[f'{p:.4f}' for p in emo_probs]}")