Feature Extraction
Transformers
Safetensors
English
mist_multitask
mist
chemistry
molecular-property-prediction
custom_code
Instructions to use mist-models/mist-28M-solvent-properties with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mist-models/mist-28M-solvent-properties with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="mist-models/mist-28M-solvent-properties", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mist-models/mist-28M-solvent-properties", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import json | |
| import logging | |
| import math | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from datasets import IterableDataset | |
| from smirk import SmirkTokenizerFast | |
| from torch import nn | |
| from torch.masked import MaskedTensor, masked_tensor | |
| from transformers import (AutoConfig, AutoModel, AutoTokenizer, | |
| DataCollatorWithPadding, PretrainedConfig, | |
| PreTrainedModel) | |
| MODEL_TYPE_ALIASES = {} | |
| IGNORE_INDEX = -100 | |
| AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast) | |
| def build_encoder(enc_dict: Dict[str, Any]): | |
| mtype = enc_dict.get("model_type") | |
| if mtype: | |
| base = MODEL_TYPE_ALIASES.get(mtype, mtype) | |
| cfg_cls = AutoConfig.for_model(base) | |
| enc_cfg = cfg_cls.from_dict(enc_dict) | |
| elif enc_dict.get("_name_or_path"): | |
| enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"]) | |
| else: | |
| raise KeyError("encoder config missing 'model_type' or '_name_or_path'") | |
| if hasattr(enc_cfg, "add_pooling_layer"): | |
| enc_cfg.add_pooling_layer = False | |
| return AutoModel.from_config(enc_cfg) | |
| class AbstractNormalizer(torch.nn.Module): | |
| def __init__(self, num_outputs=None): | |
| super().__init__() | |
| self.num_outputs = num_outputs | |
| def forward(self, x): | |
| """Remove normalization""" | |
| raise NotImplementedError | |
| def inverse(self, x): | |
| """Apply normalization""" | |
| raise NotImplementedError | |
| def _fit(self, x): | |
| """Fit the normalization parameters""" | |
| raise NotImplementedError | |
| def to_config(self): | |
| return {'class': self.__class__.__name__, 'num_outputs': self.num_outputs} | |
| def leader_fit(self, ds, rank, broadcast): | |
| state = None | |
| if rank == 0: | |
| state = self.fit(ds) | |
| state = broadcast(state) | |
| self.load_state_dict(state) | |
| def fit(self, ds, name='target'): | |
| """Fit the normalization parameters on dataset""" | |
| if isinstance(ds, IterableDataset): | |
| target = [] | |
| mask = [] | |
| for x in ds: | |
| target.append(x[name]) | |
| mask.append(x[f'{name}_mask']) | |
| target = torch.stack(target) | |
| mask = torch.stack(mask) | |
| else: | |
| target = torch.stack([torch.tensor(x) for x in ds[name]]) | |
| mask = torch.stack([torch.tensor(x) for x in ds[f'{name}_mask']]) | |
| target = masked_tensor(target, mask) | |
| state = self._fit(target) | |
| return state | |
| def get(cls, transform, num_outputs): | |
| if isinstance(transform, list): | |
| assert len(transform) == num_outputs | |
| return ChannelWiseTransform([cls.get(t, 1) for t in transform]) | |
| elif transform in ['standardize', Standardize.__name__]: | |
| return Standardize(num_outputs) | |
| elif transform in ['power_transform', PowerTransform.__name__]: | |
| return PowerTransform(num_outputs) | |
| elif transform in ['log_transform', LogTransform.__name__]: | |
| return LogTransform(num_outputs) | |
| elif transform in ['max_scale', MaxScaleTransform.__name__]: | |
| return MaxScaleTransform(num_outputs) | |
| else: | |
| return IdentityTransform() | |
| class BiPairwiseBlock(nn.Module): | |
| def __init__(self, d_model, bias=True, device=None, dtype=None): | |
| super().__init__() | |
| factory_kwargs = {'device': device, 'dtype': dtype} | |
| self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs)) | |
| self.lin_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs)) | |
| else: | |
| self.register_parameter('bias', None) | |
| self.reset_parameters() | |
| self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T)) | |
| def reset_parameters(self): | |
| nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain('relu')) | |
| nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain('relu')) | |
| with torch.no_grad(): | |
| self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T)) | |
| if self.bias is not None: | |
| bound = 1 / math.sqrt(self.bias.size(0)) | |
| nn.init.uniform_(self.bias, -bound, bound) | |
| def forward(self, x): | |
| y_bi = torch.einsum('...ld,df,...rf->...lrf', x, self.bi_weight, x) | |
| y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) | |
| x_linear = x.unsqueeze(-2) + x.unsqueeze(-3) | |
| return y_bi + F.linear(x_linear, self.lin_weight, self.bias) | |
| class ChannelWiseTransform(AbstractNormalizer): | |
| def __init__(self, transforms): | |
| super().__init__(len(transforms)) | |
| self.transforms = torch.nn.ModuleList(transforms) | |
| def to_config(self): | |
| return {'class': [t.__class__.__name__ for t in self.transforms], 'num_outputs': self.num_outputs} | |
| def inverse(self, x): | |
| return torch.cat([transform.inverse(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1) | |
| def forward(self, x): | |
| return torch.cat([transform.forward(x[:, [idx]]) for (idx, transform) in enumerate(self.transforms)], dim=1) | |
| def _fit(self, x): | |
| for (idx, transform) in enumerate(self.transforms): | |
| transform._fit(x[:, [idx]]) | |
| return self.state_dict() | |
| class IdentityTransform(AbstractNormalizer): | |
| def inverse(self, x): | |
| return x | |
| def forward(self, x): | |
| return x | |
| def _fit(self, x): | |
| return self.state_dict() | |
| class MISTFinetunedConfig(PretrainedConfig): | |
| """HF config for a single-task MIST wrapper.""" | |
| model_type = 'mist_finetuned' | |
| def __init__(self, encoder=None, task_network=None, transform=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs): | |
| super().__init__(**kwargs) | |
| self.encoder = encoder or {} | |
| self.task_network = task_network or {} | |
| self.transform = transform or {} | |
| self.channels = channels | |
| self.tokenizer_class = tokenizer_class | |
| class MISTFinetuned(PreTrainedModel): | |
| config_class = MISTFinetunedConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = build_encoder_from_dict(config.encoder) | |
| tn = config.task_network | |
| self.task_network = PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout']) | |
| self.transform = AbstractNormalizer.get(config.transform['class'], config.transform['num_outputs']) | |
| self.channels = config.channels | |
| self.tokenizer = self._resolve_tokenizer() | |
| self.post_init() | |
| def from_components(cls, encoder, task_network, transform, tokenizer=None, channels=None): | |
| cfg = MISTFinetunedConfig(encoder=encoder.config.to_dict(), task_network={'embed_dim': encoder.config.hidden_size, 'output_size': task_network.final.out_features, 'dropout': task_network.dropout1.p}, transform=transform.to_config(), channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer') | |
| model = cls(cfg) | |
| model.encoder.load_state_dict(encoder.state_dict(), strict=False) | |
| model.task_network.load_state_dict(task_network.state_dict()) | |
| model.transform.load_state_dict(transform.state_dict()) | |
| model.tokenizer = tokenizer | |
| return model | |
| def forward(self, input_ids, attention_mask=None): | |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state | |
| y = self.task_network(hs) | |
| return self.transform.forward(y) | |
| def _resolve_tokenizer(self, tokenizer=None): | |
| if tokenizer is not None: | |
| return tokenizer | |
| if getattr(self, 'tokenizer', None) is not None: | |
| return self.tokenizer | |
| if self.name_or_path and '/' in self.name_or_path: | |
| try: | |
| return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True) | |
| except Exception: | |
| pass | |
| if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path): | |
| try: | |
| return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True) | |
| except Exception: | |
| pass | |
| return None | |
| def embed(self, smi, tokenizer=None): | |
| batch = self.tokenizer(smi) | |
| batch = DataCollatorWithPadding(self.tokenizer)(batch) | |
| input_ids = batch['input_ids'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| with torch.inference_mode(): | |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] | |
| return hs.to('cpu') | |
| def predict(self, smi, return_dict=True, tokenizer=None): | |
| batch = self.tokenizer(smi) | |
| collate_fn = DataCollatorWithPadding(self.tokenizer) | |
| batch = collate_fn(batch) | |
| batch = {'input_ids': batch['input_ids'].to(self.encoder.device), 'attention_mask': batch['attention_mask'].to(self.encoder.device)} | |
| with torch.inference_mode(): | |
| out = self(**batch).cpu() | |
| if self.channels is None or not return_dict: | |
| return out | |
| return annotate_prediction(out, maybe_get_annotated_channels(self.channels)) | |
| def save_pretrained(self, save_directory, **kwargs): | |
| super().save_pretrained(save_directory, **kwargs) | |
| if getattr(self, 'tokenizer', None) is not None: | |
| self.tokenizer.save_pretrained(save_directory) | |
| class MaxScaleTransform(AbstractNormalizer): | |
| """ | |
| Divide by maximum value in training dataset. | |
| """ | |
| def __init__(self, mx, eps=1e-08): | |
| super().__init__(1) | |
| self.num_outputs = 1 | |
| self.max = mx | |
| self.eps = float(eps) | |
| assert 0 <= self.eps | |
| def forward(self, x): | |
| x_out = self.max * x | |
| return x_out | |
| def inverse(self, x): | |
| x_out = x / self.max | |
| return x_out | |
| def _fit(self, target): | |
| return self.state_dict() | |
| class PairwiseMLP(nn.Module): | |
| def __init__(self, d_model, dropout=0.2, device=None, dtype=None): | |
| super().__init__() | |
| self.mlp = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.Dropout(dropout), nn.GELU(), nn.Linear(d_model, d_model), nn.GELU()) | |
| def forward(self, x): | |
| (_, N, _) = x.shape | |
| x_l = x.unsqueeze(-2).expand(-1, N, N, -1) | |
| x_r = x.unsqueeze(-3).expand(-1, N, N, -1) | |
| x_pw = torch.cat([x_l, x_r], dim=-1) | |
| y = self.mlp(x_pw) | |
| return 0.5 * (y + y.transpose(1, 2)) | |
| class PowerTransform(AbstractNormalizer): | |
| """ | |
| Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like. | |
| Followed by applying a zero-mean, unit-variance normalization to the | |
| transformed output to rescale targets to [-1, 1]. | |
| """ | |
| def __init__(self, num_outputs, eps=1e-08): | |
| super().__init__(num_outputs) | |
| self.num_outputs = num_outputs | |
| self.register_buffer('lmbdas', torch.zeros(num_outputs)) | |
| self.register_buffer('mean', torch.zeros(num_outputs)) | |
| self.register_buffer('std', torch.zeros(num_outputs)) | |
| self.eps = float(eps) | |
| assert 0 <= self.eps | |
| def _yeo_johnson_transform(self, x, lmbda): | |
| """ | |
| Return transformed input x following Yeo-Johnson transform with | |
| parameter lambda. | |
| Adapted from | |
| https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354 | |
| """ | |
| x_out = x.clone() | |
| eps = torch.finfo(x.dtype).eps | |
| pos = x >= 0 | |
| if abs(lmbda) < eps: | |
| x_out[pos] = torch.log1p(x[pos]) | |
| else: | |
| x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda | |
| if abs(lmbda - 2) > eps: | |
| x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) | |
| else: | |
| x_out[~pos] = -torch.log1p(-x[~pos]) | |
| return x_out | |
| def _yeo_johnson_inverse_transform(self, x, lmbda): | |
| """ | |
| Return inverse-transformed input x following Yeo-Johnson inverse | |
| transform with parameter lambda. | |
| Adapted from | |
| https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383 | |
| """ | |
| x_out = x.clone() | |
| pos = x >= 0 | |
| eps = torch.finfo(x.dtype).eps | |
| if abs(lmbda) < eps: | |
| x_out[pos] = torch.exp(x[pos]) - 1 | |
| else: | |
| x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1 | |
| if abs(lmbda - 2) > eps: | |
| x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda)) | |
| else: | |
| x_out[~pos] = 1 - torch.exp(-x[~pos]) | |
| return x_out | |
| def forward(self, x): | |
| x = self.std * x + self.mean | |
| x_out = torch.zeros_like(x) | |
| for i in range(self.num_outputs): | |
| x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i]) | |
| return x_out | |
| def inverse(self, x): | |
| x_out = torch.zeros_like(x) | |
| for i in range(self.num_outputs): | |
| x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i]) | |
| x_out = (x_out - self.mean) / self.std | |
| return x_out | |
| def _fit(self, target): | |
| from sklearn.preprocessing import PowerTransformer as _PowerTransformer | |
| transformer = _PowerTransformer(method='yeo-johnson', standardize=False) | |
| target = torch.tensor(transformer.fit_transform(target.get_data().numpy())) | |
| self.lmbdas = torch.tensor(transformer.lambdas_) | |
| self.mean = target.mean(0).to(self.mean) | |
| self.std = target.std(0).to(self.std) + self.eps | |
| return self.state_dict() | |
| class PredictionTaskHead(nn.Module): | |
| def __init__(self, embed_dim, output_size=1, dropout=0.2): | |
| super().__init__() | |
| self.desc_skip_connection = True | |
| self.fc1 = nn.Linear(embed_dim, embed_dim) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.relu1 = nn.GELU() | |
| self.fc2 = nn.Linear(embed_dim, embed_dim) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.relu2 = nn.GELU() | |
| self.final = nn.Linear(embed_dim, output_size) | |
| def forward(self, emb): | |
| if emb.ndim > 2: | |
| emb = emb[:, 0, :] | |
| x_out = self.fc1(emb) | |
| x_out = self.dropout1(x_out) | |
| x_out = self.relu1(x_out) | |
| if self.desc_skip_connection is True: | |
| x_out = x_out + emb | |
| z = self.fc2(x_out) | |
| z = self.dropout2(z) | |
| z = self.relu2(z) | |
| if self.desc_skip_connection is True: | |
| z = self.final(z + x_out) | |
| else: | |
| z = self.final(z) | |
| return z | |
| class Standardize(AbstractNormalizer): | |
| def __init__(self, num_outputs, eps=1e-08): | |
| super().__init__(num_outputs) | |
| self.register_buffer('mean', torch.zeros(num_outputs)) | |
| self.register_buffer('std', torch.zeros(num_outputs)) | |
| self.eps = float(eps) | |
| assert 0 <= self.eps | |
| def forward(self, x): | |
| return self.std * x + self.mean | |
| def inverse(self, x): | |
| return (x - self.mean) / self.std | |
| def fit(self, ds, name='target'): | |
| num_outputs = self.num_outputs | |
| assert num_outputs is not None | |
| mean = torch.zeros(num_outputs) | |
| m2 = torch.zeros(num_outputs) | |
| n = torch.zeros(num_outputs, dtype=torch.int) | |
| for row in ds: | |
| target = torch.tensor(row[name]) | |
| mask = torch.tensor(row[f'{name}_mask']) | |
| x = masked_tensor(target, mask) | |
| n += mask.view(-1, num_outputs).sum(0) | |
| xs = x.view(-1, num_outputs).sum(0) | |
| delta = xs - mean | |
| mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0) | |
| delta2 = xs - mean | |
| m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0) | |
| self.mean = mean.to(self.mean) | |
| self.std = (m2 / n).sqrt().to(self.std) + self.eps | |
| self.mean[self.mean.isnan()] = 0 | |
| self.std[self.std.isnan()] = 1 | |
| logging.debug('Fitted %s', self.state_dict()) | |
| return self.state_dict() | |
| def _fit(self, target): | |
| self.mean = target.mean(0).get_data().to(self.mean) | |
| self.std = target.std(0).get_data().to(self.std) + self.eps | |
| return self.state_dict() | |
| def load_state_dict(self, state_dict, strict=True, assign=False): | |
| if 'transform.mean' in state_dict: | |
| state_dict = state_dict.copy() | |
| state_dict['mean'] = state_dict.pop('transform.mean') | |
| state_dict['std'] = state_dict.pop('transform.std') | |
| if assign: | |
| for (key, value) in state_dict.items(): | |
| if key in ['mean', 'std']: | |
| self.register_buffer(key, value) | |
| result = None | |
| else: | |
| result = super().load_state_dict(state_dict, strict=strict, assign=False) | |
| return result | |
| class LogTransform(Standardize): | |
| def forward(self, x): | |
| return torch.exp(super().forward(x)) | |
| def inverse(self, x): | |
| return super().inverse(torch.log(x)) | |
| def _fit(self, target): | |
| return super()._fit(torch.log(target)) | |
| class TokenPairwiseDistance(nn.Module): | |
| def __init__(self, embed_dim, dropout=0.2, num_attention_heads=1, num_layers=1, activation='relu', ff_ratio=2): | |
| super().__init__() | |
| enc_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_attention_heads, dim_feedforward=ff_ratio * embed_dim, dropout=dropout, batch_first=True, norm_first=True) | |
| self.interaction = nn.TransformerEncoder(enc_layer, num_layers) | |
| self.pairwise_distance = PairwiseMLP(embed_dim, dropout) | |
| self.distance1 = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU()) | |
| self.distance2 = nn.Linear(embed_dim, 1) | |
| def forward(self, hs): | |
| hs = self.interaction(hs) | |
| with torch.autocast('cuda', dtype=torch.float32): | |
| pw_dist = self.pairwise_distance(hs) | |
| d = self.distance1(pw_dist) + pw_dist | |
| d = self.distance2(d).squeeze(-1) | |
| return F.relu(F.elu(d) + 1) | |
| class TokenTaskHead(nn.Module): | |
| def __init__(self, embed_dim, output_size=1, dropout=0.2): | |
| super().__init__() | |
| self.layers = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU(), nn.Linear(embed_dim, output_size)) | |
| def forward(self, emb): | |
| return self.layers(emb) | |
| def annotate_prediction(y, channels): | |
| out = {} | |
| for (idx, chn) in enumerate(channels): | |
| channel_info = {f: v for (f, v) in chn.items() if f != 'name'} | |
| out[chn['name']] = {'value': y[:, idx], **channel_info} | |
| return out | |
| def build_encoder_from_dict(enc_dict): | |
| if 'model_type' in enc_dict: | |
| cfg_cls = AutoConfig.for_model(enc_dict['model_type']) | |
| enc_cfg = cfg_cls.from_dict(enc_dict, strict=False) | |
| elif '_name_or_path' in enc_dict: | |
| enc_cfg = AutoConfig.from_pretrained(enc_dict['_name_or_path'], strict=False) | |
| else: | |
| raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.") | |
| if hasattr(enc_cfg, 'add_pooling_layer'): | |
| enc_cfg.add_pooling_layer = False | |
| return AutoModel.from_config(enc_cfg) | |
| def maybe_get_annotated_channels(channels): | |
| for chn in channels: | |
| if isinstance(chn, str): | |
| yield {'name': chn, 'description': None, 'unit': None} | |
| else: | |
| yield chn | |
| class MISTMultiTaskConfig(PretrainedConfig): | |
| """HuggingFace config for a multi-task MIST wrapper.""" | |
| model_type = 'mist_multitask' | |
| def __init__(self, encoder=None, task_networks=None, transforms=None, channels=None, tokenizer_class='SmirkTokenizer', **kwargs): | |
| super().__init__(**kwargs) | |
| self.encoder = encoder or {} | |
| self.task_networks = task_networks or [] | |
| self.transforms = transforms or [] | |
| self.channels = channels | |
| self.tokenizer_class = tokenizer_class | |
| class MISTMultiTask(PreTrainedModel): | |
| config_class = MISTMultiTaskConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = build_encoder_from_dict(config.encoder) | |
| self.task_networks = nn.ModuleList([PredictionTaskHead(embed_dim=tn['embed_dim'], output_size=tn['output_size'], dropout=tn['dropout']) for tn in config.task_networks]) | |
| self.transforms = nn.ModuleList([AbstractNormalizer.get(tf_cfg['class'], tf_cfg['num_outputs']) for tf_cfg in config.transforms]) | |
| assert len(self.task_networks) == len(self.transforms), 'task_networks and transforms must align' | |
| self.channels = config.channels | |
| self.tokenizer = self._resolve_tokenizer() | |
| self.post_init() | |
| def from_components(cls, encoder, task_networks, transforms, tokenizer=None, channels=None): | |
| cfg = MISTMultiTaskConfig(encoder=encoder.config.to_dict(), task_networks=[{'embed_dim': encoder.config.hidden_size, 'output_size': tn.final.out_features, 'dropout': tn.dropout1.p} for tn in task_networks], transforms=[tf.to_config() for tf in transforms], channels=channels, tokenizer_class=getattr(tokenizer, '__class__', type('T', (), {})).__name__ if tokenizer else 'SmirkTokenizer') | |
| model = cls(cfg) | |
| model.encoder.load_state_dict(encoder.state_dict(), strict=False) | |
| for (dst, src) in zip(model.task_networks, task_networks): | |
| dst.load_state_dict(src.state_dict()) | |
| for (dst, src) in zip(model.transforms, transforms): | |
| dst.load_state_dict(src.state_dict()) | |
| model.tokenizer = tokenizer | |
| return model | |
| def forward(self, input_ids, attention_mask=None): | |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state | |
| outs = [] | |
| for (tn, tf) in zip(self.task_networks, self.transforms): | |
| outs.append(tf.forward(tn(hs))) | |
| return torch.cat(outs, dim=-1) | |
| def _resolve_tokenizer(self, tokenizer=None): | |
| if tokenizer is not None: | |
| return tokenizer | |
| if getattr(self, 'tokenizer', None) is not None: | |
| return self.tokenizer | |
| if self.name_or_path and '/' in self.name_or_path: | |
| try: | |
| return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True, trust_remote_code=True) | |
| except Exception: | |
| pass | |
| if hasattr(self.config, '_name_or_path') and self.config._name_or_path and ('/' in self.config._name_or_path): | |
| try: | |
| return AutoTokenizer.from_pretrained(self.config._name_or_path, use_fast=True, trust_remote_code=True) | |
| except Exception: | |
| pass | |
| return None | |
| def predict(self, smi, tokenizer=None): | |
| batch = self.tokenizer(smi) | |
| batch = DataCollatorWithPadding(self.tokenizer)(batch) | |
| inputs = {k: v.to(self.device) for (k, v) in batch.items()} | |
| with torch.inference_mode(): | |
| out = self(**inputs).cpu() | |
| if self.channels is None: | |
| return out | |
| return annotate_prediction(out, self.channels) | |
| def embed(self, smi, tokenizer=None): | |
| batch = self.tokenizer(smi) | |
| batch = DataCollatorWithPadding(self.tokenizer)(batch) | |
| input_ids = batch['input_ids'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| with torch.inference_mode(): | |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] | |
| return hs.to('cpu') | |
| def save_pretrained(self, save_directory, **kwargs): | |
| super().save_pretrained(save_directory, **kwargs) | |
| if getattr(self, 'tokenizer', None) is not None: | |
| self.tokenizer.save_pretrained(save_directory) |