DualMaxwell / hubconf.py
ayda138000's picture
Update hubconf.py
91a61e7 verified
raw
history blame contribute delete
892 Bytes
from huggingface_hub import hf_hub_download
import torch
from pinn_electromagnetics.models import GeoNetHash, MaxwellPINN
def geonet(pretrained=False, map_location="cpu"):
model = GeoNetHash()
if pretrained:
file_path = hf_hub_download(
repo_id="ayda138000/DualMaxwell",
filename="geonet_real_v30.pth"
)
model.load_state_dict(torch.load(file_path, map_location=map_location))
return model
def maxwell(pretrained=False, map_location="cpu"):
"""
Loads the MaxwellPINN model.
This function loads the .pth file directly.
"""
model = MaxwellPINN()
if pretrained:
file_path = hf_hub_download(
repo_id="ayda138000/DualMaxwell",
filename="physnet_v31_real.pth"
)
model.load_state_dict(torch.load(file_path, map_location=map_location))
return model