Spaces:
Sleeping
Sleeping
| """Inference utilities for AspectBERT. | |
| - load_model(): load tokenizer + DistilBERT backbone + classifier head from a | |
| local checkpoint directory or a HuggingFace Hub repo (HF_MODEL_NAME). | |
| - predict_aspect() / predict_all_aspects(): run sentiment prediction for one | |
| or all 8 aspects on a single review. | |
| - explain_with_lime(): word-importance explanation for a (review, aspect) pair. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import DistilBertModel, DistilBertTokenizerFast | |
| from constants import ASPECTS, ID2LABEL, MAX_LENGTH, MODEL_NAME, format_input # noqa: E402 | |
| from model import AspectBERT # noqa: E402 | |
| def _resolve_classifier_head(model_path): | |
| """Find classifier_head.pt either locally or on the HF Hub.""" | |
| local_path = os.path.join(model_path, "classifier_head.pt") | |
| if os.path.exists(local_path): | |
| return local_path | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download(repo_id=model_path, filename="classifier_head.pt") | |
| except Exception: | |
| return None | |
| def load_model(model_path=None, device=None): | |
| """Load AspectBERT for inference. | |
| `model_path` may be: | |
| - None: read from HF_MODEL_NAME env var, falling back to the base | |
| distilbert-base-uncased weights with an untrained head. | |
| - a local checkpoint directory (as written by train.save_checkpoint) | |
| - a HuggingFace Hub repo id containing the backbone, tokenizer, and | |
| classifier_head.pt | |
| """ | |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_path = model_path or os.environ.get("HF_MODEL_NAME") or MODEL_NAME | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) | |
| model = AspectBERT(model_name=MODEL_NAME) | |
| if model_path != MODEL_NAME: | |
| try: | |
| model.distilbert = DistilBertModel.from_pretrained(model_path) | |
| except Exception as exc: | |
| print(f"Warning: could not load fine-tuned backbone from " | |
| f"'{model_path}' ({exc}); using base distilbert weights.") | |
| classifier_head_path = _resolve_classifier_head(model_path) | |
| if classifier_head_path: | |
| state_dict = torch.load(classifier_head_path, map_location="cpu") | |
| model.classifier.load_state_dict(state_dict) | |
| else: | |
| print(f"Warning: no classifier_head.pt found for '{model_path}'; " | |
| "classification head is randomly initialized (untrained).") | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, device | |
| def predict_aspect(model, tokenizer, device, text, aspect): | |
| """Predict the sentiment label + class probabilities for one aspect.""" | |
| inp = format_input(text, aspect) | |
| enc = tokenizer(inp, truncation=True, padding="max_length", | |
| max_length=MAX_LENGTH, return_tensors="pt") | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| logits = model(enc["input_ids"], enc["attention_mask"]) | |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] | |
| label_idx = int(np.argmax(probs)) | |
| return { | |
| "label": ID2LABEL[label_idx], | |
| "scores": {ID2LABEL[i]: float(probs[i]) for i in range(len(probs))}, | |
| } | |
| def predict_all_aspects(model, tokenizer, device, text, aspects=None): | |
| """Predict sentiment for every aspect (default: all 8) on one review.""" | |
| aspects = aspects or ASPECTS | |
| return {aspect: predict_aspect(model, tokenizer, device, text, aspect) for aspect in aspects} | |
| def explain_with_lime(model, tokenizer, device, text, aspect, num_features=10, num_samples=200): | |
| """Return a LIME explanation object for the (text, aspect) prediction. | |
| Requires the `lime` package. | |
| """ | |
| from lime.lime_text import LimeTextExplainer | |
| class_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] | |
| explainer = LimeTextExplainer(class_names=class_names) | |
| def predict_proba(texts): | |
| out = [] | |
| for t in texts: | |
| inp = format_input(t, aspect) | |
| enc = tokenizer(inp, truncation=True, padding="max_length", | |
| max_length=MAX_LENGTH, return_tensors="pt") | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| logits = model(enc["input_ids"], enc["attention_mask"]) | |
| probs = F.softmax(logits, dim=-1).cpu().numpy() | |
| out.append(probs[0]) | |
| return np.array(out) | |
| return explainer.explain_instance( | |
| text, predict_proba, num_features=num_features, | |
| num_samples=num_samples, labels=list(range(len(class_names))), | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Run AspectBERT inference on a review.") | |
| parser.add_argument("text", help="Review text to analyze.") | |
| parser.add_argument("--model_path", default=None, | |
| help="Local checkpoint dir or HF Hub repo id " | |
| "(defaults to HF_MODEL_NAME env var).") | |
| parser.add_argument("--aspect", default=None, | |
| help="Single aspect to predict; default = all 8 aspects.") | |
| args = parser.parse_args() | |
| model, tokenizer, device = load_model(args.model_path) | |
| if args.aspect: | |
| result = {args.aspect: predict_aspect(model, tokenizer, device, args.text, args.aspect)} | |
| else: | |
| result = predict_all_aspects(model, tokenizer, device, args.text) | |
| print(json.dumps(result, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |