白鹭先生
commited on
Commit
·
4a9bffc
1
Parent(s):
c1f6ea5
添加模型
Browse files- frame_field_learning/inference.py +15 -15
- requirements.txt +1 -0
frame_field_learning/inference.py
CHANGED
|
@@ -5,7 +5,7 @@ import scipy
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
-
|
| 9 |
from . import local_utils
|
| 10 |
from . import polygonize
|
| 11 |
|
|
@@ -145,20 +145,20 @@ def load_checkpoint(model, checkpoints_dirpath, device):
|
|
| 145 |
"""
|
| 146 |
Loads best val checkpoint in checkpoints_dirpath
|
| 147 |
"""
|
| 148 |
-
filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
|
| 149 |
-
|
| 150 |
-
if len(filepaths):
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
else:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
device = torch.device(device)
|
| 163 |
checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
|
| 164 |
|
|
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
from . import local_utils
|
| 10 |
from . import polygonize
|
| 11 |
|
|
|
|
| 145 |
"""
|
| 146 |
Loads best val checkpoint in checkpoints_dirpath
|
| 147 |
"""
|
| 148 |
+
# filepaths = python_utils.get_filepaths(checkpoints_dirpath, startswith_str="checkpoint.best_val.",
|
| 149 |
+
# endswith_str=".tar")
|
| 150 |
+
# if len(filepaths):
|
| 151 |
+
# filepaths = sorted(filepaths)
|
| 152 |
+
# filepath = filepaths[-1] # Last best val checkpoint filepath in case there is more than one
|
| 153 |
+
# print_utils.print_info("Loading best val checkpoint: {}".format(filepath))
|
| 154 |
+
# else:
|
| 155 |
+
# # No best val checkpoint fount: find last checkpoint:
|
| 156 |
+
# filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
|
| 157 |
+
# startswith_str="checkpoint.")
|
| 158 |
+
# filepaths = sorted(filepaths)
|
| 159 |
+
# filepath = filepaths[-1] # Last checkpoint
|
| 160 |
+
# print_utils.print_info("Loading last checkpoint: {}".format(filepath))
|
| 161 |
+
filepath = hf_hub_download(repo_id="Egrt/Luuuu", filename="checkpoint.best_val.epoch_000047.tar")
|
| 162 |
device = torch.device(device)
|
| 163 |
checkpoint = torch.load(filepath, map_location=device) # map_location is used to load on current device
|
| 164 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
cython
|
|
|
|
| 2 |
scipy==1.4.1
|
| 3 |
numpy==1.22.3
|
| 4 |
matplotlib==3.3.2
|
|
|
|
| 1 |
cython
|
| 2 |
+
huggingface_hub
|
| 3 |
scipy==1.4.1
|
| 4 |
numpy==1.22.3
|
| 5 |
matplotlib==3.3.2
|