"""AspectBERT model: DistilBERT backbone + custom classification head. Architecture: - distilbert-base-uncased backbone (6 transformer layers) - First 4 transformer layers (and embeddings) frozen, last 2 fine-tuned - Classification head: Linear(768->256) -> GELU -> Dropout(0.2) -> Linear(256->3) The [CLS] token's last hidden state is fed to the classification head to produce 3-way (negative/neutral/positive) sentiment logits per "{review_text} aspect: {aspect_name}" input. """ import os import sys import torch import torch.nn as nn from transformers import DistilBertModel sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from constants import MAX_LENGTH, MODEL_NAME, NUM_LABELS # noqa: E402 class AspectBERT(nn.Module): def __init__(self, model_name=MODEL_NAME, num_labels=NUM_LABELS, freeze_layers=4): super().__init__() self.distilbert = DistilBertModel.from_pretrained(model_name) hidden_size = self.distilbert.config.dim # 768 for distilbert-base self.classifier = nn.Sequential( nn.Linear(hidden_size, 256), nn.GELU(), nn.Dropout(0.2), nn.Linear(256, num_labels), ) self._freeze_layers(freeze_layers) def _freeze_layers(self, n_frozen): """Freeze embeddings and the first `n_frozen` transformer layers.""" for param in self.distilbert.embeddings.parameters(): param.requires_grad = False for i, layer in enumerate(self.distilbert.transformer.layer): if i < n_frozen: for param in layer.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask): outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask) cls_token = outputs.last_hidden_state[:, 0, :] # [batch, hidden_size] logits = self.classifier(cls_token) # [batch, num_labels] return logits def trainable_parameter_summary(self): total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return {"total_params": total, "trainable_params": trainable, "trainable_pct": 100.0 * trainable / total} if __name__ == "__main__": from transformers import DistilBertTokenizerFast print("Building AspectBERT model...") model = AspectBERT() model.eval() summary = model.trainable_parameter_summary() print(f"Total params: {summary['total_params']:,}") print(f"Trainable params: {summary['trainable_params']:,} " f"({summary['trainable_pct']:.2f}%)") print("\nFrozen vs trainable transformer layers:") for i, layer in enumerate(model.distilbert.transformer.layer): any_trainable = any(p.requires_grad for p in layer.parameters()) print(f" layer {i}: {'trainable' if any_trainable else 'frozen'}") print("\nRunning a forward pass with dummy input...") tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME) text = "The battery life is amazing and lasts all day. aspect: battery" enc = tokenizer(text, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt") with torch.no_grad(): logits = model(enc["input_ids"], enc["attention_mask"]) probs = torch.softmax(logits, dim=-1) print(f"Input: {text!r}") print(f"Logits shape: {tuple(logits.shape)}") print(f"Logits: {logits.tolist()}") print(f"Probabilities (negative/neutral/positive): {probs.tolist()}") print("\nForward pass OK.")