SafeTensors model for pcunwa/BS-Roformer-Inst-FNO to support inference with torch 2.6 and neuraloperator 2.0 above
Usage
To use this model, the mask estimator should be changed to FNO(1d).
from neuralop.models import FNO
class MaskEstimatorFNO(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
# change FNO1d to FNO to support neuraloperator 2.0
FNO(n_modes=(64,), hidden_channels=dim, in_channels=dim, out_channels=dim_in*2, lifting_channel_ratio=2, projection_channel_ratio=2, n_layers=3, separable=True),
nn.GLU(dim=-2)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
band_features = rearrange(band_features, 'b t c -> b c t')
with torch.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
freq_out = mlp(band_features).float()
freq_out = rearrange(freq_out, 'b c t -> b t c')
outs.append(freq_out)
return torch.cat(outs, dim=-1)
Don't forget to change the MaskEstimator to MaskEstimatorFNO in the BSRoformer Class.
If you want to keep the original MaskEstimator, there is also a parameter named use_mask_estimator_fno in yaml config file. you can add it in the BSRoformer Class init method, and use a simple if else statement to switch between the two mask estimators.
example:
class BSRoformer(Module):
@beartype
def __init__(
self,
#... other parameters
use_mask_estimator_fno: bool = True
):
#... other code
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
if use_mask_estimator_fno:
mask_estimator = MaskEstimatorFNO(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
)
else:
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth,
mlp_expansion_factor=mlp_expansion_factor,
)
self.mask_estimators.append(mask_estimator)
Also, it is a safetensor model. You need to use torch.safetensor.load_file rather than torch.load at utils/model_utils.py(MSST) or webui/msst_infer.py(MSST-WebUI).
from safetensors.torch import load_file
state_dict = load_file(model_path, device=self.device)
model.load_state_dict(state_dict)
Requirements
You need to install the neuraloperator>=2.0.0 and safetensors>=0.7.0 (need to support complex64 type(C64 )) library
pip install neuraloperator>=2.0.0 safetensors>=0.7.0