import json import logging import math from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from datasets import IterableDataset from smirk import SmirkTokenizerFast from torch import nn from torch.masked import MaskedTensor, masked_tensor from transformers import (AutoConfig, AutoModel, AutoTokenizer, DataCollatorWithPadding, PretrainedConfig, PreTrainedModel) MODEL_TYPE_ALIASES = {} IGNORE_INDEX = -100 AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast) def build_encoder(enc_dict: Dict[str, Any]): mtype = enc_dict.get("model_type") if mtype: base = MODEL_TYPE_ALIASES.get(mtype, mtype) cfg_cls = AutoConfig.for_model(base) enc_cfg = cfg_cls.from_dict(enc_dict) elif enc_dict.get("_name_or_path"): enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"]) else: raise KeyError("encoder config missing 'model_type' or '_name_or_path'") if hasattr(enc_cfg, "add_pooling_layer"): enc_cfg.add_pooling_layer = False return AutoModel.from_config(enc_cfg) class AbstractNormalizer(torch.nn.Module): def __init__(self, num_outputs=None): super().__init__() self.num_outputs = num_outputs def forward(self, x): """Remove normalization""" raise NotImplementedError def inverse(self, x): """Apply normalization""" raise NotImplementedError def _fit(self, x): """Fit the normalization parameters""" raise NotImplementedError def to_config(self): return {'class': self.__class__.__name__, 'num_outputs': self.num_outputs} def leader_fit(self, ds, rank, broadcast): state = None if rank == 0: state = self.fit(ds) state = broadcast(state) self.load_state_dict(state) def fit(self, ds, name='target'): """Fit the normalization parameters on dataset""" if isinstance(ds, IterableDataset): target = [] mask = [] for x in ds: target.append(x[name]) mask.append(x[f'{name}_mask']) target = torch.stack(target) mask = torch.stack(mask) else: target = torch.stack([torch.tensor(x) for x in ds[name]]) mask = torch.stack([torch.tensor(x) for x in ds[f'{name}_mask']]) target = masked_tensor(target, mask) state = self._fit(target) return state @classmethod def get(cls, transform, num_outputs): if isinstance(transform, list): assert len(transform) == num_outputs return ChannelWiseTransform([cls.get(t, 1) for t in transform]) elif transform in ['standardize', Standardize.__name__]: return Standardize(num_outputs) elif transform in ['power_transform', PowerTransform.__name__]: return PowerTransform(num_outputs) elif transform in ['log_transform', LogTransform.__name__]: return LogTransform(num_outputs) elif transform in ['max_scale', MaxScaleTransform.__name__]: return MaxScaleTransform(num_outputs) else: return IdentityTransform() class BiPairwiseBlock(nn.Module): def __init__(self, d_model, bias=True, device=None, dtype=None): super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs)) self.lin_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs)) else: self.register_parameter('bias', None) self.reset_parameters() self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T)) def reset_parameters(self): nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain('relu')) with torch.no_grad(): self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T)) if self.bias is not None: bound = 1 / math.sqrt(self.bias.size(0)) nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): y_bi = torch.einsum('...ld,df,...rf->...lrf', x, self.bi_weight, x) y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) x_linear = x.unsqueeze(-2) + x.unsqueeze(-3) return y_bi + F.linear(x_linear, self.lin_weight, self.bias) class ChannelWiseTransform(AbstractNormalizer): def __init__(self, transforms): super().__init__(len(transforms)) self.transforms = torch.nn.ModuleList(transforms) def to_config(self): return {'class': [t.__class__.__name__ for t in self.transforms], 'num_outputs': self.num_outputs} def inverse(self, x): return torch.cat([transform.inverse(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1) def forward(self, x): return torch.cat([transform.forward(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1) def _fit(self, x): for (idx, transform) in enumerate(self.transforms): transform._fit(x[:, [idx]]) return self.state_dict() class IdentityTransform(AbstractNormalizer): def inverse(self, x): return x def forward(self, x): return x def _fit(self, x): return self.state_dict() class MISTFinetunedConfig(PretrainedConfig): """HF config for a single-task MIST wrapper.""" model_type = 'mist_finetuned' def __init__(self, encoder=None, task_network=None, transform=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs): super().__init__(**kwargs) self.encoder = encoder or {} self.task_network = task_network or {} self.transform = transform or {} self.channels = channels self.tokenizer_class = tokenizer_class class MISTFinetuned(PreTrainedModel): config_class = MISTFinetunedConfig def __init__(self, config): super().__init__(config) self.encoder = build_encoder_from_dict(config.encoder) tn = config.task_network self.task_network = PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout']) self.transform = AbstractNormalizer.get(config.transform['class'], config.transform['num_outputs']) self.channels = config.channels self.tokenizer = self._resolve_tokenizer() self.post_init() @classmethod def from_components(cls, encoder, task_network, transform, tokenizer=None, channels=None): cfg = MISTFinetunedConfig(encoder=encoder.config.to_dict(), task_network={'embed_dim': encoder.config.hidden_size, 'output_size': task_network.final.out_features, 'dropout': task_network.dropout1.p}, transform=transform.to_config(), channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer') model = cls(cfg) model.encoder.load_state_dict(encoder.state_dict(), strict=False) model.task_network.load_state_dict(task_network.state_dict()) model.transform.load_state_dict(transform.state_dict()) model.tokenizer = tokenizer return model def forward(self, input_ids, attention_mask=None): hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state y = self.task_network(hs) return self.transform.forward(y) def _resolve_tokenizer(self, tokenizer=None): if tokenizer is not None: return tokenizer if getattr(self, 'tokenizer', None) is not None: return self.tokenizer if self.name_or_path and '/' in self.name_or_path: try: return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True) except Exception: pass if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path): try: return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True) except Exception: pass return None def embed(self, smi, tokenizer=None): batch = self.tokenizer(smi) batch = DataCollatorWithPadding(self.tokenizer)(batch) input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) with torch.inference_mode(): hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] return hs.to('cpu') def predict(self, smi, return_dict=True, tokenizer=None): batch = self.tokenizer(smi) collate_fn = DataCollatorWithPadding(self.tokenizer) batch = collate_fn(batch) batch = {'input_ids': batch['input_ids'].to(self.encoder.device), 'attention_mask': batch['attention_mask'].to(self.encoder.device)} with torch.inference_mode(): out = self(**batch).cpu() if self.channels is None or not return_dict: return out return annotate_prediction(out, maybe_get_annotated_channels(self.channels)) def save_pretrained(self, save_directory, **kwargs): super().save_pretrained(save_directory, **kwargs) if getattr(self, 'tokenizer', None) is not None: self.tokenizer.save_pretrained(save_directory) class MaxScaleTransform(AbstractNormalizer): """ Divide by maximum value in training dataset. """ def __init__(self, mx, eps=1e-08): super().__init__(1) self.num_outputs = 1 self.max = mx self.eps = float(eps) assert 0 <= self.eps def forward(self, x): x_out = self.max * x return x_out def inverse(self, x): x_out = x / self.max return x_out def _fit(self, target): return self.state_dict() class PairwiseMLP(nn.Module): def __init__(self, d_model, dropout=0.2, device=None, dtype=None): super().__init__() self.mlp = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.Dropout(dropout), nn.GELU(), nn.Linear(d_model, d_model), nn.GELU()) def forward(self, x): (_, N, _) = x.shape x_l = x.unsqueeze(-2).expand(-1, N, N, -1) x_r = x.unsqueeze(-3).expand(-1, N, N, -1) x_pw = torch.cat([x_l, x_r], dim=-1) y = self.mlp(x_pw) return 0.5 * (y + y.transpose(1, 2)) class PowerTransform(AbstractNormalizer): """ Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like. Followed by applying a zero-mean, unit-variance normalization to the transformed output to rescale targets to [-1, 1]. """ def __init__(self, num_outputs, eps=1e-08): super().__init__(num_outputs) self.num_outputs = num_outputs self.register_buffer('lmbdas', torch.zeros(num_outputs)) self.register_buffer('mean', torch.zeros(num_outputs)) self.register_buffer('std', torch.zeros(num_outputs)) self.eps = float(eps) assert 0 <= self.eps def _yeo_johnson_transform(self, x, lmbda): """ Return transformed input x following Yeo-Johnson transform with parameter lambda. Adapted from https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354 """ x_out = x.clone() eps = torch.finfo(x.dtype).eps pos = x >= 0 if abs(lmbda) < eps: x_out[pos] = torch.log1p(x[pos]) else: x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda if abs(lmbda - 2) > eps: x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) else: x_out[~pos] = -torch.log1p(-x[~pos]) return x_out def _yeo_johnson_inverse_transform(self, x, lmbda): """ Return inverse-transformed input x following Yeo-Johnson inverse transform with parameter lambda. Adapted from https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383 """ x_out = x.clone() pos = x >= 0 eps = torch.finfo(x.dtype).eps if abs(lmbda) < eps: x_out[pos] = torch.exp(x[pos]) - 1 else: x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1 if abs(lmbda - 2) > eps: x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda)) else: x_out[~pos] = 1 - torch.exp(-x[~pos]) return x_out def forward(self, x): x = self.std * x + self.mean x_out = torch.zeros_like(x) for i in range(self.num_outputs): x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i]) return x_out def inverse(self, x): x_out = torch.zeros_like(x) for i in range(self.num_outputs): x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i]) x_out = (x_out - self.mean) / self.std return x_out def _fit(self, target): from sklearn.preprocessing import PowerTransformer as _PowerTransformer transformer = _PowerTransformer(method='yeo-johnson', standardize=False) target = torch.tensor(transformer.fit_transform(target.get_data().numpy())) self.lmbdas = torch.tensor(transformer.lambdas_) self.mean = target.mean(0).to(self.mean) self.std = target.std(0).to(self.std) + self.eps return self.state_dict() class PredictionTaskHead(nn.Module): def __init__(self, embed_dim, output_size=1, dropout=0.2): super().__init__() self.desc_skip_connection = True self.fc1 = nn.Linear(embed_dim, embed_dim) self.dropout1 = nn.Dropout(dropout) self.relu1 = nn.GELU() self.fc2 = nn.Linear(embed_dim, embed_dim) self.dropout2 = nn.Dropout(dropout) self.relu2 = nn.GELU() self.final = nn.Linear(embed_dim, output_size) def forward(self, emb): if emb.ndim > 2: emb = emb[:, 0, :] x_out = self.fc1(emb) x_out = self.dropout1(x_out) x_out = self.relu1(x_out) if self.desc_skip_connection is True: x_out = x_out + emb z = self.fc2(x_out) z = self.dropout2(z) z = self.relu2(z) if self.desc_skip_connection is True: z = self.final(z + x_out) else: z = self.final(z) return z class Standardize(AbstractNormalizer): def __init__(self, num_outputs, eps=1e-08): super().__init__(num_outputs) self.register_buffer('mean', torch.zeros(num_outputs)) self.register_buffer('std', torch.zeros(num_outputs)) self.eps = float(eps) assert 0 <= self.eps def forward(self, x): return self.std * x + self.mean def inverse(self, x): return (x - self.mean) / self.std def fit(self, ds, name='target'): num_outputs = self.num_outputs assert num_outputs is not None mean = torch.zeros(num_outputs) m2 = torch.zeros(num_outputs) n = torch.zeros(num_outputs, dtype=torch.int) for row in ds: target = torch.tensor(row[name]) mask = torch.tensor(row[f'{name}_mask']) x = masked_tensor(target, mask) n += mask.view(-1, num_outputs).sum(0) xs = x.view(-1, num_outputs).sum(0) delta = xs - mean mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0) delta2 = xs - mean m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0) self.mean = mean.to(self.mean) self.std = (m2 / n).sqrt().to(self.std) + self.eps self.mean[self.mean.isnan()] = 0 self.std[self.std.isnan()] = 1 logging.debug('Fitted %s', self.state_dict()) return self.state_dict() def _fit(self, target): self.mean = target.mean(0).get_data().to(self.mean) self.std = target.std(0).get_data().to(self.std) + self.eps return self.state_dict() def load_state_dict(self, state_dict, strict=True, assign=False): if 'transform.mean' in state_dict: state_dict = state_dict.copy() state_dict['mean'] = state_dict.pop('transform.mean') state_dict['std'] = state_dict.pop('transform.std') if assign: for (key, value) in state_dict.items(): if key in ['mean', 'std']: self.register_buffer(key, value) result = None else: result = super().load_state_dict(state_dict, strict=strict, assign=False) return result class LogTransform(Standardize): def forward(self, x): return torch.exp(super().forward(x)) def inverse(self, x): return super().inverse(torch.log(x)) def _fit(self, target): return super()._fit(torch.log(target)) class TokenPairwiseDistance(nn.Module): def __init__(self, embed_dim, dropout=0.2, num_attention_heads=1, num_layers=1, activation='relu', ff_ratio=2): super().__init__() enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_attention_heads, dim_feedforward=ff_ratio * embed_dim, dropout=dropout, batch_first=True, norm_first=True) self.interaction = nn.TransformerEncoder(enc_layer, num_layers) self.pairwise_distance = PairwiseMLP(embed_dim, dropout) self.distance1 = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU()) self.distance2 = nn.Linear(embed_dim, 1) def forward(self, hs): hs = self.interaction(hs) with torch.autocast('cuda', dtype=torch.float32): pw_dist = self.pairwise_distance(hs) d = self.distance1(pw_dist) + pw_dist d = self.distance2(d).squeeze(-1) return F.relu(F.elu(d) + 1) class TokenTaskHead(nn.Module): def __init__(self, embed_dim, output_size=1, dropout=0.2): super().__init__() self.layers = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, output_size)) def forward(self, emb): return self.layers(emb) def annotate_prediction(y, channels): out = {} for (idx, chn) in enumerate(channels): channel_info = {f: v for (f, v) in chn.items() if f != 'name'} out[chn['name']] = {'value': y[:, idx], **channel_info} return out def build_encoder_from_dict(enc_dict): if 'model_type' in enc_dict: cfg_cls = AutoConfig.for_model(enc_dict['model_type']) enc_cfg = cfg_cls.from_dict(enc_dict, strict=False) elif '_name_or_path' in enc_dict: enc_cfg = AutoConfig.from_pretrained(enc_dict['_name_or_path'], strict=False) else: raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.") if hasattr(enc_cfg, 'add_pooling_layer'): enc_cfg.add_pooling_layer = False return AutoModel.from_config(enc_cfg) def maybe_get_annotated_channels(channels): for chn in channels: if isinstance(chn, str): yield {'name': chn, 'description': None, 'unit': None} else: yield chn class MISTMultiTaskConfig(PretrainedConfig): """HuggingFace config for a multi-task MIST wrapper.""" model_type = 'mist_multitask' def __init__(self, encoder=None, task_networks=None, transforms=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs): super().__init__(**kwargs) self.encoder = encoder or {} self.task_networks = task_networks or [] self.transforms = transforms or [] self.channels = channels self.tokenizer_class = tokenizer_class class MISTMultiTask(PreTrainedModel): config_class = MISTMultiTaskConfig def __init__(self, config): super().__init__(config) self.encoder = build_encoder_from_dict(config.encoder) self.task_networks = nn.ModuleList([PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout']) for tn in config.task_networks]) self.transforms = nn.ModuleList([AbstractNormalizer.get(tf_cfg['class'], tf_cfg['num_outputs']) for tf_cfg in config.transforms]) assert len(self.task_networks) == len(self.transforms), 'task_networks and transforms must align' self.channels = config.channels self.tokenizer = self._resolve_tokenizer() self.post_init() @classmethod def from_components(cls, encoder, task_networks, transforms, tokenizer=None, channels=None): cfg = MISTMultiTaskConfig(encoder=encoder.config.to_dict(), task_networks=[{'embed_dim': encoder.config.hidden_size, 'output_size': tn.final.out_features, 'dropout': tn.dropout1.p} for tn in task_networks], transforms=[tf.to_config() for tf in transforms], channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer') model = cls(cfg) model.encoder.load_state_dict(encoder.state_dict(), strict=False) for (dst, src) in zip(model.task_networks, task_networks): dst.load_state_dict(src.state_dict()) for (dst, src) in zip(model.transforms, transforms): dst.load_state_dict(src.state_dict()) model.tokenizer = tokenizer return model def forward(self, input_ids, attention_mask=None): hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state outs = [] for (tn, tf) in zip(self.task_networks, self.transforms): outs.append(tf.forward(tn(hs))) return torch.cat(outs, dim=-1) def _resolve_tokenizer(self, tokenizer=None): if tokenizer is not None: return tokenizer if getattr(self, 'tokenizer', None) is not None: return self.tokenizer if self.name_or_path and '/' in self.name_or_path: try: return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True) except Exception: pass if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path): try: return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True) except Exception: pass return None def predict(self, smi, tokenizer=None): batch = self.tokenizer(smi) batch = DataCollatorWithPadding(self.tokenizer)(batch) inputs = {k: v.to(self.device) for (k, v) in batch.items()} with torch.inference_mode(): out = self(**inputs).cpu() if self.channels is None: return out return annotate_prediction(out, self.channels) def embed(self, smi, tokenizer=None): batch = self.tokenizer(smi) batch = DataCollatorWithPadding(self.tokenizer)(batch) input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) with torch.inference_mode(): hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] return hs.to('cpu') def save_pretrained(self, save_directory, **kwargs): super().save_pretrained(save_directory, **kwargs) if getattr(self, 'tokenizer', None) is not None: self.tokenizer.save_pretrained(save_directory)