xxwyyds commited on
Commit
473ac76
·
verified ·
1 Parent(s): 7da90fb

Update src/predict.py

Browse files
Files changed (1) hide show
  1. src/predict.py +14 -2
src/predict.py CHANGED
@@ -11,6 +11,15 @@ from src.models.loupe.modeling_loupe import LoupeModel
11
  from src.models.loupe.image_precessing_loupe import LoupeImageProcessor
12
  from src.lit_model import LitModel
13
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Initialize hydra
16
  hydra.initialize(config_path="../configs", version_base=None)
@@ -19,16 +28,19 @@ hydra.initialize(config_path="../configs", version_base=None)
19
  # Load model configuration
20
 
21
  cfg = hydra.compose(config_name="infer")
 
22
  # seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors
23
- cfg.ckpt.checkpoint_paths = ["model_weigths/seg/model.safetensors"]
24
  loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
25
  loupe = LoupeModel(loupe_config)
26
  model = LitModel(cfg, loupe)
27
  processor = LoupeImageProcessor(loupe_config)
28
 
29
  # cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors
 
30
  cfc = hydra.compose(config_name="infer")
31
- cfc.ckpt.checkpoint_paths = ["model_weigths/cls/model.safetensors"]
 
32
  cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model)
33
  cls_loupe = LoupeModel(cls_loupe_config)
34
  cls_model = LitModel(cfc, cls_loupe)
 
11
  from src.models.loupe.image_precessing_loupe import LoupeImageProcessor
12
  from src.lit_model import LitModel
13
 
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ ckpt_path = hf_hub_download(
17
+ repo_id="xxwyyds/Loupe",
18
+ filename="loupe_model/pretrained_weights/pe/PE-Core-L14-336.pt"
19
+ )
20
+
21
+
22
+ seg_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/seg/model.safetensors")
23
 
24
  # Initialize hydra
25
  hydra.initialize(config_path="../configs", version_base=None)
 
28
  # Load model configuration
29
 
30
  cfg = hydra.compose(config_name="infer")
31
+ cfg.model.backbone_path = ckpt_path
32
  # seg:/home/xxw/Loupe/model_weigths/seg/model.safetensors
33
+ cfg.ckpt.checkpoint_paths = [seg_ckpt]
34
  loupe_config = LoupeConfig(stage=cfg.stage.name, **cfg.model)
35
  loupe = LoupeModel(loupe_config)
36
  model = LitModel(cfg, loupe)
37
  processor = LoupeImageProcessor(loupe_config)
38
 
39
  # cls:/home/xxw/Loupe/model_weigths/cls/model.safetensors
40
+ cls_ckpt = hf_hub_download(repo_id="xxwyyds/Loupe", filename="loupe_model/model_weigths/cls/model.safetensors")
41
  cfc = hydra.compose(config_name="infer")
42
+ cfc.ckpt.checkpoint_paths = [cls_ckpt]
43
+ cfc.model.backbone_path = ckpt_path
44
  cls_loupe_config = LoupeConfig(stage=cfc.stage.name, **cfc.model)
45
  cls_loupe = LoupeModel(cls_loupe_config)
46
  cls_model = LitModel(cfc, cls_loupe)