POSEIDON / poseidon_model.py
MashaMash's picture
Update poseidon_model.py
2c810c3 verified
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from huggingface_hub import hf_hub_download
from torchvision.transforms.functional import resize
sys.path.append(os.path.abspath("poseidon_demo/external/poseidon"))
from external.poseidon.scOT.model import ScOT, ScOTConfig
def load_model():
"""
Initializes and loads a POSEIDON model with fixed configuration.
Returns:
model (ScOT): An instance of the POSEIDON model in evaluation mode.
"""
config = ScOTConfig(
num_channels=4,
skip_connections=[True, True, True, True]
)
model = ScOT(config)
model.eval()
return model
def run_inference_by_domain(model, domain):
"""
Runs the model on a synthetic input based on the chosen domain.
Args:
model (ScOT): The POSEIDON model.
domain (str): Domain to simulate input for. One of: 'Fluid Dynamics', 'Finance', 'Quantum', 'Biology / Medicine'.
Returns:
np.ndarray: The predicted model output.
"""
if domain == "Fluid Dynamics":
x = torch.linspace(-1, 1, 224)
y = torch.linspace(-1, 1, 224)
X, Y = torch.meshgrid(x, y, indexing="ij")
blob = torch.exp(-(X**2 + Y**2) * 10)
input_tensor = blob.expand(4, 224, 224).unsqueeze(0)
elif domain == "Finance":
base = torch.linspace(0, 1, 224).reshape(1, -1).repeat(224, 1)
noise = torch.randn(4, 224, 224) * 0.05
input_tensor = (base + noise).unsqueeze(0)
elif domain == "Quantum":
x = torch.linspace(0, 4 * torch.pi, 224)
y = torch.linspace(0, 4 * torch.pi, 224)
X, Y = torch.meshgrid(x, y, indexing="ij")
sin_grid = torch.sin(X) * torch.sin(Y)
input_tensor = sin_grid.expand(4, 224, 224).unsqueeze(0)
elif domain == "Biology / Medicine":
x = torch.linspace(-1, 1, 224)
y = torch.linspace(-1, 1, 224)
X, Y = torch.meshgrid(x, y, indexing="ij")
base_blob = torch.exp(-(X**2 + Y**2) * 5)
blob = torch.randn(4, 224, 224) * 0.2 + base_blob
input_tensor = blob.unsqueeze(0)
else:
input_tensor = torch.randn(1, 4, 224, 224)
time_tensor = torch.tensor([0.0])
with torch.no_grad():
output = model(pixel_values=input_tensor, time=time_tensor).output
return output.squeeze().numpy()
def run_inference_on_dataset(model, dataset_name):
"""
Downloads and runs inference on a real scientific dataset using POSEIDON.
Args:
model (ScOT): The POSEIDON model.
dataset_name (str): Identifier for the dataset.
Returns:
tuple: (input_array, output_array) as numpy arrays.
"""
dataset_mapping = {
"fluids.incompressible.Sines": {
"repo_id": "camlab-ethz/NS-Sines",
"filename": "velocity_0.nc",
"variable": "velocity"
},
"fluids.compressible.Riemann": {
"repo_id": "camlab-ethz/CE-RP",
"filename": "data_0.nc",
"variable": "data"
},
"reaction_diffusion.AllenCahn": {
"repo_id": "camlab-ethz/ACE",
"filename": "solution_0.nc",
"variable": "solution"
}
}
entry = dataset_mapping.get(dataset_name)
if entry is None:
raise ValueError(f"Unknown dataset name: {dataset_name}")
file_path = hf_hub_download(
repo_id=entry["repo_id"],
filename=entry["filename"],
repo_type="dataset"
)
ds = xr.open_dataset(file_path, engine="netcdf4")
var = ds[entry["variable"]]
print(f"Loaded shape: {var.shape}, dims: {var.dims}")
if "sample" in var.dims:
sample = var.isel(sample=0, time=0).values.astype(np.float32)
else:
sample = var.isel(time=0).values.astype(np.float32)
if sample.ndim > 3:
sample = np.squeeze(sample)
while sample.ndim < 3:
sample = np.expand_dims(sample, 0)
tensor = torch.tensor(sample)
if tensor.shape[-1] != 224 or tensor.shape[-2] != 224:
tensor = resize(tensor, size=[224, 224])
if tensor.shape[0] < 4:
pad = 4 - tensor.shape[0]
extra = torch.zeros((pad, 224, 224))
tensor = torch.cat([tensor, extra], dim=0)
elif tensor.shape[0] > 4:
tensor = tensor[:4]
input_tensor = tensor.unsqueeze(0)
time_tensor = torch.tensor([0.0])
with torch.no_grad():
output = model(pixel_values=input_tensor, time=time_tensor).output
return tensor.squeeze().numpy(), output.squeeze().numpy()
def plot_output(output_array, cmap="inferno", contrast=2.0):
"""
Plots the output array from the model using a heatmap.
Args:
output_array (np.ndarray): Output from the model.
cmap (str): Colormap used for visualization.
contrast (float): Contrast scaling factor.
Returns:
matplotlib.figure.Figure: The heatmap figure.
"""
output_array = output_array - output_array.min()
output_array = output_array / output_array.max()
output_array = output_array ** contrast
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(
output_array,
ax=ax,
cmap=cmap,
cbar=True,
square=True,
xticklabels=False,
yticklabels=False,
linewidths=0,
)
ax.set_title("POSEIDON Output")
ax.axis("off")
return fig
def plot_comparison(input_array, output_array, cmap="inferno"):
"""
Plots a side-by-side comparison of the input and the model output.
Args:
input_array (np.ndarray): Ground truth or input data.
output_array (np.ndarray): Output predicted by the model.
cmap (str): Colormap used for both plots.
Returns:
matplotlib.figure.Figure: Figure showing input vs output.
"""
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].imshow(input_array[0], cmap=cmap)
axs[0].set_title("Ground Truth")
axs[0].axis("off")
axs[1].imshow(output_array, cmap=cmap)
axs[1].set_title("POSEIDON Prediction")
axs[1].axis("off")
plt.tight_layout()
return fig