mist-28M-solvent-properties / modeling_mist_multitask.py
anoushka2000's picture
Upload folder using huggingface_hub
371f70d verified
import json
import logging
import math
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import IterableDataset
from smirk import SmirkTokenizerFast
from torch import nn
from torch.masked import MaskedTensor, masked_tensor
from transformers import (AutoConfig, AutoModel, AutoTokenizer,
DataCollatorWithPadding, PretrainedConfig,
PreTrainedModel)
MODEL_TYPE_ALIASES = {}
IGNORE_INDEX = -100
AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast)
def build_encoder(enc_dict: Dict[str, Any]):
mtype = enc_dict.get("model_type")
if mtype:
base = MODEL_TYPE_ALIASES.get(mtype, mtype)
cfg_cls = AutoConfig.for_model(base)
enc_cfg = cfg_cls.from_dict(enc_dict)
elif enc_dict.get("_name_or_path"):
enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"])
else:
raise KeyError("encoder config missing 'model_type' or '_name_or_path'")
if hasattr(enc_cfg, "add_pooling_layer"):
enc_cfg.add_pooling_layer = False
return AutoModel.from_config(enc_cfg)
class AbstractNormalizer(torch.nn.Module):
def __init__(self, num_outputs=None):
super().__init__()
self.num_outputs = num_outputs
def forward(self, x):
"""Remove normalization"""
raise NotImplementedError
def inverse(self, x):
"""Apply normalization"""
raise NotImplementedError
def _fit(self, x):
"""Fit the normalization parameters"""
raise NotImplementedError
def to_config(self):
return {'class': self.__class__.__name__, 'num_outputs': self.num_outputs}
def leader_fit(self, ds, rank, broadcast):
state = None
if rank == 0:
state = self.fit(ds)
state = broadcast(state)
self.load_state_dict(state)
def fit(self, ds, name='target'):
"""Fit the normalization parameters on dataset"""
if isinstance(ds, IterableDataset):
target = []
mask = []
for x in ds:
target.append(x[name])
mask.append(x[f'{name}_mask'])
target = torch.stack(target)
mask = torch.stack(mask)
else:
target = torch.stack([torch.tensor(x) for x in ds[name]])
mask = torch.stack([torch.tensor(x) for x in ds[f'{name}_mask']])
target = masked_tensor(target, mask)
state = self._fit(target)
return state
@classmethod
def get(cls, transform, num_outputs):
if isinstance(transform, list):
assert len(transform) == num_outputs
return ChannelWiseTransform([cls.get(t, 1) for t in transform])
elif transform in ['standardize', Standardize.__name__]:
return Standardize(num_outputs)
elif transform in ['power_transform', PowerTransform.__name__]:
return PowerTransform(num_outputs)
elif transform in ['log_transform', LogTransform.__name__]:
return LogTransform(num_outputs)
elif transform in ['max_scale', MaxScaleTransform.__name__]:
return MaxScaleTransform(num_outputs)
else:
return IdentityTransform()
class BiPairwiseBlock(nn.Module):
def __init__(self, d_model, bias=True, device=None, dtype=None):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs))
self.lin_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T))
def reset_parameters(self):
nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain('relu'))
with torch.no_grad():
self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T))
if self.bias is not None:
bound = 1 / math.sqrt(self.bias.size(0))
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
y_bi = torch.einsum('...ld,df,...rf->...lrf', x, self.bi_weight, x)
y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2))
x_linear = x.unsqueeze(-2) + x.unsqueeze(-3)
return y_bi + F.linear(x_linear, self.lin_weight, self.bias)
class ChannelWiseTransform(AbstractNormalizer):
def __init__(self, transforms):
super().__init__(len(transforms))
self.transforms = torch.nn.ModuleList(transforms)
def to_config(self):
return {'class': [t.__class__.__name__ for t in self.transforms], 'num_outputs': self.num_outputs}
def inverse(self, x):
return torch.cat([transform.inverse(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1)
def forward(self, x):
return torch.cat([transform.forward(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1)
def _fit(self, x):
for (idx, transform) in enumerate(self.transforms):
transform._fit(x[:, [idx]])
return self.state_dict()
class IdentityTransform(AbstractNormalizer):
def inverse(self, x):
return x
def forward(self, x):
return x
def _fit(self, x):
return self.state_dict()
class MISTFinetunedConfig(PretrainedConfig):
"""HF config for a single-task MIST wrapper."""
model_type = 'mist_finetuned'
def __init__(self, encoder=None, task_network=None, transform=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs):
super().__init__(**kwargs)
self.encoder = encoder or {}
self.task_network = task_network or {}
self.transform = transform or {}
self.channels = channels
self.tokenizer_class = tokenizer_class
class MISTFinetuned(PreTrainedModel):
config_class = MISTFinetunedConfig
def __init__(self, config):
super().__init__(config)
self.encoder = build_encoder_from_dict(config.encoder)
tn = config.task_network
self.task_network = PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout'])
self.transform = AbstractNormalizer.get(config.transform['class'], config.transform['num_outputs'])
self.channels = config.channels
self.tokenizer = self._resolve_tokenizer()
self.post_init()
@classmethod
def from_components(cls, encoder, task_network, transform, tokenizer=None, channels=None):
cfg = MISTFinetunedConfig(encoder=encoder.config.to_dict(), task_network={'embed_dim': encoder.config.hidden_size, 'output_size': task_network.final.out_features, 'dropout': task_network.dropout1.p}, transform=transform.to_config(), channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer')
model = cls(cfg)
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
model.task_network.load_state_dict(task_network.state_dict())
model.transform.load_state_dict(transform.state_dict())
model.tokenizer = tokenizer
return model
def forward(self, input_ids, attention_mask=None):
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
y = self.task_network(hs)
return self.transform.forward(y)
def _resolve_tokenizer(self, tokenizer=None):
if tokenizer is not None:
return tokenizer
if getattr(self, 'tokenizer', None) is not None:
return self.tokenizer
if self.name_or_path and '/' in self.name_or_path:
try:
return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True)
except Exception:
pass
if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path):
try:
return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True)
except Exception:
pass
return None
def embed(self, smi, tokenizer=None):
batch = self.tokenizer(smi)
batch = DataCollatorWithPadding(self.tokenizer)(batch)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
with torch.inference_mode():
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
return hs.to('cpu')
def predict(self, smi, return_dict=True, tokenizer=None):
batch = self.tokenizer(smi)
collate_fn = DataCollatorWithPadding(self.tokenizer)
batch = collate_fn(batch)
batch = {'input_ids': batch['input_ids'].to(self.encoder.device), 'attention_mask': batch['attention_mask'].to(self.encoder.device)}
with torch.inference_mode():
out = self(**batch).cpu()
if self.channels is None or not return_dict:
return out
return annotate_prediction(out, maybe_get_annotated_channels(self.channels))
def save_pretrained(self, save_directory, **kwargs):
super().save_pretrained(save_directory, **kwargs)
if getattr(self, 'tokenizer', None) is not None:
self.tokenizer.save_pretrained(save_directory)
class MaxScaleTransform(AbstractNormalizer):
"""
Divide by maximum value in training dataset.
"""
def __init__(self, mx, eps=1e-08):
super().__init__(1)
self.num_outputs = 1
self.max = mx
self.eps = float(eps)
assert 0 <= self.eps
def forward(self, x):
x_out = self.max * x
return x_out
def inverse(self, x):
x_out = x / self.max
return x_out
def _fit(self, target):
return self.state_dict()
class PairwiseMLP(nn.Module):
def __init__(self, d_model, dropout=0.2, device=None, dtype=None):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.Dropout(dropout), nn.GELU(), nn.Linear(d_model, d_model), nn.GELU())
def forward(self, x):
(_, N, _) = x.shape
x_l = x.unsqueeze(-2).expand(-1, N, N, -1)
x_r = x.unsqueeze(-3).expand(-1, N, N, -1)
x_pw = torch.cat([x_l, x_r], dim=-1)
y = self.mlp(x_pw)
return 0.5 * (y + y.transpose(1, 2))
class PowerTransform(AbstractNormalizer):
"""
Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like.
Followed by applying a zero-mean, unit-variance normalization to the
transformed output to rescale targets to [-1, 1].
"""
def __init__(self, num_outputs, eps=1e-08):
super().__init__(num_outputs)
self.num_outputs = num_outputs
self.register_buffer('lmbdas', torch.zeros(num_outputs))
self.register_buffer('mean', torch.zeros(num_outputs))
self.register_buffer('std', torch.zeros(num_outputs))
self.eps = float(eps)
assert 0 <= self.eps
def _yeo_johnson_transform(self, x, lmbda):
"""
Return transformed input x following Yeo-Johnson transform with
parameter lambda.
Adapted from
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354
"""
x_out = x.clone()
eps = torch.finfo(x.dtype).eps
pos = x >= 0
if abs(lmbda) < eps:
x_out[pos] = torch.log1p(x[pos])
else:
x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda
if abs(lmbda - 2) > eps:
x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
else:
x_out[~pos] = -torch.log1p(-x[~pos])
return x_out
def _yeo_johnson_inverse_transform(self, x, lmbda):
"""
Return inverse-transformed input x following Yeo-Johnson inverse
transform with parameter lambda.
Adapted from
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383
"""
x_out = x.clone()
pos = x >= 0
eps = torch.finfo(x.dtype).eps
if abs(lmbda) < eps:
x_out[pos] = torch.exp(x[pos]) - 1
else:
x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1
if abs(lmbda - 2) > eps:
x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
else:
x_out[~pos] = 1 - torch.exp(-x[~pos])
return x_out
def forward(self, x):
x = self.std * x + self.mean
x_out = torch.zeros_like(x)
for i in range(self.num_outputs):
x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i])
return x_out
def inverse(self, x):
x_out = torch.zeros_like(x)
for i in range(self.num_outputs):
x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i])
x_out = (x_out - self.mean) / self.std
return x_out
def _fit(self, target):
from sklearn.preprocessing import PowerTransformer as _PowerTransformer
transformer = _PowerTransformer(method='yeo-johnson', standardize=False)
target = torch.tensor(transformer.fit_transform(target.get_data().numpy()))
self.lmbdas = torch.tensor(transformer.lambdas_)
self.mean = target.mean(0).to(self.mean)
self.std = target.std(0).to(self.std) + self.eps
return self.state_dict()
class PredictionTaskHead(nn.Module):
def __init__(self, embed_dim, output_size=1, dropout=0.2):
super().__init__()
self.desc_skip_connection = True
self.fc1 = nn.Linear(embed_dim, embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.relu1 = nn.GELU()
self.fc2 = nn.Linear(embed_dim, embed_dim)
self.dropout2 = nn.Dropout(dropout)
self.relu2 = nn.GELU()
self.final = nn.Linear(embed_dim, output_size)
def forward(self, emb):
if emb.ndim > 2:
emb = emb[:, 0, :]
x_out = self.fc1(emb)
x_out = self.dropout1(x_out)
x_out = self.relu1(x_out)
if self.desc_skip_connection is True:
x_out = x_out + emb
z = self.fc2(x_out)
z = self.dropout2(z)
z = self.relu2(z)
if self.desc_skip_connection is True:
z = self.final(z + x_out)
else:
z = self.final(z)
return z
class Standardize(AbstractNormalizer):
def __init__(self, num_outputs, eps=1e-08):
super().__init__(num_outputs)
self.register_buffer('mean', torch.zeros(num_outputs))
self.register_buffer('std', torch.zeros(num_outputs))
self.eps = float(eps)
assert 0 <= self.eps
def forward(self, x):
return self.std * x + self.mean
def inverse(self, x):
return (x - self.mean) / self.std
def fit(self, ds, name='target'):
num_outputs = self.num_outputs
assert num_outputs is not None
mean = torch.zeros(num_outputs)
m2 = torch.zeros(num_outputs)
n = torch.zeros(num_outputs, dtype=torch.int)
for row in ds:
target = torch.tensor(row[name])
mask = torch.tensor(row[f'{name}_mask'])
x = masked_tensor(target, mask)
n += mask.view(-1, num_outputs).sum(0)
xs = x.view(-1, num_outputs).sum(0)
delta = xs - mean
mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0)
delta2 = xs - mean
m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0)
self.mean = mean.to(self.mean)
self.std = (m2 / n).sqrt().to(self.std) + self.eps
self.mean[self.mean.isnan()] = 0
self.std[self.std.isnan()] = 1
logging.debug('Fitted %s', self.state_dict())
return self.state_dict()
def _fit(self, target):
self.mean = target.mean(0).get_data().to(self.mean)
self.std = target.std(0).get_data().to(self.std) + self.eps
return self.state_dict()
def load_state_dict(self, state_dict, strict=True, assign=False):
if 'transform.mean' in state_dict:
state_dict = state_dict.copy()
state_dict['mean'] = state_dict.pop('transform.mean')
state_dict['std'] = state_dict.pop('transform.std')
if assign:
for (key, value) in state_dict.items():
if key in ['mean', 'std']:
self.register_buffer(key, value)
result = None
else:
result = super().load_state_dict(state_dict, strict=strict, assign=False)
return result
class LogTransform(Standardize):
def forward(self, x):
return torch.exp(super().forward(x))
def inverse(self, x):
return super().inverse(torch.log(x))
def _fit(self, target):
return super()._fit(torch.log(target))
class TokenPairwiseDistance(nn.Module):
def __init__(self, embed_dim, dropout=0.2, num_attention_heads=1, num_layers=1, activation='relu', ff_ratio=2):
super().__init__()
enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_attention_heads, dim_feedforward=ff_ratio * embed_dim, dropout=dropout, batch_first=True, norm_first=True)
self.interaction = nn.TransformerEncoder(enc_layer, num_layers)
self.pairwise_distance = PairwiseMLP(embed_dim, dropout)
self.distance1 = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU())
self.distance2 = nn.Linear(embed_dim, 1)
def forward(self, hs):
hs = self.interaction(hs)
with torch.autocast('cuda', dtype=torch.float32):
pw_dist = self.pairwise_distance(hs)
d = self.distance1(pw_dist) + pw_dist
d = self.distance2(d).squeeze(-1)
return F.relu(F.elu(d) + 1)
class TokenTaskHead(nn.Module):
def __init__(self, embed_dim, output_size=1, dropout=0.2):
super().__init__()
self.layers = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, output_size))
def forward(self, emb):
return self.layers(emb)
def annotate_prediction(y, channels):
out = {}
for (idx, chn) in enumerate(channels):
channel_info = {f: v for (f, v) in chn.items() if f != 'name'}
out[chn['name']] = {'value': y[:, idx], **channel_info}
return out
def build_encoder_from_dict(enc_dict):
if 'model_type' in enc_dict:
cfg_cls = AutoConfig.for_model(enc_dict['model_type'])
enc_cfg = cfg_cls.from_dict(enc_dict, strict=False)
elif '_name_or_path' in enc_dict:
enc_cfg = AutoConfig.from_pretrained(enc_dict['_name_or_path'], strict=False)
else:
raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.")
if hasattr(enc_cfg, 'add_pooling_layer'):
enc_cfg.add_pooling_layer = False
return AutoModel.from_config(enc_cfg)
def maybe_get_annotated_channels(channels):
for chn in channels:
if isinstance(chn, str):
yield {'name': chn, 'description': None, 'unit': None}
else:
yield chn
class MISTMultiTaskConfig(PretrainedConfig):
"""HuggingFace config for a multi-task MIST wrapper."""
model_type = 'mist_multitask'
def __init__(self, encoder=None, task_networks=None, transforms=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs):
super().__init__(**kwargs)
self.encoder = encoder or {}
self.task_networks = task_networks or []
self.transforms = transforms or []
self.channels = channels
self.tokenizer_class = tokenizer_class
class MISTMultiTask(PreTrainedModel):
config_class = MISTMultiTaskConfig
def __init__(self, config):
super().__init__(config)
self.encoder = build_encoder_from_dict(config.encoder)
self.task_networks = nn.ModuleList([PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout']) for tn in config.task_networks])
self.transforms = nn.ModuleList([AbstractNormalizer.get(tf_cfg['class'], tf_cfg['num_outputs']) for tf_cfg in config.transforms])
assert len(self.task_networks) == len(self.transforms), 'task_networks and transforms must align'
self.channels = config.channels
self.tokenizer = self._resolve_tokenizer()
self.post_init()
@classmethod
def from_components(cls, encoder, task_networks, transforms, tokenizer=None, channels=None):
cfg = MISTMultiTaskConfig(encoder=encoder.config.to_dict(), task_networks=[{'embed_dim': encoder.config.hidden_size, 'output_size': tn.final.out_features, 'dropout': tn.dropout1.p} for tn in task_networks], transforms=[tf.to_config() for tf in transforms], channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer')
model = cls(cfg)
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
for (dst, src) in zip(model.task_networks, task_networks):
dst.load_state_dict(src.state_dict())
for (dst, src) in zip(model.transforms, transforms):
dst.load_state_dict(src.state_dict())
model.tokenizer = tokenizer
return model
def forward(self, input_ids, attention_mask=None):
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
outs = []
for (tn, tf) in zip(self.task_networks, self.transforms):
outs.append(tf.forward(tn(hs)))
return torch.cat(outs, dim=-1)
def _resolve_tokenizer(self, tokenizer=None):
if tokenizer is not None:
return tokenizer
if getattr(self, 'tokenizer', None) is not None:
return self.tokenizer
if self.name_or_path and '/' in self.name_or_path:
try:
return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True)
except Exception:
pass
if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path):
try:
return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True)
except Exception:
pass
return None
def predict(self, smi, tokenizer=None):
batch = self.tokenizer(smi)
batch = DataCollatorWithPadding(self.tokenizer)(batch)
inputs = {k: v.to(self.device) for (k, v) in batch.items()}
with torch.inference_mode():
out = self(**inputs).cpu()
if self.channels is None:
return out
return annotate_prediction(out, self.channels)
def embed(self, smi, tokenizer=None):
batch = self.tokenizer(smi)
batch = DataCollatorWithPadding(self.tokenizer)(batch)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
with torch.inference_mode():
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
return hs.to('cpu')
def save_pretrained(self, save_directory, **kwargs):
super().save_pretrained(save_directory, **kwargs)
if getattr(self, 'tokenizer', None) is not None:
self.tokenizer.save_pretrained(save_directory)