Spaces:
Sleeping
Sleeping
| """ | |
| Model architecture and loading utilities. | |
| This file contains the exact same model architecture used during training. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoTokenizer | |
| import numpy as np | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class HierarchicalChemBERTa(nn.Module): | |
| """ | |
| ChemBERTa encoder + Hierarchical classification heads. | |
| MUST match the training architecture exactly. | |
| """ | |
| def __init__(self, model_name, num_pathways, num_superclasses, num_classes, | |
| hidden_dim=512, dropout=0.3, | |
| pathway_to_superclass=None, superclass_to_class=None): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| enc_dim = self.encoder.config.hidden_size | |
| self.shared = nn.Sequential( | |
| nn.Linear(enc_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout) | |
| ) | |
| self.pathway_head = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.LayerNorm(hidden_dim // 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim // 2, num_pathways) | |
| ) | |
| self.superclass_head = nn.Sequential( | |
| nn.Linear(hidden_dim + num_pathways, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.LayerNorm(hidden_dim // 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim // 2, num_superclasses) | |
| ) | |
| self.class_head = nn.Sequential( | |
| nn.Linear(hidden_dim + num_pathways + num_superclasses, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, num_classes) | |
| ) | |
| if pathway_to_superclass is not None: | |
| self.register_buffer('p2s_mask', pathway_to_superclass) | |
| else: | |
| self.p2s_mask = None | |
| if superclass_to_class is not None: | |
| self.register_buffer('s2c_mask', superclass_to_class) | |
| else: | |
| self.s2c_mask = None | |
| def forward(self, input_ids, attention_mask, apply_mask=True): | |
| encoder_output = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| cls_emb = encoder_output.last_hidden_state[:, 0, :] | |
| shared = self.shared(cls_emb) | |
| pathway_logits = self.pathway_head(shared) | |
| pathway_probs = F.softmax(pathway_logits, dim=-1) | |
| sc_input = torch.cat([shared, pathway_probs], dim=-1) | |
| superclass_logits = self.superclass_head(sc_input) | |
| if apply_mask and self.p2s_mask is not None: | |
| valid_sc = torch.matmul(pathway_probs, self.p2s_mask) | |
| superclass_logits = superclass_logits + torch.log(valid_sc.clamp(min=1e-8)) | |
| superclass_probs = F.softmax(superclass_logits, dim=-1) | |
| cl_input = torch.cat([shared, pathway_probs, superclass_probs], dim=-1) | |
| class_logits = self.class_head(cl_input) | |
| if apply_mask and self.s2c_mask is not None: | |
| valid_cl = torch.matmul(superclass_probs, self.s2c_mask) | |
| class_logits = class_logits + torch.log(valid_cl.clamp(min=1e-8)) | |
| return { | |
| 'pathway_logits': pathway_logits, | |
| 'superclass_logits': superclass_logits, | |
| 'class_logits': class_logits, | |
| } | |
| class NPClassifierService: | |
| """ | |
| Service class that handles model loading and prediction. | |
| Loaded once at server startup, used for all requests. | |
| """ | |
| def __init__(self, model_path, chemberta_name, device='cpu', | |
| max_length=256, top_n=5): | |
| self.device = torch.device(device) | |
| self.max_length = max_length | |
| self.top_n = top_n | |
| self.model = None | |
| self.tokenizer = None | |
| self.label_encoders = {} | |
| self.model_info = {} | |
| self._load_model(model_path, chemberta_name) | |
| def _load_model(self, model_path, chemberta_name): | |
| """Load model, tokenizer, and label encoders from the saved package.""" | |
| logger.info(f"Loading model from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location=self.device, | |
| weights_only=False) | |
| model_config = checkpoint['model_config'] | |
| label_encoders = checkpoint['label_encoders'] | |
| hierarchy = checkpoint['hierarchy'] | |
| logger.info(f"Model config: {model_config}") | |
| logger.info(f"Pathways: {model_config['num_pathways']}, " | |
| f"Superclasses: {model_config['num_superclasses']}, " | |
| f"Classes: {model_config['num_classes']}") | |
| self.label_encoders = { | |
| 'pathway': np.array(label_encoders['pathway']), | |
| 'superclass': np.array(label_encoders['superclass']), | |
| 'class': np.array(label_encoders['class']), | |
| } | |
| self.model_info = { | |
| 'model_name': model_config['model_name'], | |
| 'num_pathways': model_config['num_pathways'], | |
| 'num_superclasses': model_config['num_superclasses'], | |
| 'num_classes': model_config['num_classes'], | |
| 'hidden_dim': model_config['hidden_dim'], | |
| 'dropout': model_config['dropout'], | |
| 'device': str(self.device), | |
| 'pathway_classes': label_encoders['pathway'], | |
| 'superclass_classes': label_encoders['superclass'], | |
| } | |
| if 'metrics' in checkpoint: | |
| self.model_info['metrics'] = checkpoint['metrics'] | |
| p2s = hierarchy['pathway_to_superclass'].to(self.device) | |
| s2c = hierarchy['superclass_to_class'].to(self.device) | |
| self.model = HierarchicalChemBERTa( | |
| model_name=chemberta_name, | |
| num_pathways=model_config['num_pathways'], | |
| num_superclasses=model_config['num_superclasses'], | |
| num_classes=model_config['num_classes'], | |
| hidden_dim=model_config['hidden_dim'], | |
| dropout=model_config['dropout'], | |
| pathway_to_superclass=p2s, | |
| superclass_to_class=s2c, | |
| ).to(self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| self.tokenizer = AutoTokenizer.from_pretrained(chemberta_name) | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| self.model_info['total_params'] = f"{total_params:,}" | |
| logger.info(f"Model loaded successfully! ({total_params:,} parameters)") | |
| logger.info(f"Device: {self.device}") | |
| def predict(self, smiles_list, top_n=None): | |
| """ | |
| Predict Pathway, Superclass, Class for a list of SMILES. | |
| """ | |
| if top_n is None: | |
| top_n = self.top_n | |
| results = [] | |
| for smi in smiles_list: | |
| try: | |
| smi_str = str(smi).strip() | |
| if not smi_str: | |
| results.append(self._error_result(smi, "Empty SMILES string")) | |
| continue | |
| encoding = self.tokenizer( | |
| smi_str, | |
| add_special_tokens=True, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(self.device) | |
| attention_mask = encoding['attention_mask'].to(self.device) | |
| outputs = self.model(input_ids, attention_mask) | |
| result = { | |
| 'smiles': smi_str, | |
| 'status': 'success', | |
| 'predictions': {} | |
| } | |
| for level, logits_key, le_classes in [ | |
| ('pathway', 'pathway_logits', | |
| self.label_encoders['pathway']), | |
| ('superclass', 'superclass_logits', | |
| self.label_encoders['superclass']), | |
| ('class', 'class_logits', | |
| self.label_encoders['class']), | |
| ]: | |
| probs = F.softmax( | |
| outputs[logits_key], dim=-1 | |
| ).cpu().numpy()[0] | |
| top_indices = np.argsort(probs)[::-1][:top_n] | |
| top_predictions = [] | |
| for rank, idx in enumerate(top_indices): | |
| top_predictions.append({ | |
| 'rank': rank + 1, | |
| 'label': str(le_classes[idx]), | |
| 'confidence': round(float(probs[idx]), 6), | |
| }) | |
| result['predictions'][level] = { | |
| 'predicted': str(le_classes[top_indices[0]]), | |
| 'confidence': round( | |
| float(probs[top_indices[0]]), 6 | |
| ), | |
| 'top_n': top_predictions, | |
| } | |
| results.append(result) | |
| except Exception as e: | |
| logger.error(f"Error predicting {smi}: {e}") | |
| results.append(self._error_result(smi, str(e))) | |
| return results | |
| def _error_result(self, smiles, error_msg): | |
| return { | |
| 'smiles': str(smiles), | |
| 'status': 'error', | |
| 'error': error_msg, | |
| 'predictions': {} | |
| } | |
| def get_model_info(self): | |
| """Return model metadata for the about page.""" | |
| return self.model_info | |