File size: 2,769 Bytes
31d7b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import torch
import numpy as np
from transformers import AutoTokenizer
from modeling import UnifiedMASRIHead
import fasttext
from huggingface_hub import hf_hub_download
class taMASRIBERT:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
self.repo_id = "T0KII/taMASRIBERT"
print("Loading Tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
print("Loading FastText embeddings (~3.3GB)...")
self.ft_model = fasttext.load_model(hf_hub_download("facebook/fasttext-arz-vectors", "model.bin"))
print("Initializing Deep Fusion Architecture...")
self.model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(self.device)
print("Fetching Model Weights...")
# Automatically downloads weights if they aren't local
weights_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin")
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
print("✓ taMASRIBERT is ready for inference.")
def _get_ft_embedding(self, text, max_len=128, embed_dim=300):
tokens = str(text).split()[:max_len]
matrix = np.zeros((max_len, embed_dim), dtype=np.float32)
for i, tok in enumerate(tokens):
try: matrix[i] = self.ft_model.get_word_vector(tok)
except Exception: pass
return torch.from_numpy(matrix).unsqueeze(0)
def predict(self, text, task='sarcasm'):
enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding='max_length')
input_ids = enc['input_ids'].to(self.device)
attention_mask = enc['attention_mask'].to(self.device)
ft_embeds = self._get_ft_embedding(text).to(self.device)
with torch.no_grad():
logits = self.model(input_ids, attention_mask, ft_embeds, task)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
return pred, probs.cpu().numpy().tolist()[0]
if __name__ == "__main__":
model = taMASRIBERT()
text = "يا سلام عليك يا عبقري"
sarc_pred, sarc_probs = model.predict(text, task='sarcasm')
sent_pred, sent_probs = model.predict(text, task='sentiment')
emo_pred, emo_probs = model.predict(text, task='emotion')
print(f"\nText: {text}")
print(f"Sarcasm: {sarc_pred} | Probs: {[f'{p:.4f}' for p in sarc_probs]}")
print(f"Sentiment: {sent_pred} | Probs: {[f'{p:.4f}' for p in sent_probs]}")
print(f"Emotion: {emo_pred} | Probs: {[f'{p:.4f}' for p in emo_probs]}")
|