init project
Browse files- modules/pe3r/models.py +6 -3
modules/pe3r/models.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import sys
|
| 3 |
sys.path.append(os.path.abspath('./modules/ultralytics'))
|
| 4 |
|
| 5 |
-
from transformers import AutoTokenizer, AutoModel, AutoProcessor
|
| 6 |
from modules.mast3r.model import AsymmetricMASt3R
|
| 7 |
|
| 8 |
# from modules.sam2.build_sam import build_sam2_video_predictor
|
|
@@ -32,10 +32,13 @@ class Models:
|
|
| 32 |
|
| 33 |
# -- mobilesamv2 & sam1 --
|
| 34 |
# SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
|
| 35 |
-
SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
|
| 36 |
SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
|
| 37 |
self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
|
| 38 |
-
image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
|
|
|
|
|
|
|
|
|
|
| 39 |
prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
|
| 40 |
self.mobilesamv2.prompt_encoder = prompt_encoder
|
| 41 |
self.mobilesamv2.mask_decoder = mask_decoder
|
|
|
|
| 2 |
import sys
|
| 3 |
sys.path.append(os.path.abspath('./modules/ultralytics'))
|
| 4 |
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
|
| 6 |
from modules.mast3r.model import AsymmetricMASt3R
|
| 7 |
|
| 8 |
# from modules.sam2.build_sam import build_sam2_video_predictor
|
|
|
|
| 32 |
|
| 33 |
# -- mobilesamv2 & sam1 --
|
| 34 |
# SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
|
| 35 |
+
# SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
|
| 36 |
SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
|
| 37 |
self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
|
| 38 |
+
# image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
|
| 39 |
+
sam1 = SamModel.from_pretrained("facebook/sam-vit-huge", device=device)
|
| 40 |
+
image_encoder = sam1.image_encoder
|
| 41 |
+
|
| 42 |
prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
|
| 43 |
self.mobilesamv2.prompt_encoder = prompt_encoder
|
| 44 |
self.mobilesamv2.mask_decoder = mask_decoder
|