Spaces:
Sleeping
Sleeping
| """ | |
| models/bert_model.py | |
| DistilBERT fine-tuned sentiment classifier. | |
| Training is done on Google Colab (GPU required) β see notebooks/colab_train.py. | |
| This file handles inference only, loading the saved checkpoint from disk. | |
| Public API (used by app.py): | |
| predict(text) -> {"label": str, "score": float, "keywords": list[str]} | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| DistilBertTokenizerFast, | |
| DistilBertForSequenceClassification, | |
| ) | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from data.labels import LABEL_NAMES # {0: "Negative", 1: "Positive", 2: "Neutral"} | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SAVE_DIR = os.path.join(os.path.dirname(__file__), "saved", "bert", "bert_sentiment") | |
| HUB_MODEL = "DanTan05/bert-sentiment" # fallback when local checkpoint not present | |
| # ββ Module-level cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Same pattern as baseline.py β load from disk once, reuse on every predict() | |
| _tokenizer = None | |
| _model = None | |
| _device = None | |
| def _load_models(): | |
| global _tokenizer, _model, _device | |
| if _model is not None: | |
| return | |
| # Use local checkpoint if present (dev), otherwise download from Hub (Spaces). | |
| source = SAVE_DIR if os.path.exists(SAVE_DIR) else HUB_MODEL | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading BERT model from '{source}' on {_device}...") | |
| _tokenizer = DistilBertTokenizerFast.from_pretrained(source) | |
| _model = DistilBertForSequenceClassification.from_pretrained(source, attn_implementation="eager") | |
| _model.to(_device) | |
| _model.eval() # disables dropout β important for deterministic inference | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(text: str) -> dict: | |
| """ | |
| Returns the same inference contract dict as baseline.py: | |
| { | |
| "label": "Positive" | "Negative" | "Neutral", | |
| "score": float, # confidence in the predicted class (0β1) | |
| "keywords": list[str], # tokens with highest attention weights | |
| } | |
| Why output_attentions=True? | |
| DistilBERT has 6 transformer layers, each with 12 attention heads. | |
| Each head produces a (seq_len Γ seq_len) attention matrix showing how | |
| much each token "attended to" every other token. | |
| We use these weights as a proxy for token importance. | |
| """ | |
| _load_models() | |
| # Tokenize | |
| # return_tensors="pt" β return PyTorch tensors (not lists) | |
| # truncation=True β clip to model's max 512 tokens | |
| # max_length=512 β DistilBERT's hard limit | |
| inputs = _tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True, | |
| ) | |
| inputs = {k: v.to(_device) for k, v in inputs.items()} | |
| # Forward pass β no gradient tracking needed at inference time. | |
| # torch.no_grad() saves memory and speeds things up. | |
| with torch.no_grad(): | |
| outputs = _model(**inputs, output_attentions=True) | |
| # outputs.logits shape: (1, n_classes) β raw unnormalised scores | |
| # softmax converts them to probabilities that sum to 1 | |
| proba = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() | |
| class_idx = int(np.argmax(proba)) | |
| score = float(proba[class_idx]) | |
| # The neutral class was trained on Twitter data that was actually pos/neg, | |
| # so the model over-predicts neutral for short opinionated text. | |
| # If neutral wins but with low confidence, defer to the stronger of pos/neg. | |
| NEUTRAL_IDX = 2 | |
| NEUTRAL_THRESHOLD = 0.60 | |
| if class_idx == NEUTRAL_IDX and score < NEUTRAL_THRESHOLD: | |
| class_idx = int(np.argmax(proba[:2])) # best of Negative(0) / Positive(1) | |
| score = float(proba[class_idx]) | |
| # Map class index β label string using the model's own id2label config | |
| label_int = _model.config.id2label[class_idx] # e.g. "LABEL_1" | |
| # Our fine-tuning saves numeric keys, so fall back to LABEL_NAMES | |
| label_str = LABEL_NAMES.get(class_idx, label_int) | |
| keywords = _extract_keywords_from_attention(outputs.attentions, inputs, top_n=10) | |
| return { | |
| "label": label_str, | |
| "score": score, | |
| "keywords": keywords, | |
| } | |
| def _extract_keywords_from_attention(attentions, inputs, top_n: int = 10) -> list: | |
| """ | |
| Derives the most important tokens using the last layer's attention weights. | |
| Steps: | |
| 1. Take the last transformer layer's attention tensor | |
| Shape: (1, n_heads, seq_len, seq_len) | |
| 2. Average across all 12 heads β shape: (seq_len, seq_len) | |
| 3. Sum each token's incoming attention (column sum) β this measures | |
| how much the rest of the sequence attended TO this token | |
| 4. Convert token IDs back to strings, skip special tokens | |
| ([CLS], [SEP], [PAD]) which always get high attention artificially | |
| 5. Return the top_n tokens by attention score | |
| Caveat (worth knowing): | |
| Attention weights β explanation. Research (Jain & Wallace 2019) shows | |
| attention doesn't always correlate with feature importance. | |
| For a demo this is fine; for production use SHAP or integrated gradients. | |
| """ | |
| # attentions is a tuple of tensors, one per layer β we want the last one | |
| last_layer_attn = attentions[-1] # (1, heads, seq, seq) | |
| avg_attn = last_layer_attn[0].mean(dim=0) # (seq, seq) | |
| token_scores = avg_attn.sum(dim=0).cpu().numpy() # (seq,) | |
| # Decode each token ID back to its string | |
| input_ids = inputs["input_ids"][0].cpu().numpy() | |
| special_ids = set(_tokenizer.all_special_ids) | |
| tokens = _tokenizer.convert_ids_to_tokens(input_ids) | |
| # Pair each token with its score, skip specials and subword prefixes | |
| scored = [] | |
| for token, score, tid in zip(tokens, token_scores, input_ids): | |
| if tid in special_ids: | |
| continue | |
| # WordPiece subword tokens start with "##" β strip the prefix | |
| clean = token.replace("##", "") | |
| if len(clean) < 2: # skip single characters | |
| continue | |
| scored.append((clean, float(score))) | |
| # Sort by score descending, deduplicate, return top_n | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| seen, keywords = set(), [] | |
| for word, _ in scored: | |
| if word not in seen: | |
| seen.add(word) | |
| keywords.append(word) | |
| if len(keywords) >= top_n: | |
| break | |
| return keywords | |