""" Model architecture and loading utilities. This file contains the exact same model architecture used during training. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer import numpy as np import logging logger = logging.getLogger(__name__) class HierarchicalChemBERTa(nn.Module): """ ChemBERTa encoder + Hierarchical classification heads. MUST match the training architecture exactly. """ def __init__(self, model_name, num_pathways, num_superclasses, num_classes, hidden_dim=512, dropout=0.3, pathway_to_superclass=None, superclass_to_class=None): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) enc_dim = self.encoder.config.hidden_size self.shared = nn.Sequential( nn.Linear(enc_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout) ) self.pathway_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, num_pathways) ) self.superclass_head = nn.Sequential( nn.Linear(hidden_dim + num_pathways, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, num_superclasses) ) self.class_head = nn.Sequential( nn.Linear(hidden_dim + num_pathways + num_superclasses, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) if pathway_to_superclass is not None: self.register_buffer('p2s_mask', pathway_to_superclass) else: self.p2s_mask = None if superclass_to_class is not None: self.register_buffer('s2c_mask', superclass_to_class) else: self.s2c_mask = None def forward(self, input_ids, attention_mask, apply_mask=True): encoder_output = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) cls_emb = encoder_output.last_hidden_state[:, 0, :] shared = self.shared(cls_emb) pathway_logits = self.pathway_head(shared) pathway_probs = F.softmax(pathway_logits, dim=-1) sc_input = torch.cat([shared, pathway_probs], dim=-1) superclass_logits = self.superclass_head(sc_input) if apply_mask and self.p2s_mask is not None: valid_sc = torch.matmul(pathway_probs, self.p2s_mask) superclass_logits = superclass_logits + torch.log(valid_sc.clamp(min=1e-8)) superclass_probs = F.softmax(superclass_logits, dim=-1) cl_input = torch.cat([shared, pathway_probs, superclass_probs], dim=-1) class_logits = self.class_head(cl_input) if apply_mask and self.s2c_mask is not None: valid_cl = torch.matmul(superclass_probs, self.s2c_mask) class_logits = class_logits + torch.log(valid_cl.clamp(min=1e-8)) return { 'pathway_logits': pathway_logits, 'superclass_logits': superclass_logits, 'class_logits': class_logits, } class NPClassifierService: """ Service class that handles model loading and prediction. Loaded once at server startup, used for all requests. """ def __init__(self, model_path, chemberta_name, device='cpu', max_length=256, top_n=5): self.device = torch.device(device) self.max_length = max_length self.top_n = top_n self.model = None self.tokenizer = None self.label_encoders = {} self.model_info = {} self._load_model(model_path, chemberta_name) def _load_model(self, model_path, chemberta_name): """Load model, tokenizer, and label encoders from the saved package.""" logger.info(f"Loading model from {model_path}...") checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) model_config = checkpoint['model_config'] label_encoders = checkpoint['label_encoders'] hierarchy = checkpoint['hierarchy'] logger.info(f"Model config: {model_config}") logger.info(f"Pathways: {model_config['num_pathways']}, " f"Superclasses: {model_config['num_superclasses']}, " f"Classes: {model_config['num_classes']}") self.label_encoders = { 'pathway': np.array(label_encoders['pathway']), 'superclass': np.array(label_encoders['superclass']), 'class': np.array(label_encoders['class']), } self.model_info = { 'model_name': model_config['model_name'], 'num_pathways': model_config['num_pathways'], 'num_superclasses': model_config['num_superclasses'], 'num_classes': model_config['num_classes'], 'hidden_dim': model_config['hidden_dim'], 'dropout': model_config['dropout'], 'device': str(self.device), 'pathway_classes': label_encoders['pathway'], 'superclass_classes': label_encoders['superclass'], } if 'metrics' in checkpoint: self.model_info['metrics'] = checkpoint['metrics'] p2s = hierarchy['pathway_to_superclass'].to(self.device) s2c = hierarchy['superclass_to_class'].to(self.device) self.model = HierarchicalChemBERTa( model_name=chemberta_name, num_pathways=model_config['num_pathways'], num_superclasses=model_config['num_superclasses'], num_classes=model_config['num_classes'], hidden_dim=model_config['hidden_dim'], dropout=model_config['dropout'], pathway_to_superclass=p2s, superclass_to_class=s2c, ).to(self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(chemberta_name) total_params = sum(p.numel() for p in self.model.parameters()) self.model_info['total_params'] = f"{total_params:,}" logger.info(f"Model loaded successfully! ({total_params:,} parameters)") logger.info(f"Device: {self.device}") @torch.no_grad() def predict(self, smiles_list, top_n=None): """ Predict Pathway, Superclass, Class for a list of SMILES. """ if top_n is None: top_n = self.top_n results = [] for smi in smiles_list: try: smi_str = str(smi).strip() if not smi_str: results.append(self._error_result(smi, "Empty SMILES string")) continue encoding = self.tokenizer( smi_str, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(self.device) attention_mask = encoding['attention_mask'].to(self.device) outputs = self.model(input_ids, attention_mask) result = { 'smiles': smi_str, 'status': 'success', 'predictions': {} } for level, logits_key, le_classes in [ ('pathway', 'pathway_logits', self.label_encoders['pathway']), ('superclass', 'superclass_logits', self.label_encoders['superclass']), ('class', 'class_logits', self.label_encoders['class']), ]: probs = F.softmax( outputs[logits_key], dim=-1 ).cpu().numpy()[0] top_indices = np.argsort(probs)[::-1][:top_n] top_predictions = [] for rank, idx in enumerate(top_indices): top_predictions.append({ 'rank': rank + 1, 'label': str(le_classes[idx]), 'confidence': round(float(probs[idx]), 6), }) result['predictions'][level] = { 'predicted': str(le_classes[top_indices[0]]), 'confidence': round( float(probs[top_indices[0]]), 6 ), 'top_n': top_predictions, } results.append(result) except Exception as e: logger.error(f"Error predicting {smi}: {e}") results.append(self._error_result(smi, str(e))) return results def _error_result(self, smiles, error_msg): return { 'smiles': str(smiles), 'status': 'error', 'error': error_msg, 'predictions': {} } def get_model_info(self): """Return model metadata for the about page.""" return self.model_info