RRF / inference.py
antonypamo's picture
Update inference.py
e8a12bd verified
Raw
History Blame Contribute Delete
666 Bytes
import torch
from model import * # import extracted model classes
def load_model(weights_path="model.pt", device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Replace 'YourModelClass' with the actual class name from model.py
model = IcosahedralRRF()
state = torch.load(weights_path, map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()
return model
if __name__ == "__main__":
model = load_model()
# Example dummy input - adjust to your model's expected input
x = torch.randn(1, 3, 224, 224)
y = model(x)
print("Output shape:", y.shape)