cabin_npcbert / model.py
snehasis19's picture
Upload 61 files
52eda4b verified
"""
Model architecture and loading utilities.
This file contains the exact same model architecture used during training.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import numpy as np
import logging
logger = logging.getLogger(__name__)
class HierarchicalChemBERTa(nn.Module):
"""
ChemBERTa encoder + Hierarchical classification heads.
MUST match the training architecture exactly.
"""
def __init__(self, model_name, num_pathways, num_superclasses, num_classes,
hidden_dim=512, dropout=0.3,
pathway_to_superclass=None, superclass_to_class=None):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
enc_dim = self.encoder.config.hidden_size
self.shared = nn.Sequential(
nn.Linear(enc_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.pathway_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_pathways)
)
self.superclass_head = nn.Sequential(
nn.Linear(hidden_dim + num_pathways, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_superclasses)
)
self.class_head = nn.Sequential(
nn.Linear(hidden_dim + num_pathways + num_superclasses, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
if pathway_to_superclass is not None:
self.register_buffer('p2s_mask', pathway_to_superclass)
else:
self.p2s_mask = None
if superclass_to_class is not None:
self.register_buffer('s2c_mask', superclass_to_class)
else:
self.s2c_mask = None
def forward(self, input_ids, attention_mask, apply_mask=True):
encoder_output = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
cls_emb = encoder_output.last_hidden_state[:, 0, :]
shared = self.shared(cls_emb)
pathway_logits = self.pathway_head(shared)
pathway_probs = F.softmax(pathway_logits, dim=-1)
sc_input = torch.cat([shared, pathway_probs], dim=-1)
superclass_logits = self.superclass_head(sc_input)
if apply_mask and self.p2s_mask is not None:
valid_sc = torch.matmul(pathway_probs, self.p2s_mask)
superclass_logits = superclass_logits + torch.log(valid_sc.clamp(min=1e-8))
superclass_probs = F.softmax(superclass_logits, dim=-1)
cl_input = torch.cat([shared, pathway_probs, superclass_probs], dim=-1)
class_logits = self.class_head(cl_input)
if apply_mask and self.s2c_mask is not None:
valid_cl = torch.matmul(superclass_probs, self.s2c_mask)
class_logits = class_logits + torch.log(valid_cl.clamp(min=1e-8))
return {
'pathway_logits': pathway_logits,
'superclass_logits': superclass_logits,
'class_logits': class_logits,
}
class NPClassifierService:
"""
Service class that handles model loading and prediction.
Loaded once at server startup, used for all requests.
"""
def __init__(self, model_path, chemberta_name, device='cpu',
max_length=256, top_n=5):
self.device = torch.device(device)
self.max_length = max_length
self.top_n = top_n
self.model = None
self.tokenizer = None
self.label_encoders = {}
self.model_info = {}
self._load_model(model_path, chemberta_name)
def _load_model(self, model_path, chemberta_name):
"""Load model, tokenizer, and label encoders from the saved package."""
logger.info(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=self.device,
weights_only=False)
model_config = checkpoint['model_config']
label_encoders = checkpoint['label_encoders']
hierarchy = checkpoint['hierarchy']
logger.info(f"Model config: {model_config}")
logger.info(f"Pathways: {model_config['num_pathways']}, "
f"Superclasses: {model_config['num_superclasses']}, "
f"Classes: {model_config['num_classes']}")
self.label_encoders = {
'pathway': np.array(label_encoders['pathway']),
'superclass': np.array(label_encoders['superclass']),
'class': np.array(label_encoders['class']),
}
self.model_info = {
'model_name': model_config['model_name'],
'num_pathways': model_config['num_pathways'],
'num_superclasses': model_config['num_superclasses'],
'num_classes': model_config['num_classes'],
'hidden_dim': model_config['hidden_dim'],
'dropout': model_config['dropout'],
'device': str(self.device),
'pathway_classes': label_encoders['pathway'],
'superclass_classes': label_encoders['superclass'],
}
if 'metrics' in checkpoint:
self.model_info['metrics'] = checkpoint['metrics']
p2s = hierarchy['pathway_to_superclass'].to(self.device)
s2c = hierarchy['superclass_to_class'].to(self.device)
self.model = HierarchicalChemBERTa(
model_name=chemberta_name,
num_pathways=model_config['num_pathways'],
num_superclasses=model_config['num_superclasses'],
num_classes=model_config['num_classes'],
hidden_dim=model_config['hidden_dim'],
dropout=model_config['dropout'],
pathway_to_superclass=p2s,
superclass_to_class=s2c,
).to(self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(chemberta_name)
total_params = sum(p.numel() for p in self.model.parameters())
self.model_info['total_params'] = f"{total_params:,}"
logger.info(f"Model loaded successfully! ({total_params:,} parameters)")
logger.info(f"Device: {self.device}")
@torch.no_grad()
def predict(self, smiles_list, top_n=None):
"""
Predict Pathway, Superclass, Class for a list of SMILES.
"""
if top_n is None:
top_n = self.top_n
results = []
for smi in smiles_list:
try:
smi_str = str(smi).strip()
if not smi_str:
results.append(self._error_result(smi, "Empty SMILES string"))
continue
encoding = self.tokenizer(
smi_str,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
outputs = self.model(input_ids, attention_mask)
result = {
'smiles': smi_str,
'status': 'success',
'predictions': {}
}
for level, logits_key, le_classes in [
('pathway', 'pathway_logits',
self.label_encoders['pathway']),
('superclass', 'superclass_logits',
self.label_encoders['superclass']),
('class', 'class_logits',
self.label_encoders['class']),
]:
probs = F.softmax(
outputs[logits_key], dim=-1
).cpu().numpy()[0]
top_indices = np.argsort(probs)[::-1][:top_n]
top_predictions = []
for rank, idx in enumerate(top_indices):
top_predictions.append({
'rank': rank + 1,
'label': str(le_classes[idx]),
'confidence': round(float(probs[idx]), 6),
})
result['predictions'][level] = {
'predicted': str(le_classes[top_indices[0]]),
'confidence': round(
float(probs[top_indices[0]]), 6
),
'top_n': top_predictions,
}
results.append(result)
except Exception as e:
logger.error(f"Error predicting {smi}: {e}")
results.append(self._error_result(smi, str(e)))
return results
def _error_result(self, smiles, error_msg):
return {
'smiles': str(smiles),
'status': 'error',
'error': error_msg,
'predictions': {}
}
def get_model_info(self):
"""Return model metadata for the about page."""
return self.model_info