zbller's picture
Upload folder using huggingface_hub
34c8a90 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Base model with lexical features only."""
from typing import Optional
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseMecariGNN(pl.LightningModule):
"""Base class for Mecari morpheme GNNs."""
def __init__(
self,
hidden_dim: int = 512,
num_classes: int = 1,
learning_rate: float = 1e-3,
lexical_feature_dim: int = 100000,
) -> None:
super().__init__()
self.save_hyperparameters()
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.learning_rate = learning_rate
self.lexical_feature_dim = lexical_feature_dim
self.lexical_embedding = nn.Embedding(
num_embeddings=lexical_feature_dim, embedding_dim=hidden_dim, padding_idx=0, sparse=False
)
nn.init.xavier_uniform_(self.lexical_embedding.weight[1:])
self.lexical_embedding.weight.data[0].fill_(0)
self.lexical_norm = nn.LayerNorm(hidden_dim)
self.lexical_dropout = nn.Dropout(0.2)
self.residual_proj = nn.Linear(hidden_dim, hidden_dim)
self.node_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1)
)
def _process_features(
self, lexical_indices: torch.Tensor, lexical_values: torch.Tensor, bert_features: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Process lexical features."""
embedded = self.lexical_embedding(lexical_indices)
weighted = embedded * lexical_values.unsqueeze(-1)
aggregated = weighted.sum(dim=1)
processed = self.lexical_dropout(self.lexical_norm(aggregated))
return processed
def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None):
"""Forward pass (implemented in subclasses)."""
raise NotImplementedError("Subclasses must implement forward method")
def training_step(self, batch, batch_idx):
node_predictions = self(
batch.lexical_indices,
batch.lexical_values,
batch.edge_index,
None,
batch.edge_attr if hasattr(batch, "edge_attr") else None,
).squeeze()
valid_mask = batch.valid_mask
valid_predictions = node_predictions[valid_mask]
valid_targets = batch.y[valid_mask]
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="train")
with torch.no_grad():
pred_probs = torch.sigmoid(valid_predictions)
pred_binary = (pred_probs > 0.5).float()
correct = (pred_binary == valid_targets).sum()
accuracy = correct / valid_targets.numel()
error_rate = 1.0 - accuracy
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
self.log("train_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
if self.trainer and self.trainer.optimizers:
current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log("learning_rate", current_lr, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
node_predictions = self(
batch.lexical_indices,
batch.lexical_values,
batch.edge_index,
None,
batch.edge_attr if hasattr(batch, "edge_attr") else None,
).squeeze()
valid_mask = batch.valid_mask
valid_predictions = node_predictions[valid_mask]
valid_targets = batch.y[valid_mask]
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="val")
with torch.no_grad():
pred_probs = torch.sigmoid(valid_predictions)
pred_binary = (pred_probs > 0.5).float()
correct = (pred_binary == valid_targets).sum()
accuracy = correct / valid_targets.numel()
error_rate = 1.0 - accuracy
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
self.log("val_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
self.log("val_loss_epoch", loss, on_step=False, on_epoch=True)
self.log("val_error_epoch", error_rate, on_step=False, on_epoch=True)
return loss
def configure_optimizers(self):
"""Configure optimizer."""
optimizer_config = getattr(self, "training_config", {}).get("optimizer", {})
optimizer_type = optimizer_config.get("type", "adamw")
if optimizer_type == "adamw":
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.learning_rate, weight_decay=optimizer_config.get("weight_decay", 0.01)
)
else:
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# Optional warmup scheduler (linear warmup to base LR)
tc = getattr(self, "training_config", {}) or {}
warmup_steps = int(tc.get("warmup_steps", 0) or 0)
warmup_start_lr = float(tc.get("warmup_start_lr", 0.0) or 0.0)
if warmup_steps > 0 and self.learning_rate > 0.0:
start_factor = max(0.0, min(1.0, warmup_start_lr / float(self.learning_rate)))
def lr_lambda(step: int):
if step <= 0:
return start_factor
if step < warmup_steps:
return start_factor + (1.0 - start_factor) * (step / float(warmup_steps))
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
"name": "linear_warmup",
},
}
return {"optimizer": optimizer}
def test_step(self, batch, batch_idx):
node_predictions = self(
batch.lexical_indices,
batch.lexical_values,
batch.edge_index,
None,
batch.edge_attr if hasattr(batch, "edge_attr") else None,
).squeeze()
valid_mask = batch.valid_mask
valid_predictions = node_predictions[valid_mask]
valid_targets = batch.y[valid_mask]
with torch.no_grad():
pred_probs = torch.sigmoid(valid_predictions)
pred_binary = (pred_probs > 0.5).float()
correct = (pred_binary == valid_targets).sum()
accuracy = correct / valid_targets.numel()
error_rate = 1.0 - accuracy
self.log("test_error", error_rate, on_step=False, on_epoch=True)
self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
return error_rate
def _compute_bce_loss(self, logits: torch.Tensor, targets: torch.Tensor, stage: str = "train") -> torch.Tensor:
"""BCEWithLogits loss with optional label smoothing and pos_weight.
- label_smoothing: smooth targets toward 0.5 by epsilon.
- pos_weight: handle class imbalance using ratio (neg/pos) per batch, robustly.
"""
loss_cfg = getattr(self, "training_config", {}).get("loss", {})
eps = float(loss_cfg.get("label_smoothing", 0.0) or 0.0)
use_pos_weight = bool(loss_cfg.get("use_pos_weight", True))
# Compute pos_weight from unsmoothed targets
pos = torch.clamp(targets.sum(), min=0.0)
total = torch.tensor(targets.numel(), device=targets.device, dtype=targets.dtype)
neg = total - pos
pos_weight = None
if use_pos_weight and pos > 0 and neg > 0:
# pos_weight = neg/pos; clamp to avoid extreme values
pw = (neg / pos).detach()
pw = torch.clamp(pw, 0.5, 50.0) # safety bounds
pos_weight = pw
# Apply label smoothing to targets: y' = (1-eps)*y + 0.5*eps
if eps > 0.0:
targets = (1.0 - eps) * targets + 0.5 * eps
loss = F.binary_cross_entropy_with_logits(
logits,
targets,
pos_weight=pos_weight,
)
return loss