from __future__ import annotations import json import random from collections import Counter, defaultdict from pathlib import Path import torch from torch import nn from torch.utils.data import DataLoader, Dataset from .predictor import DDIEmbeddingMLP, canonical_pair_key, normalize_name BASE_DIR = Path(__file__).resolve().parents[2] DATA_PATH = BASE_DIR / 'data' / 'processed' / 'ddinter_combined.parquet' MODEL_DIR = BASE_DIR / 'models' MODEL_PATH = MODEL_DIR / 'ddi_mlp_best.pt' LABEL_NAMES = ['unknown', 'minor', 'moderate', 'major'] LABEL_TO_INDEX = {label: index for index, label in enumerate(LABEL_NAMES)} class PairDataset(Dataset): def __init__(self, examples: list[tuple[int, int, int]]) -> None: self.examples = examples def __len__(self) -> int: return len(self.examples) def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: drug_a_id, drug_b_id, label_id = self.examples[index] return ( torch.tensor(drug_a_id, dtype=torch.long), torch.tensor(drug_b_id, dtype=torch.long), torch.tensor(label_id, dtype=torch.long), ) def load_and_aggregate_dataset() -> list[dict[str, str]]: from preprocessing.artifact_manager import manager import pandas as pd df = manager.load_artifact('ddinter_combined') pair_records: dict[tuple[str, str], Counter] = defaultdict(Counter) support_records: dict[tuple[str, str], int] = defaultdict(int) canonical_names: dict[tuple[str, str], tuple[str, str]] = {} for _, row in df.iterrows(): try: drug_a = str(row.get('canonical_drug_a') or row.get('Drug_A', '')).strip() drug_b = str(row.get('canonical_drug_b') or row.get('Drug_B', '')).strip() severity = str(row.get('Level') or row.get('level', 'Unknown')).strip().lower() if severity not in LABEL_NAMES: severity = 'unknown' key = canonical_pair_key(drug_a, drug_b) pair_records[key][severity] += 1 support_records[key] += 1 canonical_names.setdefault(key, (drug_a, drug_b)) except Exception: continue examples: list[dict[str, str]] = [] for key, counter in pair_records.items(): severity = max(counter.items(), key=lambda item: (item[1], LABEL_TO_INDEX.get(item[0], 0)))[0] drug_a, drug_b = canonical_names[key] examples.append( { 'drug_a': drug_a, 'drug_b': drug_b, 'severity': severity, 'support_count': str(support_records[key]), } ) return examples def build_vocabulary(examples: list[dict[str, str]]) -> dict[str, int]: vocab: dict[str, int] = {} for example in examples: for drug_name in (example['drug_a'], example['drug_b']): normalized = normalize_name(drug_name) if normalized not in vocab: vocab[normalized] = len(vocab) + 1 return vocab def encode_examples(examples: list[dict[str, str]], vocab: dict[str, int]) -> list[tuple[int, int, int]]: encoded_examples: list[tuple[int, int, int]] = [] for example in examples: drug_a_id = vocab.get(normalize_name(example['drug_a']), 0) drug_b_id = vocab.get(normalize_name(example['drug_b']), 0) label_id = LABEL_TO_INDEX.get(example['severity'], 0) encoded_examples.append((drug_a_id, drug_b_id, label_id)) return encoded_examples def compute_class_weights(labels: list[int]) -> torch.Tensor: counts = Counter(labels) total = sum(counts.values()) weights = [] for index in range(len(LABEL_NAMES)): class_count = max(counts.get(index, 1), 1) weight = total / (len(LABEL_NAMES) * class_count) weights.append(weight) return torch.tensor(weights, dtype=torch.float32) def split_examples(examples: list[tuple[int, int, int]], seed: int = 42) -> tuple[list, list]: shuffled = examples[:] random.Random(seed).shuffle(shuffled) split_index = max(1, int(len(shuffled) * 0.9)) return shuffled[:split_index], shuffled[split_index:] def evaluate(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> tuple[float, float]: model.eval() total_loss = 0.0 total_correct = 0 total_items = 0 with torch.no_grad(): for drug_a_ids, drug_b_ids, labels in dataloader: logits = model(drug_a_ids, drug_b_ids) loss = loss_fn(logits, labels) predictions = torch.argmax(logits, dim=-1) total_loss += float(loss.item()) * labels.size(0) total_correct += int((predictions == labels).sum().item()) total_items += int(labels.size(0)) average_loss = total_loss / max(total_items, 1) accuracy = total_correct / max(total_items, 1) return average_loss, accuracy def train() -> dict[str, object]: random.seed(42) torch.manual_seed(42) examples = load_and_aggregate_dataset() vocab = build_vocabulary(examples) encoded_examples = encode_examples(examples, vocab) train_examples, valid_examples = split_examples(encoded_examples) train_labels = [label_id for _, _, label_id in train_examples] class_weights = compute_class_weights(train_labels) train_dataset = PairDataset(train_examples) valid_dataset = PairDataset(valid_examples) train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=4096, shuffle=False) model = DDIEmbeddingMLP( vocab_size=len(vocab) + 1, embedding_dim=64, hidden_dim=128, num_classes=len(LABEL_NAMES), dropout=0.2, ) optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4) loss_fn = nn.CrossEntropyLoss(weight=class_weights) best_state = None best_accuracy = -1.0 history: list[dict[str, float]] = [] for epoch in range(4): model.train() running_loss = 0.0 running_correct = 0 running_items = 0 for drug_a_ids, drug_b_ids, labels in train_loader: optimizer.zero_grad(set_to_none=True) logits = model(drug_a_ids, drug_b_ids) loss = loss_fn(logits, labels) loss.backward() optimizer.step() predictions = torch.argmax(logits, dim=-1) running_loss += float(loss.item()) * labels.size(0) running_correct += int((predictions == labels).sum().item()) running_items += int(labels.size(0)) train_loss = running_loss / max(running_items, 1) train_accuracy = running_correct / max(running_items, 1) valid_loss, valid_accuracy = evaluate(model, valid_loader, loss_fn) history.append( { 'epoch': float(epoch + 1), 'train_loss': float(train_loss), 'train_accuracy': float(train_accuracy), 'valid_loss': float(valid_loss), 'valid_accuracy': float(valid_accuracy), } ) if valid_accuracy >= best_accuracy: best_accuracy = valid_accuracy best_state = {key: value.cpu() for key, value in model.state_dict().items()} print( f'epoch={epoch + 1} train_loss={train_loss:.4f} train_acc={train_accuracy:.4f} ' f'valid_loss={valid_loss:.4f} valid_acc={valid_accuracy:.4f}' ) if best_state is None: best_state = {key: value.cpu() for key, value in model.state_dict().items()} MODEL_DIR.mkdir(parents=True, exist_ok=True) checkpoint = { 'model_version': 'medcare-ddi-mlp-v1', 'embedding_dim': 64, 'hidden_dim': 128, 'label_names': LABEL_NAMES, 'label_to_index': LABEL_TO_INDEX, 'index_to_label': {index: label for index, label in enumerate(LABEL_NAMES)}, 'drug_vocab': vocab, 'model_state_dict': best_state, 'training_history': history, 'best_validation_accuracy': float(best_accuracy), 'dataset_size': len(encoded_examples), 'vocab_size': len(vocab), } torch.save(checkpoint, MODEL_PATH) summary_path = MODEL_DIR / 'ddi_mlp_best.summary.json' summary_path.write_text( json.dumps( { 'model_version': checkpoint['model_version'], 'best_validation_accuracy': checkpoint['best_validation_accuracy'], 'dataset_size': checkpoint['dataset_size'], 'vocab_size': checkpoint['vocab_size'], 'training_history': history, }, indent=2, ), encoding='utf-8', ) return checkpoint if __name__ == '__main__': train()