AspectBERT / src /inference.py
itismeTithi's picture
Deploy AspectBERT Streamlit app
31f6bcb
Raw
History Blame Contribute Delete
5.59 kB
"""Inference utilities for AspectBERT.
- load_model(): load tokenizer + DistilBERT backbone + classifier head from a
local checkpoint directory or a HuggingFace Hub repo (HF_MODEL_NAME).
- predict_aspect() / predict_all_aspects(): run sentiment prediction for one
or all 8 aspects on a single review.
- explain_with_lime(): word-importance explanation for a (review, aspect) pair.
"""
import argparse
import json
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import torch
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertTokenizerFast
from constants import ASPECTS, ID2LABEL, MAX_LENGTH, MODEL_NAME, format_input # noqa: E402
from model import AspectBERT # noqa: E402
def _resolve_classifier_head(model_path):
"""Find classifier_head.pt either locally or on the HF Hub."""
local_path = os.path.join(model_path, "classifier_head.pt")
if os.path.exists(local_path):
return local_path
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=model_path, filename="classifier_head.pt")
except Exception:
return None
def load_model(model_path=None, device=None):
"""Load AspectBERT for inference.
`model_path` may be:
- None: read from HF_MODEL_NAME env var, falling back to the base
distilbert-base-uncased weights with an untrained head.
- a local checkpoint directory (as written by train.save_checkpoint)
- a HuggingFace Hub repo id containing the backbone, tokenizer, and
classifier_head.pt
"""
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = model_path or os.environ.get("HF_MODEL_NAME") or MODEL_NAME
tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
model = AspectBERT(model_name=MODEL_NAME)
if model_path != MODEL_NAME:
try:
model.distilbert = DistilBertModel.from_pretrained(model_path)
except Exception as exc:
print(f"Warning: could not load fine-tuned backbone from "
f"'{model_path}' ({exc}); using base distilbert weights.")
classifier_head_path = _resolve_classifier_head(model_path)
if classifier_head_path:
state_dict = torch.load(classifier_head_path, map_location="cpu")
model.classifier.load_state_dict(state_dict)
else:
print(f"Warning: no classifier_head.pt found for '{model_path}'; "
"classification head is randomly initialized (untrained).")
model.to(device)
model.eval()
return model, tokenizer, device
@torch.no_grad()
def predict_aspect(model, tokenizer, device, text, aspect):
"""Predict the sentiment label + class probabilities for one aspect."""
inp = format_input(text, aspect)
enc = tokenizer(inp, truncation=True, padding="max_length",
max_length=MAX_LENGTH, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
logits = model(enc["input_ids"], enc["attention_mask"])
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
label_idx = int(np.argmax(probs))
return {
"label": ID2LABEL[label_idx],
"scores": {ID2LABEL[i]: float(probs[i]) for i in range(len(probs))},
}
def predict_all_aspects(model, tokenizer, device, text, aspects=None):
"""Predict sentiment for every aspect (default: all 8) on one review."""
aspects = aspects or ASPECTS
return {aspect: predict_aspect(model, tokenizer, device, text, aspect) for aspect in aspects}
def explain_with_lime(model, tokenizer, device, text, aspect, num_features=10, num_samples=200):
"""Return a LIME explanation object for the (text, aspect) prediction.
Requires the `lime` package.
"""
from lime.lime_text import LimeTextExplainer
class_names = [ID2LABEL[i] for i in range(len(ID2LABEL))]
explainer = LimeTextExplainer(class_names=class_names)
def predict_proba(texts):
out = []
for t in texts:
inp = format_input(t, aspect)
enc = tokenizer(inp, truncation=True, padding="max_length",
max_length=MAX_LENGTH, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = F.softmax(logits, dim=-1).cpu().numpy()
out.append(probs[0])
return np.array(out)
return explainer.explain_instance(
text, predict_proba, num_features=num_features,
num_samples=num_samples, labels=list(range(len(class_names))),
)
def main():
parser = argparse.ArgumentParser(description="Run AspectBERT inference on a review.")
parser.add_argument("text", help="Review text to analyze.")
parser.add_argument("--model_path", default=None,
help="Local checkpoint dir or HF Hub repo id "
"(defaults to HF_MODEL_NAME env var).")
parser.add_argument("--aspect", default=None,
help="Single aspect to predict; default = all 8 aspects.")
args = parser.parse_args()
model, tokenizer, device = load_model(args.model_path)
if args.aspect:
result = {args.aspect: predict_aspect(model, tokenizer, device, args.text, args.aspect)}
else:
result = predict_all_aspects(model, tokenizer, device, args.text)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()