mist-28M-0uiq7o7m-freesolv / modeling_mist_finetuned.py
anoushka2000's picture
Upload folder using huggingface_hub
3a6929c verified
from datasets import IterableDataset
from pathlib import Path
from smirk import SmirkTokenizerFast
from torch import nn
from torch.masked import MaskedTensor, masked_tensor
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
DataCollatorWithPadding,
PreTrainedModel,
PretrainedConfig,
)
from typing import Any, Callable, Optional, Union
from typing import Any, Dict, List, Optional
import json
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast)
MODEL_TYPE_ALIASES = {}
def build_encoder(enc_dict: Dict[str, Any]):
mtype = enc_dict.get("model_type")
if mtype:
base = MODEL_TYPE_ALIASES.get(mtype, mtype)
cfg_cls = AutoConfig.for_model(base)
enc_cfg = cfg_cls.from_dict(enc_dict)
elif enc_dict.get("_name_or_path"):
enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"])
else:
raise KeyError("encoder config missing 'model_type' or '_name_or_path'")
if hasattr(enc_cfg, "add_pooling_layer"):
enc_cfg.add_pooling_layer = False
return AutoModel.from_config(enc_cfg)
class MISTFinetunedConfig(PretrainedConfig):
"""HF config for a single-task MIST wrapper."""
model_type = "mist_finetuned"
def __init__(
self,
encoder: Optional[Dict[str, Any]] = None,
task_network: Optional[Dict[str, Any]] = None,
transform: Optional[Dict[str, Any]] = None,
channels: Optional[List[Dict[str, Any]]] = None,
tokenizer_class: Optional[str] = "SmirkTokenizer",
**kwargs,
):
super().__init__(**kwargs)
self.encoder = encoder or {}
self.task_network = task_network or {}
self.transform = transform or {}
self.channels = channels
self.tokenizer_class = tokenizer_class
class MISTFinetuned(PreTrainedModel):
config_class = MISTFinetunedConfig
def __init__(self, config: MISTFinetunedConfig):
super().__init__(config)
self.encoder = build_encoder_from_dict(config.encoder)
tn = config.task_network
self.task_network = PredictionTaskHead(
embed_dim=tn["embed_dim"],
output_size=tn["output_size"],
dropout=tn["dropout"],
)
self.transform = AbstractNormalizer.get(
config.transform["class"], config.transform["num_outputs"]
)
self.channels = config.channels
self.tokenizer = None
self.post_init()
@classmethod
def from_components(
cls,
encoder: PreTrainedModel,
task_network: nn.Module,
transform: Any,
tokenizer: Optional[Any] = None,
channels: Optional[List[Dict[str, Any]]] = None,
) -> "MISTFinetuned":
cfg = MISTFinetunedConfig(
encoder=encoder.config.to_dict(),
task_network={
"embed_dim": encoder.config.hidden_size,
"output_size": task_network.final.out_features,
"dropout": task_network.dropout1.p,
},
transform=transform.to_config(),
channels=channels,
tokenizer_class=(
getattr(tokenizer, "__class__", type("T", (), {})).__name__
if tokenizer
else "SmirkTokenizer"
),
)
model = cls(cfg)
# load component weights
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
model.task_network.load_state_dict(task_network.state_dict())
model.transform.load_state_dict(transform.state_dict())
model.tokenizer = tokenizer
return model
def forward(self, input_ids, attention_mask=None):
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
y = self.task_network(hs)
return self.transform.forward(y)
def _resolve_tokenizer(self, tokenizer):
if tokenizer is not None:
return tokenizer
if getattr(self, "tokenizer", None) is not None:
return self.tokenizer
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)
def embed(self, smi: List[str], tokenizer=None):
tok = self._resolve_tokenizer(tokenizer)
batch = tok(smi)
batch = DataCollatorWithPadding(tok)(batch)
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
with torch.inference_mode():
hs = self.encoder(
input_ids, attention_mask=attention_mask
).last_hidden_state[:, 0, :]
return hs.to("cpu")
def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None):
tok = self._resolve_tokenizer(tokenizer)
batch = tok(smi)
collate_fn = DataCollatorWithPadding(tok)
batch = collate_fn(batch)
batch = {
"input_ids": batch["input_ids"].to(self.encoder.device),
"attention_mask": batch["attention_mask"].to(self.encoder.device),
}
with torch.inference_mode():
out = self(**batch).cpu()
if self.channels is None or not return_dict:
return out
return annotate_prediction(out, maybe_get_annotated_channels(self.channels))
def save_pretrained(self, save_directory, **kwargs):
super().save_pretrained(save_directory, **kwargs)
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.save_pretrained(save_directory)
def maybe_get_annotated_channels(channels: List[Any]):
for chn in channels:
if isinstance(chn, str):
yield {"name": chn, "description": None, "unit": None}
else:
yield chn
def annotate_prediction(
y: torch.Tensor, channels: List[Dict[str, str]]
) -> Dict[str, Dict[str, Any]]:
out: Dict[str, Dict[str, Any]] = {}
for idx, chn in enumerate(channels):
channel_info = {f: v for f, v in chn.items() if f != "name"}
out[chn["name"]] = {"value": y[:, idx], **channel_info}
return out
def build_encoder_from_dict(enc_dict):
if "model_type" in enc_dict:
cfg_cls = AutoConfig.for_model(enc_dict["model_type"])
enc_cfg = cfg_cls.from_dict(enc_dict, strict=False)
elif "_name_or_path" in enc_dict:
enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"], strict=False)
else:
raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.")
# Ensure pooling layer is disabled to match saved checkpoints
if hasattr(enc_cfg, "add_pooling_layer"):
enc_cfg.add_pooling_layer = False
return AutoModel.from_config(enc_cfg)
class MISTMultiTaskConfig(PretrainedConfig):
"""HuggingFace config for a multi-task MIST wrapper."""
model_type = "mist_multitask"
def __init__(
self,
encoder: Optional[Dict[str, Any]] = None,
task_networks: Optional[List[Dict[str, Any]]] = None,
transforms: Optional[List[Dict[str, Any]]] = None,
channels: Optional[List[Dict[str, Any]]] = None,
tokenizer_class: Optional[str] = "SmirkTokenizer",
**kwargs,
):
super().__init__(**kwargs)
self.encoder = encoder or {}
self.task_networks = task_networks or []
self.transforms = transforms or []
self.channels = channels
self.tokenizer_class = tokenizer_class
class MISTMultiTask(PreTrainedModel):
config_class = MISTMultiTaskConfig
def __init__(self, config: MISTMultiTaskConfig):
super().__init__(config)
self.encoder = build_encoder_from_dict(config.encoder)
self.task_networks = nn.ModuleList(
[
PredictionTaskHead(
embed_dim=tn["embed_dim"],
output_size=tn["output_size"],
dropout=tn["dropout"],
)
for tn in config.task_networks
]
)
self.transforms = nn.ModuleList(
[
AbstractNormalizer.get(tf_cfg["class"], tf_cfg["num_outputs"])
for tf_cfg in config.transforms
]
)
assert len(self.task_networks) == len(
self.transforms
), "task_networks and transforms must align"
self.channels = config.channels
self.tokenizer = None
self.post_init()
@classmethod
def from_components(
cls,
encoder: PreTrainedModel,
task_networks: List[nn.Module],
transforms: List[Any],
tokenizer: Optional[Any] = None,
channels: Optional[List[Dict[str, Any]]] = None,
) -> "MISTMultiTask":
cfg = MISTMultiTaskConfig(
encoder=encoder.config.to_dict(),
task_networks=[
{
"embed_dim": encoder.config.hidden_size,
"output_size": tn.final.out_features,
"dropout": tn.dropout1.p,
}
for tn in task_networks
],
transforms=[tf.to_config() for tf in transforms],
channels=channels,
tokenizer_class=(
getattr(tokenizer, "__class__", type("T", (), {})).__name__
if tokenizer
else "SmirkTokenizer"
),
)
model = cls(cfg)
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
for dst, src in zip(model.task_networks, task_networks):
dst.load_state_dict(src.state_dict())
for dst, src in zip(model.transforms, transforms):
dst.load_state_dict(src.state_dict())
model.tokenizer = tokenizer
return model
def forward(self, input_ids, attention_mask=None):
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
outs = []
for tn, tf in zip(self.task_networks, self.transforms):
outs.append(tf.forward(tn(hs)))
return torch.cat(outs, dim=-1)
def _resolve_tokenizer(self, tokenizer):
if tokenizer is not None:
return tokenizer
if getattr(self, "tokenizer", None) is not None:
return self.tokenizer
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)
def predict(self, smi: List[str], tokenizer=None):
tok = self._resolve_tokenizer(tokenizer)
batch = tok(smi)
batch = DataCollatorWithPadding(tok)(batch)
inputs = {k: v.to(self.device) for k, v in batch.items()}
with torch.inference_mode():
out = self(**inputs).cpu()
if self.channels is None:
return out
return annotate_prediction(out, maybe_get_annotated_channels(self.channels))
def embed(self, smi: List[str], tokenizer=None):
tok = self._resolve_tokenizer(tokenizer)
batch = tok(smi)
batch = DataCollatorWithPadding(tok)(batch)
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
with torch.inference_mode():
hs = self.encoder(
input_ids, attention_mask=attention_mask
).last_hidden_state[:, 0, :]
return hs.to("cpu")
def save_pretrained(self, save_directory, **kwargs):
super().save_pretrained(save_directory, **kwargs)
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.save_pretrained(save_directory)
class PredictionTaskHead(nn.Module):
def __init__(
self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
) -> None:
super().__init__()
self.desc_skip_connection = True
self.fc1 = nn.Linear(embed_dim, embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.relu1 = nn.GELU()
self.fc2 = nn.Linear(embed_dim, embed_dim)
self.dropout2 = nn.Dropout(dropout)
self.relu2 = nn.GELU()
self.final = nn.Linear(embed_dim, output_size)
def forward(self, emb):
if emb.ndim > 2:
emb = emb[:, 0, :]
x_out = self.fc1(emb)
x_out = self.dropout1(x_out)
x_out = self.relu1(x_out)
if self.desc_skip_connection is True:
x_out = x_out + emb
z = self.fc2(x_out)
z = self.dropout2(z)
z = self.relu2(z)
if self.desc_skip_connection is True:
z = self.final(z + x_out)
else:
z = self.final(z)
return z
class AbstractNormalizer(torch.nn.Module):
def __init__(self, num_outputs: Optional[int] = None):
super().__init__()
self.num_outputs = num_outputs
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Remove normalization"""
raise NotImplementedError
def inverse(self, x: torch.Tensor) -> torch.Tensor:
"""Apply normalization"""
raise NotImplementedError
def _fit(self, x: MaskedTensor) -> dict:
"""Fit the normalization parameters"""
raise NotImplementedError
def to_config(self) -> dict:
return {"class": self.__class__.__name__, "num_outputs": self.num_outputs}
def leader_fit(self, ds, rank: int, broadcast: Callable):
state = None
if rank == 0:
state = self.fit(ds)
state = broadcast(state)
self.load_state_dict(state)
def fit(self, ds, name: str = "target") -> dict:
"""Fit the normalization parameters on dataset"""
if isinstance(ds, IterableDataset):
target = []
mask = []
for x in ds:
target.append(x[name])
mask.append(x[f"{name}_mask"])
target = torch.stack(target)
mask = torch.stack(mask)
else:
target = torch.stack([torch.tensor(x) for x in ds[name]])
mask = torch.stack([torch.tensor(x) for x in ds[f"{name}_mask"]])
# Use masked tensor to compute normalization parameters
target = masked_tensor(target, mask)
state = self._fit(target)
return state
@classmethod
def get(
cls, transform: Optional[Union[list[str], str]], num_outputs: int
) -> "AbstractNormalizer":
if isinstance(transform, list):
assert len(transform) == num_outputs
return ChannelWiseTransform([cls.get(t, 1) for t in transform])
elif transform in ["standardize", Standardize.__name__]:
return Standardize(num_outputs)
elif transform in ["power_transform", PowerTransform.__name__]:
return PowerTransform(num_outputs)
elif transform in ["log_transform", LogTransform.__name__]:
return LogTransform(num_outputs)
elif transform in ["max_scale", MaxScaleTransform.__name__]:
return MaxScaleTransform(num_outputs)
else:
return IdentityTransform()
class Standardize(AbstractNormalizer):
def __init__(self, num_outputs: int, eps: float = 1e-8):
super().__init__(num_outputs)
self.register_buffer("mean", torch.zeros(num_outputs))
self.register_buffer("std", torch.zeros(num_outputs))
self.eps = float(eps)
assert 0 <= self.eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (self.std * x) + self.mean
def inverse(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.mean) / self.std
def fit(self, ds, name: str = "target") -> dict:
num_outputs = self.num_outputs
assert num_outputs is not None
mean = torch.zeros(num_outputs)
m2 = torch.zeros(num_outputs)
n = torch.zeros(num_outputs, dtype=torch.int)
for row in ds:
target = torch.tensor(row[name])
mask = torch.tensor(row[f"{name}_mask"])
x = masked_tensor(target, mask)
n += mask.view(-1, num_outputs).sum(0)
xs = x.view(-1, num_outputs).sum(0)
delta = xs - mean
# Only update masked values
mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0)
delta2 = xs - mean
m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0)
self.mean = mean.to(self.mean)
self.std = (m2 / n).sqrt().to(self.std) + self.eps
self.mean[self.mean.isnan()] = 0
self.std[self.std.isnan()] = 1
logging.debug("Fitted %s", self.state_dict())
return self.state_dict()
def _fit(self, target: MaskedTensor) -> dict:
self.mean = target.mean(0).get_data().to(self.mean)
self.std = target.std(0).get_data().to(self.std) + self.eps
return self.state_dict()
def load_state_dict(
self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False
):
# Handle legacy case where keys have "transform." prefix
if "transform.mean" in state_dict:
state_dict = state_dict.copy() # Don't modify original
state_dict["mean"] = state_dict.pop("transform.mean")
state_dict["std"] = state_dict.pop("transform.std")
if assign:
# Manually assign buffers when assign=True
for key, value in state_dict.items():
if key in ["mean", "std"]:
# Use register_buffer to properly replace the buffer
self.register_buffer(key, value)
result = None # No incompatible keys when we do it manually
else:
result = super().load_state_dict(state_dict, strict=strict, assign=False)
logging.debug(f" After loading: mean={self.mean}, std={self.std}")
return result
class TokenTaskHead(nn.Module):
def __init__(
self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(embed_dim, output_size),
)
def forward(self, emb):
return self.layers(emb)
class TokenPairwiseDistance(nn.Module):
def __init__(
self,
embed_dim: int,
dropout: float = 0.2,
num_attention_heads: int = 1,
num_layers: int = 1,
activation: str = "relu",
ff_ratio: int = 2,
) -> None:
super().__init__()
enc_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_attention_heads,
dim_feedforward=ff_ratio * embed_dim,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.interaction = nn.TransformerEncoder(enc_layer, num_layers)
self.pairwise_distance = PairwiseMLP(embed_dim, dropout)
self.distance1 = nn.Sequential(
nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU()
)
self.distance2 = nn.Linear(embed_dim, 1)
def forward(self, hs: torch.Tensor) -> torch.Tensor:
hs = self.interaction(hs)
with torch.autocast("cuda", dtype=torch.float32):
pw_dist = self.pairwise_distance(hs)
d = self.distance1(pw_dist) + pw_dist
d = self.distance2(d).squeeze(-1)
return F.relu(F.elu(d) + 1)
class BiPairwiseBlock(nn.Module):
def __init__(self, d_model: int, bias: bool = True, device=None, dtype=None):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs))
self.lin_weight = nn.Parameter(
torch.empty((d_model, d_model), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
# Gradient hook to enforce symmetry
self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T))
def reset_parameters(self):
nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain("relu"))
with torch.no_grad():
self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T))
if self.bias is not None:
bound = 1 / math.sqrt(self.bias.size(0))
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor):
y_bi = torch.einsum("...ld,df,...rf->...lrf", x, self.bi_weight, x)
y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) # Enforce symmetry
x_linear = x.unsqueeze(-2) + x.unsqueeze(-3)
return y_bi + F.linear(x_linear, self.lin_weight, self.bias)
class PairwiseMLP(nn.Module):
def __init__(
self,
d_model: int,
dropout: float = 0.2,
device=None,
dtype=None,
) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(2 * d_model, d_model),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(d_model, d_model),
nn.GELU(),
)
def forward(self, x: torch.Tensor):
_, N, _ = x.shape
x_l = x.unsqueeze(-2).expand(-1, N, N, -1)
x_r = x.unsqueeze(-3).expand(-1, N, N, -1)
x_pw = torch.cat([x_l, x_r], dim=-1)
y = self.mlp(x_pw)
return 0.5 * (y + y.transpose(1, 2))
class ChannelWiseTransform(AbstractNormalizer):
def __init__(self, transforms: list[AbstractNormalizer]):
super().__init__(len(transforms))
self.transforms = torch.nn.ModuleList(transforms)
def to_config(self) -> dict:
return {
"class": [t.__class__.__name__ for t in self.transforms],
"num_outputs": self.num_outputs,
}
def inverse(self, x: torch.Tensor) -> torch.Tensor:
return torch.cat(
[
transform.inverse(x[:, [idx]])
for idx, transform in enumerate(self.transforms)
],
dim=1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cat(
[
transform.forward(x[:, [idx]])
for idx, transform in enumerate(self.transforms)
],
dim=1,
)
def _fit(self, x: MaskedTensor) -> dict:
for idx, transform in enumerate(self.transforms):
transform._fit(x[:, [idx]])
return self.state_dict()
class LogTransform(Standardize):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.exp(super().forward(x))
def inverse(self, x: torch.Tensor) -> torch.Tensor:
return super().inverse(torch.log(x))
def _fit(self, target: MaskedTensor) -> dict:
return super()._fit(torch.log(target))
class PowerTransform(AbstractNormalizer):
"""
Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like.
Followed by applying a zero-mean, unit-variance normalization to the
transformed output to rescale targets to [-1, 1].
"""
def __init__(self, num_outputs, eps: float = 1e-8):
super().__init__(num_outputs)
self.num_outputs = num_outputs
self.register_buffer("lmbdas", torch.zeros(num_outputs))
self.register_buffer("mean", torch.zeros(num_outputs))
self.register_buffer("std", torch.zeros(num_outputs))
self.eps = float(eps)
assert 0 <= self.eps
def _yeo_johnson_transform(self, x, lmbda):
"""
Return transformed input x following Yeo-Johnson transform with
parameter lambda.
Adapted from
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354
"""
x_out = x.clone()
eps = torch.finfo(x.dtype).eps
pos = x >= 0 # binary mask
# when x >= 0
if abs(lmbda) < eps:
x_out[pos] = torch.log1p(x[pos])
else: # lmbda != 0
x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda
# when x < 0
if abs(lmbda - 2) > eps:
x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
else: # lmbda == 2
x_out[~pos] = -torch.log1p(-x[~pos])
return x_out
def _yeo_johnson_inverse_transform(self, x, lmbda):
"""
Return inverse-transformed input x following Yeo-Johnson inverse
transform with parameter lambda.
Adapted from
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383
"""
x_out = x.clone()
pos = x >= 0
eps = torch.finfo(x.dtype).eps
# when x >= 0
if abs(lmbda) < eps: # lmbda == 0
x_out[pos] = torch.exp(x[pos]) - 1
else: # lmbda != 0
x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1
# when x < 0
if abs(lmbda - 2) > eps: # lmbda != 2
x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
else: # lmbda == 2
x_out[~pos] = 1 - torch.exp(-x[~pos])
return x_out
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Undo standardization
x = (self.std * x) + self.mean
x_out = torch.zeros_like(x)
for i in range(self.num_outputs):
x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i])
return x_out
def inverse(self, x: torch.Tensor) -> torch.Tensor:
x_out = torch.zeros_like(x)
for i in range(self.num_outputs):
x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i])
# Standardization
x_out = (x_out - self.mean) / self.std
return x_out
def _fit(self, target: MaskedTensor) -> dict:
# Fit Yeo-Johnson lambdas
from sklearn.preprocessing import (
PowerTransformer as _PowerTransformer, # noqa: F811
)
transformer = _PowerTransformer(method="yeo-johnson", standardize=False)
target = torch.tensor(transformer.fit_transform(target.get_data().numpy()))
self.lmbdas = torch.tensor(transformer.lambdas_)
# Fit standardization scaling
self.mean = target.mean(0).to(self.mean)
self.std = target.std(0).to(self.std) + self.eps
return self.state_dict()
class MaxScaleTransform(AbstractNormalizer):
"""
Divide by maximum value in training dataset.
"""
def __init__(self, mx: int, eps: float = 1e-8):
super().__init__(1)
self.num_outputs = 1
self.max = mx
self.eps = float(eps)
assert 0 <= self.eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Undo standardization
x_out = self.max * x
return x_out
def inverse(self, x: torch.Tensor) -> torch.Tensor:
x_out = x / self.max
return x_out
def _fit(self, target: MaskedTensor) -> dict:
return self.state_dict()
class IdentityTransform(AbstractNormalizer):
def inverse(self, x: torch.Tensor) -> torch.Tensor:
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
def _fit(self, x: MaskedTensor) -> dict:
return self.state_dict()