| from datasets import IterableDataset |
| from pathlib import Path |
| from smirk import SmirkTokenizerFast |
| from torch import nn |
| from torch.masked import MaskedTensor, masked_tensor |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoTokenizer, |
| DataCollatorWithPadding, |
| PreTrainedModel, |
| PretrainedConfig, |
| ) |
| from typing import Any, Callable, Optional, Union |
| from typing import Any, Dict, List, Optional |
| import json |
| import logging |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast) |
| MODEL_TYPE_ALIASES = {} |
|
|
| def build_encoder(enc_dict: Dict[str, Any]): |
| mtype = enc_dict.get("model_type") |
| if mtype: |
| base = MODEL_TYPE_ALIASES.get(mtype, mtype) |
| cfg_cls = AutoConfig.for_model(base) |
| enc_cfg = cfg_cls.from_dict(enc_dict) |
| elif enc_dict.get("_name_or_path"): |
| enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"]) |
| else: |
| raise KeyError("encoder config missing 'model_type' or '_name_or_path'") |
| if hasattr(enc_cfg, "add_pooling_layer"): |
| enc_cfg.add_pooling_layer = False |
| return AutoModel.from_config(enc_cfg) |
|
|
|
|
| class MISTFinetunedConfig(PretrainedConfig): |
| """HF config for a single-task MIST wrapper.""" |
|
|
| model_type = "mist_finetuned" |
|
|
| def __init__( |
| self, |
| encoder: Optional[Dict[str, Any]] = None, |
| task_network: Optional[Dict[str, Any]] = None, |
| transform: Optional[Dict[str, Any]] = None, |
| channels: Optional[List[Dict[str, Any]]] = None, |
| tokenizer_class: Optional[str] = "SmirkTokenizer", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.encoder = encoder or {} |
| self.task_network = task_network or {} |
| self.transform = transform or {} |
| self.channels = channels |
| self.tokenizer_class = tokenizer_class |
|
|
| class MISTFinetuned(PreTrainedModel): |
| config_class = MISTFinetunedConfig |
|
|
| def __init__(self, config: MISTFinetunedConfig): |
| super().__init__(config) |
| self.encoder = build_encoder_from_dict(config.encoder) |
|
|
| tn = config.task_network |
| self.task_network = PredictionTaskHead( |
| embed_dim=tn["embed_dim"], |
| output_size=tn["output_size"], |
| dropout=tn["dropout"], |
| ) |
| self.transform = AbstractNormalizer.get( |
| config.transform["class"], config.transform["num_outputs"] |
| ) |
| self.channels = config.channels |
| self.tokenizer = None |
| self.post_init() |
|
|
| @classmethod |
| def from_components( |
| cls, |
| encoder: PreTrainedModel, |
| task_network: nn.Module, |
| transform: Any, |
| tokenizer: Optional[Any] = None, |
| channels: Optional[List[Dict[str, Any]]] = None, |
| ) -> "MISTFinetuned": |
| cfg = MISTFinetunedConfig( |
| encoder=encoder.config.to_dict(), |
| task_network={ |
| "embed_dim": encoder.config.hidden_size, |
| "output_size": task_network.final.out_features, |
| "dropout": task_network.dropout1.p, |
| }, |
| transform=transform.to_config(), |
| channels=channels, |
| tokenizer_class=( |
| getattr(tokenizer, "__class__", type("T", (), {})).__name__ |
| if tokenizer |
| else "SmirkTokenizer" |
| ), |
| ) |
| model = cls(cfg) |
| |
| model.encoder.load_state_dict(encoder.state_dict(), strict=False) |
| model.task_network.load_state_dict(task_network.state_dict()) |
| model.transform.load_state_dict(transform.state_dict()) |
| model.tokenizer = tokenizer |
| return model |
|
|
| def forward(self, input_ids, attention_mask=None): |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state |
| y = self.task_network(hs) |
| return self.transform.forward(y) |
|
|
| def _resolve_tokenizer(self, tokenizer): |
| if tokenizer is not None: |
| return tokenizer |
| if getattr(self, "tokenizer", None) is not None: |
| return self.tokenizer |
| try: |
| return AutoTokenizer.from_pretrained( |
| self.name_or_path, use_fast=True, trust_remote_code=True |
| ) |
| except Exception: |
| return AutoTokenizer.from_pretrained( |
| self.config._name_or_path, use_fast=True, trust_remote_code=True |
| ) |
|
|
| def embed(self, smi: List[str], tokenizer=None): |
| tok = self._resolve_tokenizer(tokenizer) |
| batch = tok(smi) |
| batch = DataCollatorWithPadding(tok)(batch) |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| with torch.inference_mode(): |
| hs = self.encoder( |
| input_ids, attention_mask=attention_mask |
| ).last_hidden_state[:, 0, :] |
| return hs.to("cpu") |
|
|
| def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None): |
| tok = self._resolve_tokenizer(tokenizer) |
| batch = tok(smi) |
| collate_fn = DataCollatorWithPadding(tok) |
| batch = collate_fn(batch) |
| batch = { |
| "input_ids": batch["input_ids"].to(self.encoder.device), |
| "attention_mask": batch["attention_mask"].to(self.encoder.device), |
| } |
| with torch.inference_mode(): |
| out = self(**batch).cpu() |
| if self.channels is None or not return_dict: |
| return out |
| return annotate_prediction(out, maybe_get_annotated_channels(self.channels)) |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| super().save_pretrained(save_directory, **kwargs) |
| if getattr(self, "tokenizer", None) is not None: |
| self.tokenizer.save_pretrained(save_directory) |
|
|
| def maybe_get_annotated_channels(channels: List[Any]): |
| for chn in channels: |
| if isinstance(chn, str): |
| yield {"name": chn, "description": None, "unit": None} |
| else: |
| yield chn |
|
|
| def annotate_prediction( |
| y: torch.Tensor, channels: List[Dict[str, str]] |
| ) -> Dict[str, Dict[str, Any]]: |
| out: Dict[str, Dict[str, Any]] = {} |
| for idx, chn in enumerate(channels): |
| channel_info = {f: v for f, v in chn.items() if f != "name"} |
| out[chn["name"]] = {"value": y[:, idx], **channel_info} |
| return out |
|
|
| def build_encoder_from_dict(enc_dict): |
| if "model_type" in enc_dict: |
| cfg_cls = AutoConfig.for_model(enc_dict["model_type"]) |
| enc_cfg = cfg_cls.from_dict(enc_dict, strict=False) |
| elif "_name_or_path" in enc_dict: |
| enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"], strict=False) |
| else: |
| raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.") |
|
|
| |
| if hasattr(enc_cfg, "add_pooling_layer"): |
| enc_cfg.add_pooling_layer = False |
|
|
| return AutoModel.from_config(enc_cfg) |
|
|
| class MISTMultiTaskConfig(PretrainedConfig): |
| """HuggingFace config for a multi-task MIST wrapper.""" |
|
|
| model_type = "mist_multitask" |
|
|
| def __init__( |
| self, |
| encoder: Optional[Dict[str, Any]] = None, |
| task_networks: Optional[List[Dict[str, Any]]] = None, |
| transforms: Optional[List[Dict[str, Any]]] = None, |
| channels: Optional[List[Dict[str, Any]]] = None, |
| tokenizer_class: Optional[str] = "SmirkTokenizer", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.encoder = encoder or {} |
| self.task_networks = task_networks or [] |
| self.transforms = transforms or [] |
| self.channels = channels |
| self.tokenizer_class = tokenizer_class |
|
|
| class MISTMultiTask(PreTrainedModel): |
| config_class = MISTMultiTaskConfig |
|
|
| def __init__(self, config: MISTMultiTaskConfig): |
| super().__init__(config) |
| self.encoder = build_encoder_from_dict(config.encoder) |
|
|
| self.task_networks = nn.ModuleList( |
| [ |
| PredictionTaskHead( |
| embed_dim=tn["embed_dim"], |
| output_size=tn["output_size"], |
| dropout=tn["dropout"], |
| ) |
| for tn in config.task_networks |
| ] |
| ) |
| self.transforms = nn.ModuleList( |
| [ |
| AbstractNormalizer.get(tf_cfg["class"], tf_cfg["num_outputs"]) |
| for tf_cfg in config.transforms |
| ] |
| ) |
|
|
| assert len(self.task_networks) == len( |
| self.transforms |
| ), "task_networks and transforms must align" |
| self.channels = config.channels |
| self.tokenizer = None |
| self.post_init() |
|
|
| @classmethod |
| def from_components( |
| cls, |
| encoder: PreTrainedModel, |
| task_networks: List[nn.Module], |
| transforms: List[Any], |
| tokenizer: Optional[Any] = None, |
| channels: Optional[List[Dict[str, Any]]] = None, |
| ) -> "MISTMultiTask": |
| cfg = MISTMultiTaskConfig( |
| encoder=encoder.config.to_dict(), |
| task_networks=[ |
| { |
| "embed_dim": encoder.config.hidden_size, |
| "output_size": tn.final.out_features, |
| "dropout": tn.dropout1.p, |
| } |
| for tn in task_networks |
| ], |
| transforms=[tf.to_config() for tf in transforms], |
| channels=channels, |
| tokenizer_class=( |
| getattr(tokenizer, "__class__", type("T", (), {})).__name__ |
| if tokenizer |
| else "SmirkTokenizer" |
| ), |
| ) |
| model = cls(cfg) |
| model.encoder.load_state_dict(encoder.state_dict(), strict=False) |
| for dst, src in zip(model.task_networks, task_networks): |
| dst.load_state_dict(src.state_dict()) |
| for dst, src in zip(model.transforms, transforms): |
| dst.load_state_dict(src.state_dict()) |
| model.tokenizer = tokenizer |
| return model |
|
|
| def forward(self, input_ids, attention_mask=None): |
| hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state |
| outs = [] |
| for tn, tf in zip(self.task_networks, self.transforms): |
| outs.append(tf.forward(tn(hs))) |
| return torch.cat(outs, dim=-1) |
|
|
| def _resolve_tokenizer(self, tokenizer): |
| if tokenizer is not None: |
| return tokenizer |
| if getattr(self, "tokenizer", None) is not None: |
| return self.tokenizer |
| try: |
| return AutoTokenizer.from_pretrained( |
| self.name_or_path, use_fast=True, trust_remote_code=True |
| ) |
| except Exception: |
| return AutoTokenizer.from_pretrained( |
| self.config._name_or_path, use_fast=True, trust_remote_code=True |
| ) |
|
|
| def predict(self, smi: List[str], tokenizer=None): |
| tok = self._resolve_tokenizer(tokenizer) |
| batch = tok(smi) |
| batch = DataCollatorWithPadding(tok)(batch) |
| inputs = {k: v.to(self.device) for k, v in batch.items()} |
| with torch.inference_mode(): |
| out = self(**inputs).cpu() |
| if self.channels is None: |
| return out |
| return annotate_prediction(out, maybe_get_annotated_channels(self.channels)) |
|
|
| def embed(self, smi: List[str], tokenizer=None): |
| tok = self._resolve_tokenizer(tokenizer) |
| batch = tok(smi) |
| batch = DataCollatorWithPadding(tok)(batch) |
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| with torch.inference_mode(): |
| hs = self.encoder( |
| input_ids, attention_mask=attention_mask |
| ).last_hidden_state[:, 0, :] |
| return hs.to("cpu") |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| super().save_pretrained(save_directory, **kwargs) |
| if getattr(self, "tokenizer", None) is not None: |
| self.tokenizer.save_pretrained(save_directory) |
|
|
| class PredictionTaskHead(nn.Module): |
| def __init__( |
| self, embed_dim: int, output_size: int = 1, dropout: float = 0.2 |
| ) -> None: |
| super().__init__() |
| self.desc_skip_connection = True |
|
|
| self.fc1 = nn.Linear(embed_dim, embed_dim) |
| self.dropout1 = nn.Dropout(dropout) |
| self.relu1 = nn.GELU() |
| self.fc2 = nn.Linear(embed_dim, embed_dim) |
| self.dropout2 = nn.Dropout(dropout) |
| self.relu2 = nn.GELU() |
| self.final = nn.Linear(embed_dim, output_size) |
|
|
| def forward(self, emb): |
| if emb.ndim > 2: |
| emb = emb[:, 0, :] |
| x_out = self.fc1(emb) |
| x_out = self.dropout1(x_out) |
| x_out = self.relu1(x_out) |
|
|
| if self.desc_skip_connection is True: |
| x_out = x_out + emb |
|
|
| z = self.fc2(x_out) |
| z = self.dropout2(z) |
| z = self.relu2(z) |
| if self.desc_skip_connection is True: |
| z = self.final(z + x_out) |
| else: |
| z = self.final(z) |
| return z |
|
|
| class AbstractNormalizer(torch.nn.Module): |
| def __init__(self, num_outputs: Optional[int] = None): |
| super().__init__() |
| self.num_outputs = num_outputs |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Remove normalization""" |
| raise NotImplementedError |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| """Apply normalization""" |
| raise NotImplementedError |
|
|
| def _fit(self, x: MaskedTensor) -> dict: |
| """Fit the normalization parameters""" |
| raise NotImplementedError |
|
|
| def to_config(self) -> dict: |
| return {"class": self.__class__.__name__, "num_outputs": self.num_outputs} |
|
|
| def leader_fit(self, ds, rank: int, broadcast: Callable): |
| state = None |
| if rank == 0: |
| state = self.fit(ds) |
| state = broadcast(state) |
| self.load_state_dict(state) |
|
|
| def fit(self, ds, name: str = "target") -> dict: |
| """Fit the normalization parameters on dataset""" |
| if isinstance(ds, IterableDataset): |
| target = [] |
| mask = [] |
| for x in ds: |
| target.append(x[name]) |
| mask.append(x[f"{name}_mask"]) |
|
|
| target = torch.stack(target) |
| mask = torch.stack(mask) |
|
|
| else: |
| target = torch.stack([torch.tensor(x) for x in ds[name]]) |
| mask = torch.stack([torch.tensor(x) for x in ds[f"{name}_mask"]]) |
|
|
| |
| target = masked_tensor(target, mask) |
|
|
| state = self._fit(target) |
| return state |
|
|
| @classmethod |
| def get( |
| cls, transform: Optional[Union[list[str], str]], num_outputs: int |
| ) -> "AbstractNormalizer": |
| if isinstance(transform, list): |
| assert len(transform) == num_outputs |
| return ChannelWiseTransform([cls.get(t, 1) for t in transform]) |
| elif transform in ["standardize", Standardize.__name__]: |
| return Standardize(num_outputs) |
| elif transform in ["power_transform", PowerTransform.__name__]: |
| return PowerTransform(num_outputs) |
| elif transform in ["log_transform", LogTransform.__name__]: |
| return LogTransform(num_outputs) |
| elif transform in ["max_scale", MaxScaleTransform.__name__]: |
| return MaxScaleTransform(num_outputs) |
| else: |
| return IdentityTransform() |
|
|
| class Standardize(AbstractNormalizer): |
| def __init__(self, num_outputs: int, eps: float = 1e-8): |
| super().__init__(num_outputs) |
| self.register_buffer("mean", torch.zeros(num_outputs)) |
| self.register_buffer("std", torch.zeros(num_outputs)) |
| self.eps = float(eps) |
| assert 0 <= self.eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return (self.std * x) + self.mean |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| return (x - self.mean) / self.std |
|
|
| def fit(self, ds, name: str = "target") -> dict: |
| num_outputs = self.num_outputs |
| assert num_outputs is not None |
| mean = torch.zeros(num_outputs) |
| m2 = torch.zeros(num_outputs) |
| n = torch.zeros(num_outputs, dtype=torch.int) |
| for row in ds: |
| target = torch.tensor(row[name]) |
| mask = torch.tensor(row[f"{name}_mask"]) |
| x = masked_tensor(target, mask) |
| n += mask.view(-1, num_outputs).sum(0) |
| xs = x.view(-1, num_outputs).sum(0) |
| delta = xs - mean |
| |
| mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0) |
| delta2 = xs - mean |
| m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0) |
|
|
| self.mean = mean.to(self.mean) |
| self.std = (m2 / n).sqrt().to(self.std) + self.eps |
| self.mean[self.mean.isnan()] = 0 |
| self.std[self.std.isnan()] = 1 |
| logging.debug("Fitted %s", self.state_dict()) |
| return self.state_dict() |
|
|
| def _fit(self, target: MaskedTensor) -> dict: |
| self.mean = target.mean(0).get_data().to(self.mean) |
| self.std = target.std(0).get_data().to(self.std) + self.eps |
| return self.state_dict() |
|
|
| def load_state_dict( |
| self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False |
| ): |
| |
| if "transform.mean" in state_dict: |
| state_dict = state_dict.copy() |
| state_dict["mean"] = state_dict.pop("transform.mean") |
| state_dict["std"] = state_dict.pop("transform.std") |
|
|
| if assign: |
| |
| for key, value in state_dict.items(): |
| if key in ["mean", "std"]: |
| |
| self.register_buffer(key, value) |
| result = None |
| else: |
| result = super().load_state_dict(state_dict, strict=strict, assign=False) |
|
|
| logging.debug(f" After loading: mean={self.mean}, std={self.std}") |
| return result |
|
|
| class TokenTaskHead(nn.Module): |
| def __init__( |
| self, embed_dim: int, output_size: int = 1, dropout: float = 0.2 |
| ) -> None: |
| super().__init__() |
| self.layers = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim), |
| nn.Dropout(dropout), |
| nn.GELU(), |
| nn.Linear(embed_dim, embed_dim), |
| nn.Dropout(dropout), |
| nn.GELU(), |
| nn.Linear(embed_dim, output_size), |
| ) |
|
|
| def forward(self, emb): |
| return self.layers(emb) |
|
|
| class TokenPairwiseDistance(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| dropout: float = 0.2, |
| num_attention_heads: int = 1, |
| num_layers: int = 1, |
| activation: str = "relu", |
| ff_ratio: int = 2, |
| ) -> None: |
| super().__init__() |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, |
| nhead=num_attention_heads, |
| dim_feedforward=ff_ratio * embed_dim, |
| dropout=dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.interaction = nn.TransformerEncoder(enc_layer, num_layers) |
| self.pairwise_distance = PairwiseMLP(embed_dim, dropout) |
| self.distance1 = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU() |
| ) |
| self.distance2 = nn.Linear(embed_dim, 1) |
|
|
| def forward(self, hs: torch.Tensor) -> torch.Tensor: |
| hs = self.interaction(hs) |
|
|
| with torch.autocast("cuda", dtype=torch.float32): |
| pw_dist = self.pairwise_distance(hs) |
| d = self.distance1(pw_dist) + pw_dist |
| d = self.distance2(d).squeeze(-1) |
| return F.relu(F.elu(d) + 1) |
|
|
| class BiPairwiseBlock(nn.Module): |
| def __init__(self, d_model: int, bias: bool = True, device=None, dtype=None): |
| super().__init__() |
| factory_kwargs = {"device": device, "dtype": dtype} |
|
|
| self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs)) |
| self.lin_weight = nn.Parameter( |
| torch.empty((d_model, d_model), **factory_kwargs) |
| ) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs)) |
| else: |
| self.register_parameter("bias", None) |
| self.reset_parameters() |
|
|
| |
| self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T)) |
|
|
| def reset_parameters(self): |
| nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain("relu")) |
| with torch.no_grad(): |
| self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T)) |
|
|
| if self.bias is not None: |
| bound = 1 / math.sqrt(self.bias.size(0)) |
| nn.init.uniform_(self.bias, -bound, bound) |
|
|
| def forward(self, x: torch.Tensor): |
| y_bi = torch.einsum("...ld,df,...rf->...lrf", x, self.bi_weight, x) |
| y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) |
|
|
| x_linear = x.unsqueeze(-2) + x.unsqueeze(-3) |
| return y_bi + F.linear(x_linear, self.lin_weight, self.bias) |
|
|
| class PairwiseMLP(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| dropout: float = 0.2, |
| device=None, |
| dtype=None, |
| ) -> None: |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(2 * d_model, d_model), |
| nn.Dropout(dropout), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| _, N, _ = x.shape |
| x_l = x.unsqueeze(-2).expand(-1, N, N, -1) |
| x_r = x.unsqueeze(-3).expand(-1, N, N, -1) |
| x_pw = torch.cat([x_l, x_r], dim=-1) |
| y = self.mlp(x_pw) |
| return 0.5 * (y + y.transpose(1, 2)) |
|
|
| class ChannelWiseTransform(AbstractNormalizer): |
| def __init__(self, transforms: list[AbstractNormalizer]): |
| super().__init__(len(transforms)) |
| self.transforms = torch.nn.ModuleList(transforms) |
|
|
| def to_config(self) -> dict: |
| return { |
| "class": [t.__class__.__name__ for t in self.transforms], |
| "num_outputs": self.num_outputs, |
| } |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.cat( |
| [ |
| transform.inverse(x[:, [idx]]) |
| for idx, transform in enumerate(self.transforms) |
| ], |
| dim=1, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.cat( |
| [ |
| transform.forward(x[:, [idx]]) |
| for idx, transform in enumerate(self.transforms) |
| ], |
| dim=1, |
| ) |
|
|
| def _fit(self, x: MaskedTensor) -> dict: |
| for idx, transform in enumerate(self.transforms): |
| transform._fit(x[:, [idx]]) |
| return self.state_dict() |
|
|
| class LogTransform(Standardize): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.exp(super().forward(x)) |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| return super().inverse(torch.log(x)) |
|
|
| def _fit(self, target: MaskedTensor) -> dict: |
| return super()._fit(torch.log(target)) |
|
|
| class PowerTransform(AbstractNormalizer): |
| """ |
| Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like. |
| Followed by applying a zero-mean, unit-variance normalization to the |
| transformed output to rescale targets to [-1, 1]. |
| """ |
|
|
| def __init__(self, num_outputs, eps: float = 1e-8): |
| super().__init__(num_outputs) |
| self.num_outputs = num_outputs |
| self.register_buffer("lmbdas", torch.zeros(num_outputs)) |
| self.register_buffer("mean", torch.zeros(num_outputs)) |
| self.register_buffer("std", torch.zeros(num_outputs)) |
| self.eps = float(eps) |
| assert 0 <= self.eps |
|
|
| def _yeo_johnson_transform(self, x, lmbda): |
| """ |
| Return transformed input x following Yeo-Johnson transform with |
| parameter lambda. |
| Adapted from |
| https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354 |
| """ |
| x_out = x.clone() |
| eps = torch.finfo(x.dtype).eps |
| pos = x >= 0 |
|
|
| |
| if abs(lmbda) < eps: |
| x_out[pos] = torch.log1p(x[pos]) |
| else: |
| x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda |
|
|
| |
| if abs(lmbda - 2) > eps: |
| x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda) |
| else: |
| x_out[~pos] = -torch.log1p(-x[~pos]) |
|
|
| return x_out |
|
|
| def _yeo_johnson_inverse_transform(self, x, lmbda): |
| """ |
| Return inverse-transformed input x following Yeo-Johnson inverse |
| transform with parameter lambda. |
| Adapted from |
| https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383 |
| """ |
| x_out = x.clone() |
| pos = x >= 0 |
| eps = torch.finfo(x.dtype).eps |
|
|
| |
| if abs(lmbda) < eps: |
| x_out[pos] = torch.exp(x[pos]) - 1 |
| else: |
| x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1 |
|
|
| |
| if abs(lmbda - 2) > eps: |
| x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda)) |
| else: |
| x_out[~pos] = 1 - torch.exp(-x[~pos]) |
| return x_out |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = (self.std * x) + self.mean |
| x_out = torch.zeros_like(x) |
| for i in range(self.num_outputs): |
| x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i]) |
| return x_out |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| x_out = torch.zeros_like(x) |
| for i in range(self.num_outputs): |
| x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i]) |
| |
| x_out = (x_out - self.mean) / self.std |
| return x_out |
|
|
| def _fit(self, target: MaskedTensor) -> dict: |
| |
| from sklearn.preprocessing import ( |
| PowerTransformer as _PowerTransformer, |
| ) |
|
|
| transformer = _PowerTransformer(method="yeo-johnson", standardize=False) |
| target = torch.tensor(transformer.fit_transform(target.get_data().numpy())) |
| self.lmbdas = torch.tensor(transformer.lambdas_) |
| |
| self.mean = target.mean(0).to(self.mean) |
| self.std = target.std(0).to(self.std) + self.eps |
| return self.state_dict() |
|
|
| class MaxScaleTransform(AbstractNormalizer): |
| """ |
| Divide by maximum value in training dataset. |
| """ |
|
|
| def __init__(self, mx: int, eps: float = 1e-8): |
| super().__init__(1) |
| self.num_outputs = 1 |
| self.max = mx |
| self.eps = float(eps) |
| assert 0 <= self.eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_out = self.max * x |
| return x_out |
|
|
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| x_out = x / self.max |
| return x_out |
|
|
| def _fit(self, target: MaskedTensor) -> dict: |
| return self.state_dict() |
|
|
| class IdentityTransform(AbstractNormalizer): |
| def inverse(self, x: torch.Tensor) -> torch.Tensor: |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x |
|
|
| def _fit(self, x: MaskedTensor) -> dict: |
| return self.state_dict() |