"""Inference utilities for AspectBERT. - load_model(): load tokenizer + DistilBERT backbone + classifier head from a local checkpoint directory or a HuggingFace Hub repo (HF_MODEL_NAME). - predict_aspect() / predict_all_aspects(): run sentiment prediction for one or all 8 aspects on a single review. - explain_with_lime(): word-importance explanation for a (review, aspect) pair. """ import argparse import json import os import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import numpy as np import torch import torch.nn.functional as F from transformers import DistilBertModel, DistilBertTokenizerFast from constants import ASPECTS, ID2LABEL, MAX_LENGTH, MODEL_NAME, format_input # noqa: E402 from model import AspectBERT # noqa: E402 def _resolve_classifier_head(model_path): """Find classifier_head.pt either locally or on the HF Hub.""" local_path = os.path.join(model_path, "classifier_head.pt") if os.path.exists(local_path): return local_path try: from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=model_path, filename="classifier_head.pt") except Exception: return None def load_model(model_path=None, device=None): """Load AspectBERT for inference. `model_path` may be: - None: read from HF_MODEL_NAME env var, falling back to the base distilbert-base-uncased weights with an untrained head. - a local checkpoint directory (as written by train.save_checkpoint) - a HuggingFace Hub repo id containing the backbone, tokenizer, and classifier_head.pt """ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = model_path or os.environ.get("HF_MODEL_NAME") or MODEL_NAME tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) model = AspectBERT(model_name=MODEL_NAME) if model_path != MODEL_NAME: try: model.distilbert = DistilBertModel.from_pretrained(model_path) except Exception as exc: print(f"Warning: could not load fine-tuned backbone from " f"'{model_path}' ({exc}); using base distilbert weights.") classifier_head_path = _resolve_classifier_head(model_path) if classifier_head_path: state_dict = torch.load(classifier_head_path, map_location="cpu") model.classifier.load_state_dict(state_dict) else: print(f"Warning: no classifier_head.pt found for '{model_path}'; " "classification head is randomly initialized (untrained).") model.to(device) model.eval() return model, tokenizer, device @torch.no_grad() def predict_aspect(model, tokenizer, device, text, aspect): """Predict the sentiment label + class probabilities for one aspect.""" inp = format_input(text, aspect) enc = tokenizer(inp, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt") enc = {k: v.to(device) for k, v in enc.items()} logits = model(enc["input_ids"], enc["attention_mask"]) probs = F.softmax(logits, dim=-1).cpu().numpy()[0] label_idx = int(np.argmax(probs)) return { "label": ID2LABEL[label_idx], "scores": {ID2LABEL[i]: float(probs[i]) for i in range(len(probs))}, } def predict_all_aspects(model, tokenizer, device, text, aspects=None): """Predict sentiment for every aspect (default: all 8) on one review.""" aspects = aspects or ASPECTS return {aspect: predict_aspect(model, tokenizer, device, text, aspect) for aspect in aspects} def explain_with_lime(model, tokenizer, device, text, aspect, num_features=10, num_samples=200): """Return a LIME explanation object for the (text, aspect) prediction. Requires the `lime` package. """ from lime.lime_text import LimeTextExplainer class_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] explainer = LimeTextExplainer(class_names=class_names) def predict_proba(texts): out = [] for t in texts: inp = format_input(t, aspect) enc = tokenizer(inp, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt") enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): logits = model(enc["input_ids"], enc["attention_mask"]) probs = F.softmax(logits, dim=-1).cpu().numpy() out.append(probs[0]) return np.array(out) return explainer.explain_instance( text, predict_proba, num_features=num_features, num_samples=num_samples, labels=list(range(len(class_names))), ) def main(): parser = argparse.ArgumentParser(description="Run AspectBERT inference on a review.") parser.add_argument("text", help="Review text to analyze.") parser.add_argument("--model_path", default=None, help="Local checkpoint dir or HF Hub repo id " "(defaults to HF_MODEL_NAME env var).") parser.add_argument("--aspect", default=None, help="Single aspect to predict; default = all 8 aspects.") args = parser.parse_args() model, tokenizer, device = load_model(args.model_path) if args.aspect: result = {args.aspect: predict_aspect(model, tokenizer, device, args.text, args.aspect)} else: result = predict_all_aspects(model, tokenizer, device, args.text) print(json.dumps(result, indent=2)) if __name__ == "__main__": main()