absa-app / inference.py
asmashayea's picture
m
7ebac28
raw
history blame
5.02 kB
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from araberta_setting.modeling_bilstm_crf import BERT_BiLSTM_CRF
from seq2seq_inference import infer_t5_prompt
from huggingface_hub import hf_hub_download
# Define supported models and their adapter IDs
MODEL_OPTIONS = {
"Araberta": {
"base": "asmashayea/absa-araberta",
"adapter": "asmashayea/absa-araberta"
},
"mT5": {
"base": "google/mt5-base",
"adapter": "asmashayea/mt4-absa"
},
"mBART": {
"base": "facebook/mbart-large-50-many-to-many-mmt",
"adapter": "asmashayea/mbart-absa"
},
"GPT3.5": {
"base": "bigscience/bloom-560m", # placeholder
"adapter": "asmashayea/gpt-absa"
},
"GPT4o": {
"base": "bigscience/bloom-560m", # placeholder
"adapter": "asmashayea/gpt-absa"
}
}
cached_models = {}
def load_araberta():
path = "asmashayea/absa-arabert"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(path)
base_model = AutoModel.from_pretrained(path)
# Load LoRA adapter
lora_config = LoraConfig.from_pretrained(path)
lora_model = get_peft_model(base_model, lora_config)
# Download CRF head from Hub
local_pt = hf_hub_download(repo_id=path, filename="bilstm_crf_head.pt")
config = AutoConfig.from_pretrained(path)
model = BERT_BiLSTM_CRF(lora_model, config)
# Always map to current device
state_dict = torch.load(local_pt, map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device).eval()
cached_models["Araberta"] = (tokenizer, model)
return tokenizer, model
def infer_araberta(text):
if "Araberta" not in cached_models:
tokenizer, model = load_araberta()
else:
tokenizer, model = cached_models["Araberta"]
device = next(model.parameters()).device
inputs = tokenizer(
text,
return_tensors='pt',
truncation=True,
padding='max_length',
max_length=128
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predicted_ids = outputs['logits'][0].cpu().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
# Horizontal token:label pairs
pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
horizontal_output = " | ".join(pairs)
# Group into aspect spans
aspects = []
current_tokens, current_sentiment = [], None
for token, label in zip(clean_tokens, clean_labels):
if label.startswith("B-"):
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": current_sentiment
})
current_tokens = [token]
current_sentiment = label.split("-")[1]
elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
current_tokens.append(token)
else:
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": current_sentiment
})
current_tokens, current_sentiment = [], None
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": current_sentiment
})
return {
"token_predictions": horizontal_output,
"aspects": aspects
}
def load_model(model_key):
if model_key in cached_models:
return cached_models[model_key]
base_id = MODEL_OPTIONS[model_key]["base"]
adapter_id = MODEL_OPTIONS[model_key]["adapter"]
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(adapter_id)
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id).to(device)
model = PeftModel.from_pretrained(base_model, adapter_id).to(device)
model.eval()
cached_models[model_key] = (tokenizer, model)
return tokenizer, model
def predict_absa(text, model_choice):
if model_choice in ['mT5', 'mBART']:
tokenizer, model = load_model(model_choice)
decoded = infer_t5_prompt(text, tokenizer, model)
elif model_choice == 'Araberta':
decoded = infer_araberta(text)
else:
decoded = {"error": f"Model {model_choice} not supported"}
return decoded