|
|
|
|
|
|
|
|
"""Base model with lexical features only.""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class BaseMecariGNN(pl.LightningModule): |
|
|
"""Base class for Mecari morpheme GNNs.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_dim: int = 512, |
|
|
num_classes: int = 1, |
|
|
learning_rate: float = 1e-3, |
|
|
lexical_feature_dim: int = 100000, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
|
|
|
self.hidden_dim = hidden_dim |
|
|
self.num_classes = num_classes |
|
|
self.learning_rate = learning_rate |
|
|
self.lexical_feature_dim = lexical_feature_dim |
|
|
|
|
|
self.lexical_embedding = nn.Embedding( |
|
|
num_embeddings=lexical_feature_dim, embedding_dim=hidden_dim, padding_idx=0, sparse=False |
|
|
) |
|
|
nn.init.xavier_uniform_(self.lexical_embedding.weight[1:]) |
|
|
self.lexical_embedding.weight.data[0].fill_(0) |
|
|
|
|
|
self.lexical_norm = nn.LayerNorm(hidden_dim) |
|
|
self.lexical_dropout = nn.Dropout(0.2) |
|
|
|
|
|
self.residual_proj = nn.Linear(hidden_dim, hidden_dim) |
|
|
|
|
|
self.node_classifier = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1) |
|
|
) |
|
|
|
|
|
def _process_features( |
|
|
self, lexical_indices: torch.Tensor, lexical_values: torch.Tensor, bert_features: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
"""Process lexical features.""" |
|
|
embedded = self.lexical_embedding(lexical_indices) |
|
|
weighted = embedded * lexical_values.unsqueeze(-1) |
|
|
aggregated = weighted.sum(dim=1) |
|
|
processed = self.lexical_dropout(self.lexical_norm(aggregated)) |
|
|
return processed |
|
|
|
|
|
def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None): |
|
|
"""Forward pass (implemented in subclasses).""" |
|
|
raise NotImplementedError("Subclasses must implement forward method") |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
node_predictions = self( |
|
|
batch.lexical_indices, |
|
|
batch.lexical_values, |
|
|
batch.edge_index, |
|
|
None, |
|
|
batch.edge_attr if hasattr(batch, "edge_attr") else None, |
|
|
).squeeze() |
|
|
|
|
|
valid_mask = batch.valid_mask |
|
|
valid_predictions = node_predictions[valid_mask] |
|
|
valid_targets = batch.y[valid_mask] |
|
|
|
|
|
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="train") |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_probs = torch.sigmoid(valid_predictions) |
|
|
pred_binary = (pred_probs > 0.5).float() |
|
|
correct = (pred_binary == valid_targets).sum() |
|
|
accuracy = correct / valid_targets.numel() |
|
|
error_rate = 1.0 - accuracy |
|
|
|
|
|
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True) |
|
|
self.log("train_error", error_rate, prog_bar=True, on_step=True, on_epoch=True) |
|
|
|
|
|
if self.trainer and self.trainer.optimizers: |
|
|
current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] |
|
|
self.log("learning_rate", current_lr, on_step=True, on_epoch=False) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
node_predictions = self( |
|
|
batch.lexical_indices, |
|
|
batch.lexical_values, |
|
|
batch.edge_index, |
|
|
None, |
|
|
batch.edge_attr if hasattr(batch, "edge_attr") else None, |
|
|
).squeeze() |
|
|
|
|
|
valid_mask = batch.valid_mask |
|
|
valid_predictions = node_predictions[valid_mask] |
|
|
valid_targets = batch.y[valid_mask] |
|
|
|
|
|
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="val") |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_probs = torch.sigmoid(valid_predictions) |
|
|
pred_binary = (pred_probs > 0.5).float() |
|
|
correct = (pred_binary == valid_targets).sum() |
|
|
accuracy = correct / valid_targets.numel() |
|
|
error_rate = 1.0 - accuracy |
|
|
|
|
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
|
|
self.log("val_error", error_rate, prog_bar=True, on_step=True, on_epoch=True) |
|
|
|
|
|
self.log("val_loss_epoch", loss, on_step=False, on_epoch=True) |
|
|
self.log("val_error_epoch", error_rate, on_step=False, on_epoch=True) |
|
|
|
|
|
return loss |
|
|
|
|
|
def configure_optimizers(self): |
|
|
"""Configure optimizer.""" |
|
|
optimizer_config = getattr(self, "training_config", {}).get("optimizer", {}) |
|
|
optimizer_type = optimizer_config.get("type", "adamw") |
|
|
|
|
|
if optimizer_type == "adamw": |
|
|
optimizer = torch.optim.AdamW( |
|
|
self.parameters(), lr=self.learning_rate, weight_decay=optimizer_config.get("weight_decay", 0.01) |
|
|
) |
|
|
else: |
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
|
|
|
|
|
tc = getattr(self, "training_config", {}) or {} |
|
|
warmup_steps = int(tc.get("warmup_steps", 0) or 0) |
|
|
warmup_start_lr = float(tc.get("warmup_start_lr", 0.0) or 0.0) |
|
|
if warmup_steps > 0 and self.learning_rate > 0.0: |
|
|
start_factor = max(0.0, min(1.0, warmup_start_lr / float(self.learning_rate))) |
|
|
|
|
|
def lr_lambda(step: int): |
|
|
if step <= 0: |
|
|
return start_factor |
|
|
if step < warmup_steps: |
|
|
return start_factor + (1.0 - start_factor) * (step / float(warmup_steps)) |
|
|
return 1.0 |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) |
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"interval": "step", |
|
|
"frequency": 1, |
|
|
"name": "linear_warmup", |
|
|
}, |
|
|
} |
|
|
return {"optimizer": optimizer} |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
node_predictions = self( |
|
|
batch.lexical_indices, |
|
|
batch.lexical_values, |
|
|
batch.edge_index, |
|
|
None, |
|
|
batch.edge_attr if hasattr(batch, "edge_attr") else None, |
|
|
).squeeze() |
|
|
|
|
|
valid_mask = batch.valid_mask |
|
|
valid_predictions = node_predictions[valid_mask] |
|
|
valid_targets = batch.y[valid_mask] |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_probs = torch.sigmoid(valid_predictions) |
|
|
pred_binary = (pred_probs > 0.5).float() |
|
|
correct = (pred_binary == valid_targets).sum() |
|
|
accuracy = correct / valid_targets.numel() |
|
|
error_rate = 1.0 - accuracy |
|
|
|
|
|
self.log("test_error", error_rate, on_step=False, on_epoch=True) |
|
|
self.log("test_accuracy", accuracy, on_step=False, on_epoch=True) |
|
|
|
|
|
return error_rate |
|
|
|
|
|
def _compute_bce_loss(self, logits: torch.Tensor, targets: torch.Tensor, stage: str = "train") -> torch.Tensor: |
|
|
"""BCEWithLogits loss with optional label smoothing and pos_weight. |
|
|
|
|
|
- label_smoothing: smooth targets toward 0.5 by epsilon. |
|
|
- pos_weight: handle class imbalance using ratio (neg/pos) per batch, robustly. |
|
|
""" |
|
|
loss_cfg = getattr(self, "training_config", {}).get("loss", {}) |
|
|
eps = float(loss_cfg.get("label_smoothing", 0.0) or 0.0) |
|
|
use_pos_weight = bool(loss_cfg.get("use_pos_weight", True)) |
|
|
|
|
|
|
|
|
pos = torch.clamp(targets.sum(), min=0.0) |
|
|
total = torch.tensor(targets.numel(), device=targets.device, dtype=targets.dtype) |
|
|
neg = total - pos |
|
|
pos_weight = None |
|
|
if use_pos_weight and pos > 0 and neg > 0: |
|
|
|
|
|
pw = (neg / pos).detach() |
|
|
pw = torch.clamp(pw, 0.5, 50.0) |
|
|
pos_weight = pw |
|
|
|
|
|
|
|
|
if eps > 0.0: |
|
|
targets = (1.0 - eps) * targets + 0.5 * eps |
|
|
|
|
|
loss = F.binary_cross_entropy_with_logits( |
|
|
logits, |
|
|
targets, |
|
|
pos_weight=pos_weight, |
|
|
) |
|
|
|
|
|
return loss |
|
|
|