File size: 892 Bytes
44a2478
 
2034231
a7ab644
44a2478
 
 
2034231
a7ab644
2be7d3c
a7ab644
2034231
fad4732
44a2478
 
 
fad4732
 
 
 
44a2478
 
a7ab644
 
2034231
a7ab644
44a2478
fad4732
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from huggingface_hub import hf_hub_download
import torch
from pinn_electromagnetics.models import GeoNetHash, MaxwellPINN

def geonet(pretrained=False, map_location="cpu"):
    model = GeoNetHash()
    if pretrained:
        file_path = hf_hub_download(
            repo_id="ayda138000/DualMaxwell",
            filename="geonet_real_v30.pth" 
        )
        model.load_state_dict(torch.load(file_path, map_location=map_location))
            
    return model

def maxwell(pretrained=False, map_location="cpu"):
    """
    Loads the MaxwellPINN model.
    This function loads the .pth file directly.
    """
    model = MaxwellPINN()
    if pretrained:
        file_path = hf_hub_download(
            repo_id="ayda138000/DualMaxwell",
            filename="physnet_v31_real.pth" 
        )
        model.load_state_dict(torch.load(file_path, map_location=map_location))
    return model