Nick Konovalchuk commited on
Commit ·
be88154
1
Parent(s): 5f8446e
feat(hub): loading a custom model with `torch.hub.load` (#1396)
Browse files- hubconf.py +3 -0
- yolox/models/build.py +44 -24
hubconf.py
CHANGED
|
@@ -5,6 +5,8 @@
|
|
| 5 |
Usage example:
|
| 6 |
import torch
|
| 7 |
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
dependencies = ["torch"]
|
| 10 |
|
|
@@ -16,4 +18,5 @@ from yolox.models import ( # isort:skip # noqa: F401, E402
|
|
| 16 |
yolox_l,
|
| 17 |
yolox_x,
|
| 18 |
yolov3,
|
|
|
|
| 19 |
)
|
|
|
|
| 5 |
Usage example:
|
| 6 |
import torch
|
| 7 |
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
|
| 8 |
+
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_custom",
|
| 9 |
+
exp_path="exp.py", ckpt_path="ckpt.pth")
|
| 10 |
"""
|
| 11 |
dependencies = ["torch"]
|
| 12 |
|
|
|
|
| 18 |
yolox_l,
|
| 19 |
yolox_x,
|
| 20 |
yolov3,
|
| 21 |
+
yolox_custom
|
| 22 |
)
|
yolox/models/build.py
CHANGED
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
| 14 |
"yolox_l",
|
| 15 |
"yolox_x",
|
| 16 |
"yolov3",
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
|
|
@@ -28,16 +29,20 @@ _CKPT_FULL_PATH = {
|
|
| 28 |
}
|
| 29 |
|
| 30 |
|
| 31 |
-
def create_yolox_model(
|
| 32 |
-
|
| 33 |
-
) -> nn.Module:
|
| 34 |
"""creates and loads a YOLOX model
|
| 35 |
|
| 36 |
Args:
|
| 37 |
-
name (str): name of model. for example, "yolox-s", "yolox-tiny"
|
|
|
|
| 38 |
pretrained (bool): load pretrained weights into the model. Default to True.
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
YOLOX model (nn.Module)
|
|
@@ -48,44 +53,59 @@ def create_yolox_model(
|
|
| 48 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 49 |
device = torch.device(device)
|
| 50 |
|
| 51 |
-
assert name in _CKPT_FULL_PATH
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
ckpt =
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
yolox_model.to(device)
|
| 63 |
return yolox_model
|
| 64 |
|
| 65 |
|
| 66 |
-
def yolox_nano(pretrained=True, num_classes=80, device=None):
|
| 67 |
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
|
| 68 |
|
| 69 |
|
| 70 |
-
def yolox_tiny(pretrained=True, num_classes=80, device=None):
|
| 71 |
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
|
| 72 |
|
| 73 |
|
| 74 |
-
def yolox_s(pretrained=True, num_classes=80, device=None):
|
| 75 |
return create_yolox_model("yolox-s", pretrained, num_classes, device)
|
| 76 |
|
| 77 |
|
| 78 |
-
def yolox_m(pretrained=True, num_classes=80, device=None):
|
| 79 |
return create_yolox_model("yolox-m", pretrained, num_classes, device)
|
| 80 |
|
| 81 |
|
| 82 |
-
def yolox_l(pretrained=True, num_classes=80, device=None):
|
| 83 |
return create_yolox_model("yolox-l", pretrained, num_classes, device)
|
| 84 |
|
| 85 |
|
| 86 |
-
def yolox_x(pretrained=True, num_classes=80, device=None):
|
| 87 |
return create_yolox_model("yolox-x", pretrained, num_classes, device)
|
| 88 |
|
| 89 |
|
| 90 |
-
def yolov3(pretrained=True, num_classes=80, device=None):
|
| 91 |
-
return create_yolox_model("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"yolox_l",
|
| 15 |
"yolox_x",
|
| 16 |
"yolov3",
|
| 17 |
+
"yolox_custom"
|
| 18 |
]
|
| 19 |
|
| 20 |
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
|
| 32 |
+
def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
|
| 33 |
+
exp_path: str = None, ckpt_path: str = None) -> nn.Module:
|
|
|
|
| 34 |
"""creates and loads a YOLOX model
|
| 35 |
|
| 36 |
Args:
|
| 37 |
+
name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
|
| 38 |
+
if you want to load your own model.
|
| 39 |
pretrained (bool): load pretrained weights into the model. Default to True.
|
| 40 |
+
device (str): default device to for model. Default to None.
|
| 41 |
+
num_classes (int): number of model classes. Default to 80.
|
| 42 |
+
exp_path (str): path to your own experiment file. Required if name="yolox_custom"
|
| 43 |
+
ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
|
| 44 |
+
load a pretrained model
|
| 45 |
+
|
| 46 |
|
| 47 |
Returns:
|
| 48 |
YOLOX model (nn.Module)
|
|
|
|
| 53 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 54 |
device = torch.device(device)
|
| 55 |
|
| 56 |
+
assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
|
| 57 |
+
f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
|
| 58 |
+
if name in _CKPT_FULL_PATH:
|
| 59 |
+
exp: Exp = get_exp(exp_name=name)
|
| 60 |
+
exp.num_classes = num_classes
|
| 61 |
+
yolox_model = exp.get_model()
|
| 62 |
+
if pretrained and num_classes == 80:
|
| 63 |
+
weights_url = _CKPT_FULL_PATH[name]
|
| 64 |
+
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
|
| 65 |
+
if "model" in ckpt:
|
| 66 |
+
ckpt = ckpt["model"]
|
| 67 |
+
yolox_model.load_state_dict(ckpt)
|
| 68 |
+
else:
|
| 69 |
+
assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
|
| 70 |
+
exp: Exp = get_exp(exp_file=exp_path)
|
| 71 |
+
yolox_model = exp.get_model()
|
| 72 |
+
if ckpt_path:
|
| 73 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 74 |
+
if "model" in ckpt:
|
| 75 |
+
ckpt = ckpt["model"]
|
| 76 |
+
yolox_model.load_state_dict(ckpt)
|
| 77 |
|
| 78 |
yolox_model.to(device)
|
| 79 |
return yolox_model
|
| 80 |
|
| 81 |
|
| 82 |
+
def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 83 |
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
|
| 84 |
|
| 85 |
|
| 86 |
+
def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 87 |
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
|
| 88 |
|
| 89 |
|
| 90 |
+
def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 91 |
return create_yolox_model("yolox-s", pretrained, num_classes, device)
|
| 92 |
|
| 93 |
|
| 94 |
+
def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 95 |
return create_yolox_model("yolox-m", pretrained, num_classes, device)
|
| 96 |
|
| 97 |
|
| 98 |
+
def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 99 |
return create_yolox_model("yolox-l", pretrained, num_classes, device)
|
| 100 |
|
| 101 |
|
| 102 |
+
def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 103 |
return create_yolox_model("yolox-x", pretrained, num_classes, device)
|
| 104 |
|
| 105 |
|
| 106 |
+
def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
| 107 |
+
return create_yolox_model("yolov3", pretrained, num_classes, device)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
|
| 111 |
+
return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)
|