AspectBERT / src /model.py
itismeTithi's picture
Deploy AspectBERT Streamlit app
31f6bcb
raw
history blame contribute delete
3.63 kB
"""AspectBERT model: DistilBERT backbone + custom classification head.
Architecture:
- distilbert-base-uncased backbone (6 transformer layers)
- First 4 transformer layers (and embeddings) frozen, last 2 fine-tuned
- Classification head: Linear(768->256) -> GELU -> Dropout(0.2) -> Linear(256->3)
The [CLS] token's last hidden state is fed to the classification head to
produce 3-way (negative/neutral/positive) sentiment logits per
"{review_text} aspect: {aspect_name}" input.
"""
import os
import sys
import torch
import torch.nn as nn
from transformers import DistilBertModel
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from constants import MAX_LENGTH, MODEL_NAME, NUM_LABELS # noqa: E402
class AspectBERT(nn.Module):
def __init__(self, model_name=MODEL_NAME, num_labels=NUM_LABELS, freeze_layers=4):
super().__init__()
self.distilbert = DistilBertModel.from_pretrained(model_name)
hidden_size = self.distilbert.config.dim # 768 for distilbert-base
self.classifier = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(256, num_labels),
)
self._freeze_layers(freeze_layers)
def _freeze_layers(self, n_frozen):
"""Freeze embeddings and the first `n_frozen` transformer layers."""
for param in self.distilbert.embeddings.parameters():
param.requires_grad = False
for i, layer in enumerate(self.distilbert.transformer.layer):
if i < n_frozen:
for param in layer.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
cls_token = outputs.last_hidden_state[:, 0, :] # [batch, hidden_size]
logits = self.classifier(cls_token) # [batch, num_labels]
return logits
def trainable_parameter_summary(self):
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {"total_params": total, "trainable_params": trainable,
"trainable_pct": 100.0 * trainable / total}
if __name__ == "__main__":
from transformers import DistilBertTokenizerFast
print("Building AspectBERT model...")
model = AspectBERT()
model.eval()
summary = model.trainable_parameter_summary()
print(f"Total params: {summary['total_params']:,}")
print(f"Trainable params: {summary['trainable_params']:,} "
f"({summary['trainable_pct']:.2f}%)")
print("\nFrozen vs trainable transformer layers:")
for i, layer in enumerate(model.distilbert.transformer.layer):
any_trainable = any(p.requires_grad for p in layer.parameters())
print(f" layer {i}: {'trainable' if any_trainable else 'frozen'}")
print("\nRunning a forward pass with dummy input...")
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)
text = "The battery life is amazing and lasts all day. aspect: battery"
enc = tokenizer(text, truncation=True, padding="max_length",
max_length=MAX_LENGTH, return_tensors="pt")
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = torch.softmax(logits, dim=-1)
print(f"Input: {text!r}")
print(f"Logits shape: {tuple(logits.shape)}")
print(f"Logits: {logits.tolist()}")
print(f"Probabilities (negative/neutral/positive): {probs.tolist()}")
print("\nForward pass OK.")