File size: 892 Bytes
44a2478 2034231 a7ab644 44a2478 2034231 a7ab644 2be7d3c a7ab644 2034231 fad4732 44a2478 fad4732 44a2478 a7ab644 2034231 a7ab644 44a2478 fad4732 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from huggingface_hub import hf_hub_download
import torch
from pinn_electromagnetics.models import GeoNetHash, MaxwellPINN
def geonet(pretrained=False, map_location="cpu"):
model = GeoNetHash()
if pretrained:
file_path = hf_hub_download(
repo_id="ayda138000/DualMaxwell",
filename="geonet_real_v30.pth"
)
model.load_state_dict(torch.load(file_path, map_location=map_location))
return model
def maxwell(pretrained=False, map_location="cpu"):
"""
Loads the MaxwellPINN model.
This function loads the .pth file directly.
"""
model = MaxwellPINN()
if pretrained:
file_path = hf_hub_download(
repo_id="ayda138000/DualMaxwell",
filename="physnet_v31_real.pth"
)
model.load_state_dict(torch.load(file_path, map_location=map_location))
return model |