Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import random | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import DataLoader, Dataset | |
| from .predictor import DDIEmbeddingMLP, canonical_pair_key, normalize_name | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| DATA_PATH = BASE_DIR / 'data' / 'processed' / 'ddinter_combined.parquet' | |
| MODEL_DIR = BASE_DIR / 'models' | |
| MODEL_PATH = MODEL_DIR / 'ddi_mlp_best.pt' | |
| LABEL_NAMES = ['unknown', 'minor', 'moderate', 'major'] | |
| LABEL_TO_INDEX = {label: index for index, label in enumerate(LABEL_NAMES)} | |
| class PairDataset(Dataset): | |
| def __init__(self, examples: list[tuple[int, int, int]]) -> None: | |
| self.examples = examples | |
| def __len__(self) -> int: | |
| return len(self.examples) | |
| def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| drug_a_id, drug_b_id, label_id = self.examples[index] | |
| return ( | |
| torch.tensor(drug_a_id, dtype=torch.long), | |
| torch.tensor(drug_b_id, dtype=torch.long), | |
| torch.tensor(label_id, dtype=torch.long), | |
| ) | |
| def load_and_aggregate_dataset() -> list[dict[str, str]]: | |
| from preprocessing.artifact_manager import manager | |
| import pandas as pd | |
| df = manager.load_artifact('ddinter_combined') | |
| pair_records: dict[tuple[str, str], Counter] = defaultdict(Counter) | |
| support_records: dict[tuple[str, str], int] = defaultdict(int) | |
| canonical_names: dict[tuple[str, str], tuple[str, str]] = {} | |
| for _, row in df.iterrows(): | |
| try: | |
| drug_a = str(row.get('canonical_drug_a') or row.get('Drug_A', '')).strip() | |
| drug_b = str(row.get('canonical_drug_b') or row.get('Drug_B', '')).strip() | |
| severity = str(row.get('Level') or row.get('level', 'Unknown')).strip().lower() | |
| if severity not in LABEL_NAMES: | |
| severity = 'unknown' | |
| key = canonical_pair_key(drug_a, drug_b) | |
| pair_records[key][severity] += 1 | |
| support_records[key] += 1 | |
| canonical_names.setdefault(key, (drug_a, drug_b)) | |
| except Exception: | |
| continue | |
| examples: list[dict[str, str]] = [] | |
| for key, counter in pair_records.items(): | |
| severity = max(counter.items(), key=lambda item: (item[1], LABEL_TO_INDEX.get(item[0], 0)))[0] | |
| drug_a, drug_b = canonical_names[key] | |
| examples.append( | |
| { | |
| 'drug_a': drug_a, | |
| 'drug_b': drug_b, | |
| 'severity': severity, | |
| 'support_count': str(support_records[key]), | |
| } | |
| ) | |
| return examples | |
| def build_vocabulary(examples: list[dict[str, str]]) -> dict[str, int]: | |
| vocab: dict[str, int] = {} | |
| for example in examples: | |
| for drug_name in (example['drug_a'], example['drug_b']): | |
| normalized = normalize_name(drug_name) | |
| if normalized not in vocab: | |
| vocab[normalized] = len(vocab) + 1 | |
| return vocab | |
| def encode_examples(examples: list[dict[str, str]], vocab: dict[str, int]) -> list[tuple[int, int, int]]: | |
| encoded_examples: list[tuple[int, int, int]] = [] | |
| for example in examples: | |
| drug_a_id = vocab.get(normalize_name(example['drug_a']), 0) | |
| drug_b_id = vocab.get(normalize_name(example['drug_b']), 0) | |
| label_id = LABEL_TO_INDEX.get(example['severity'], 0) | |
| encoded_examples.append((drug_a_id, drug_b_id, label_id)) | |
| return encoded_examples | |
| def compute_class_weights(labels: list[int]) -> torch.Tensor: | |
| counts = Counter(labels) | |
| total = sum(counts.values()) | |
| weights = [] | |
| for index in range(len(LABEL_NAMES)): | |
| class_count = max(counts.get(index, 1), 1) | |
| weight = total / (len(LABEL_NAMES) * class_count) | |
| weights.append(weight) | |
| return torch.tensor(weights, dtype=torch.float32) | |
| def split_examples(examples: list[tuple[int, int, int]], seed: int = 42) -> tuple[list, list]: | |
| shuffled = examples[:] | |
| random.Random(seed).shuffle(shuffled) | |
| split_index = max(1, int(len(shuffled) * 0.9)) | |
| return shuffled[:split_index], shuffled[split_index:] | |
| def evaluate(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module) -> tuple[float, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_items = 0 | |
| with torch.no_grad(): | |
| for drug_a_ids, drug_b_ids, labels in dataloader: | |
| logits = model(drug_a_ids, drug_b_ids) | |
| loss = loss_fn(logits, labels) | |
| predictions = torch.argmax(logits, dim=-1) | |
| total_loss += float(loss.item()) * labels.size(0) | |
| total_correct += int((predictions == labels).sum().item()) | |
| total_items += int(labels.size(0)) | |
| average_loss = total_loss / max(total_items, 1) | |
| accuracy = total_correct / max(total_items, 1) | |
| return average_loss, accuracy | |
| def train() -> dict[str, object]: | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| examples = load_and_aggregate_dataset() | |
| vocab = build_vocabulary(examples) | |
| encoded_examples = encode_examples(examples, vocab) | |
| train_examples, valid_examples = split_examples(encoded_examples) | |
| train_labels = [label_id for _, _, label_id in train_examples] | |
| class_weights = compute_class_weights(train_labels) | |
| train_dataset = PairDataset(train_examples) | |
| valid_dataset = PairDataset(valid_examples) | |
| train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True) | |
| valid_loader = DataLoader(valid_dataset, batch_size=4096, shuffle=False) | |
| model = DDIEmbeddingMLP( | |
| vocab_size=len(vocab) + 1, | |
| embedding_dim=64, | |
| hidden_dim=128, | |
| num_classes=len(LABEL_NAMES), | |
| dropout=0.2, | |
| ) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4) | |
| loss_fn = nn.CrossEntropyLoss(weight=class_weights) | |
| best_state = None | |
| best_accuracy = -1.0 | |
| history: list[dict[str, float]] = [] | |
| for epoch in range(4): | |
| model.train() | |
| running_loss = 0.0 | |
| running_correct = 0 | |
| running_items = 0 | |
| for drug_a_ids, drug_b_ids, labels in train_loader: | |
| optimizer.zero_grad(set_to_none=True) | |
| logits = model(drug_a_ids, drug_b_ids) | |
| loss = loss_fn(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| predictions = torch.argmax(logits, dim=-1) | |
| running_loss += float(loss.item()) * labels.size(0) | |
| running_correct += int((predictions == labels).sum().item()) | |
| running_items += int(labels.size(0)) | |
| train_loss = running_loss / max(running_items, 1) | |
| train_accuracy = running_correct / max(running_items, 1) | |
| valid_loss, valid_accuracy = evaluate(model, valid_loader, loss_fn) | |
| history.append( | |
| { | |
| 'epoch': float(epoch + 1), | |
| 'train_loss': float(train_loss), | |
| 'train_accuracy': float(train_accuracy), | |
| 'valid_loss': float(valid_loss), | |
| 'valid_accuracy': float(valid_accuracy), | |
| } | |
| ) | |
| if valid_accuracy >= best_accuracy: | |
| best_accuracy = valid_accuracy | |
| best_state = {key: value.cpu() for key, value in model.state_dict().items()} | |
| print( | |
| f'epoch={epoch + 1} train_loss={train_loss:.4f} train_acc={train_accuracy:.4f} ' | |
| f'valid_loss={valid_loss:.4f} valid_acc={valid_accuracy:.4f}' | |
| ) | |
| if best_state is None: | |
| best_state = {key: value.cpu() for key, value in model.state_dict().items()} | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| checkpoint = { | |
| 'model_version': 'medcare-ddi-mlp-v1', | |
| 'embedding_dim': 64, | |
| 'hidden_dim': 128, | |
| 'label_names': LABEL_NAMES, | |
| 'label_to_index': LABEL_TO_INDEX, | |
| 'index_to_label': {index: label for index, label in enumerate(LABEL_NAMES)}, | |
| 'drug_vocab': vocab, | |
| 'model_state_dict': best_state, | |
| 'training_history': history, | |
| 'best_validation_accuracy': float(best_accuracy), | |
| 'dataset_size': len(encoded_examples), | |
| 'vocab_size': len(vocab), | |
| } | |
| torch.save(checkpoint, MODEL_PATH) | |
| summary_path = MODEL_DIR / 'ddi_mlp_best.summary.json' | |
| summary_path.write_text( | |
| json.dumps( | |
| { | |
| 'model_version': checkpoint['model_version'], | |
| 'best_validation_accuracy': checkpoint['best_validation_accuracy'], | |
| 'dataset_size': checkpoint['dataset_size'], | |
| 'vocab_size': checkpoint['vocab_size'], | |
| 'training_history': history, | |
| }, | |
| indent=2, | |
| ), | |
| encoding='utf-8', | |
| ) | |
| return checkpoint | |
| if __name__ == '__main__': | |
| train() | |