taMASRIBERT / inference.py
T0KII's picture
Initial Deployment Package with Inference Script
31d7b01 verified
Raw
History Blame Contribute Delete
2.77 kB
import os
import torch
import numpy as np
from transformers import AutoTokenizer
from modeling import UnifiedMASRIHead
import fasttext
from huggingface_hub import hf_hub_download
class taMASRIBERT:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
self.repo_id = "T0KII/taMASRIBERT"
print("Loading Tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
print("Loading FastText embeddings (~3.3GB)...")
self.ft_model = fasttext.load_model(hf_hub_download("facebook/fasttext-arz-vectors", "model.bin"))
print("Initializing Deep Fusion Architecture...")
self.model = UnifiedMASRIHead(bert_model_name="T0KII/MASRIBERTv3").to(self.device)
print("Fetching Model Weights...")
# Automatically downloads weights if they aren't local
weights_path = hf_hub_download(repo_id=self.repo_id, filename="pytorch_model.bin")
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
print("✓ taMASRIBERT is ready for inference.")
def _get_ft_embedding(self, text, max_len=128, embed_dim=300):
tokens = str(text).split()[:max_len]
matrix = np.zeros((max_len, embed_dim), dtype=np.float32)
for i, tok in enumerate(tokens):
try: matrix[i] = self.ft_model.get_word_vector(tok)
except Exception: pass
return torch.from_numpy(matrix).unsqueeze(0)
def predict(self, text, task='sarcasm'):
enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=128, padding='max_length')
input_ids = enc['input_ids'].to(self.device)
attention_mask = enc['attention_mask'].to(self.device)
ft_embeds = self._get_ft_embedding(text).to(self.device)
with torch.no_grad():
logits = self.model(input_ids, attention_mask, ft_embeds, task)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
return pred, probs.cpu().numpy().tolist()[0]
if __name__ == "__main__":
model = taMASRIBERT()
text = "يا سلام عليك يا عبقري"
sarc_pred, sarc_probs = model.predict(text, task='sarcasm')
sent_pred, sent_probs = model.predict(text, task='sentiment')
emo_pred, emo_probs = model.predict(text, task='emotion')
print(f"\nText: {text}")
print(f"Sarcasm: {sarc_pred} | Probs: {[f'{p:.4f}' for p in sarc_probs]}")
print(f"Sentiment: {sent_pred} | Probs: {[f'{p:.4f}' for p in sent_probs]}")
print(f"Emotion: {emo_pred} | Probs: {[f'{p:.4f}' for p in emo_probs]}")