Commit ·
d5c53f9
0
Parent(s):
fresh start without image history
Browse files- .gitattributes +35 -0
- .gitignore +5 -0
- README.md +14 -0
- app.py +40 -0
- configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- configs/sam2/sam2_hiera_b+.yaml +113 -0
- configs/sam2/sam2_hiera_l.yaml +117 -0
- configs/sam2/sam2_hiera_s.yaml +116 -0
- configs/sam2/sam2_hiera_t.yaml +118 -0
- requirements.txt +9 -0
- sam2 +1 -0
- sam2segment_structure.py +887 -0
- yolo11n.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jpg
|
| 2 |
+
*.png
|
| 3 |
+
driver_182_30frame/
|
| 4 |
+
*.jpg
|
| 5 |
+
*.png
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DBDLD
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.25.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: The backdoor trigger demo
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from sam2segment_structure import generate_trigger_crop
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# 模拟 lane_data(后期你可以动态读取 JSON 或用户上传)
|
| 6 |
+
dummy_lane_data = {
|
| 7 |
+
"lanes": [[-2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2]],
|
| 8 |
+
"h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390],
|
| 9 |
+
"raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
def process_trigger_with_path(input_image, save_path):
|
| 13 |
+
# 确保目录存在
|
| 14 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# 保存图片到指定路径
|
| 17 |
+
input_image.save(save_path)
|
| 18 |
+
|
| 19 |
+
# 设置 dummy_lane_data 中 raw_file 为当前路径
|
| 20 |
+
dummy_lane_data["raw_file"] = save_path
|
| 21 |
+
|
| 22 |
+
# 调用主处理函数
|
| 23 |
+
crop_path, mask_path = generate_trigger_crop(save_path, dummy_lane_data)
|
| 24 |
+
return crop_path, mask_path
|
| 25 |
+
|
| 26 |
+
demo = gr.Interface(
|
| 27 |
+
fn=process_trigger_with_path,
|
| 28 |
+
inputs=[
|
| 29 |
+
gr.Image(type="pil", label="Upload Image"),
|
| 30 |
+
gr.Textbox(label="Path to Save Image (e.g. driver_182_30frame/06010513_0036.MP4/00270.jpg)")
|
| 31 |
+
],
|
| 32 |
+
outputs=[
|
| 33 |
+
gr.Image(type="filepath", label="Cropped Image"),
|
| 34 |
+
gr.Image(type="filepath", label="Cropped Mask")
|
| 35 |
+
],
|
| 36 |
+
title="DBDLD Trigger Demo",
|
| 37 |
+
description="Upload an image and specify the target save path. The crop and mask will be generated accordingly."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
demo.launch()
|
configs/sam2.1/sam2.1_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [64, 64]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [64, 64]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
no_obj_embed_spatial: true
|
| 93 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 94 |
+
use_high_res_features_in_sam: true
|
| 95 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 96 |
+
multimask_output_in_sam: true
|
| 97 |
+
# SAM heads
|
| 98 |
+
iou_prediction_use_sigmoid: True
|
| 99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 100 |
+
use_obj_ptrs_in_encoder: true
|
| 101 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_l.yaml
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [64, 64]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [64, 64]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
no_obj_embed_spatial: true
|
| 97 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 98 |
+
use_high_res_features_in_sam: true
|
| 99 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 100 |
+
multimask_output_in_sam: true
|
| 101 |
+
# SAM heads
|
| 102 |
+
iou_prediction_use_sigmoid: True
|
| 103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 104 |
+
use_obj_ptrs_in_encoder: true
|
| 105 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 109 |
+
# object occlusion prediction
|
| 110 |
+
pred_obj_scores: true
|
| 111 |
+
pred_obj_scores_mlp: true
|
| 112 |
+
fixed_no_obj_ptr: true
|
| 113 |
+
# multimask tracking settings
|
| 114 |
+
multimask_output_for_tracking: true
|
| 115 |
+
use_multimask_token_for_obj_ptr: true
|
| 116 |
+
multimask_min_pt_num: 0
|
| 117 |
+
multimask_max_pt_num: 1
|
| 118 |
+
use_mlp_for_obj_ptr_proj: true
|
| 119 |
+
# Compilation flag
|
| 120 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_s.yaml
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
no_obj_embed_spatial: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 105 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 106 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 107 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 108 |
+
# object occlusion prediction
|
| 109 |
+
pred_obj_scores: true
|
| 110 |
+
pred_obj_scores_mlp: true
|
| 111 |
+
fixed_no_obj_ptr: true
|
| 112 |
+
# multimask tracking settings
|
| 113 |
+
multimask_output_for_tracking: true
|
| 114 |
+
use_multimask_token_for_obj_ptr: true
|
| 115 |
+
multimask_min_pt_num: 0
|
| 116 |
+
multimask_max_pt_num: 1
|
| 117 |
+
use_mlp_for_obj_ptr_proj: true
|
| 118 |
+
# Compilation flag
|
| 119 |
+
compile_image_encoder: False
|
configs/sam2.1/sam2.1_hiera_t.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
no_obj_embed_spatial: true
|
| 97 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 98 |
+
use_high_res_features_in_sam: true
|
| 99 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 100 |
+
multimask_output_in_sam: true
|
| 101 |
+
# SAM heads
|
| 102 |
+
iou_prediction_use_sigmoid: True
|
| 103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 104 |
+
use_obj_ptrs_in_encoder: true
|
| 105 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 109 |
+
# object occlusion prediction
|
| 110 |
+
pred_obj_scores: true
|
| 111 |
+
pred_obj_scores_mlp: true
|
| 112 |
+
fixed_no_obj_ptr: true
|
| 113 |
+
# multimask tracking settings
|
| 114 |
+
multimask_output_for_tracking: true
|
| 115 |
+
use_multimask_token_for_obj_ptr: true
|
| 116 |
+
multimask_min_pt_num: 0
|
| 117 |
+
multimask_max_pt_num: 1
|
| 118 |
+
use_mlp_for_obj_ptr_proj: true
|
| 119 |
+
# Compilation flag
|
| 120 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 121 |
+
compile_image_encoder: False
|
configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
scratch:
|
| 4 |
+
resolution: 1024
|
| 5 |
+
train_batch_size: 1
|
| 6 |
+
num_train_workers: 10
|
| 7 |
+
num_frames: 8
|
| 8 |
+
max_num_objects: 3
|
| 9 |
+
base_lr: 5.0e-6
|
| 10 |
+
vision_lr: 3.0e-06
|
| 11 |
+
phases_per_epoch: 1
|
| 12 |
+
num_epochs: 40
|
| 13 |
+
|
| 14 |
+
dataset:
|
| 15 |
+
# PATHS to Dataset
|
| 16 |
+
img_folder: null # PATH to MOSE JPEGImages folder
|
| 17 |
+
gt_folder: null # PATH to MOSE Annotations folder
|
| 18 |
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
| 19 |
+
multiplier: 2
|
| 20 |
+
|
| 21 |
+
# Video transforms
|
| 22 |
+
vos:
|
| 23 |
+
train_transforms:
|
| 24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
| 25 |
+
transforms:
|
| 26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
| 27 |
+
consistent_transform: True
|
| 28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
| 29 |
+
degrees: 25
|
| 30 |
+
shear: 20
|
| 31 |
+
image_interpolation: bilinear
|
| 32 |
+
consistent_transform: True
|
| 33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
| 34 |
+
sizes: ${scratch.resolution}
|
| 35 |
+
square: true
|
| 36 |
+
consistent_transform: True
|
| 37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 38 |
+
consistent_transform: True
|
| 39 |
+
brightness: 0.1
|
| 40 |
+
contrast: 0.03
|
| 41 |
+
saturation: 0.03
|
| 42 |
+
hue: null
|
| 43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
| 44 |
+
p: 0.05
|
| 45 |
+
consistent_transform: True
|
| 46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 47 |
+
consistent_transform: False
|
| 48 |
+
brightness: 0.1
|
| 49 |
+
contrast: 0.05
|
| 50 |
+
saturation: 0.05
|
| 51 |
+
hue: null
|
| 52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
| 53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
| 54 |
+
mean: [0.485, 0.456, 0.406]
|
| 55 |
+
std: [0.229, 0.224, 0.225]
|
| 56 |
+
|
| 57 |
+
trainer:
|
| 58 |
+
_target_: training.trainer.Trainer
|
| 59 |
+
mode: train_only
|
| 60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
| 61 |
+
accelerator: cuda
|
| 62 |
+
seed_value: 123
|
| 63 |
+
|
| 64 |
+
model:
|
| 65 |
+
_target_: training.model.sam2.SAM2Train
|
| 66 |
+
image_encoder:
|
| 67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 68 |
+
scalp: 1
|
| 69 |
+
trunk:
|
| 70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 71 |
+
embed_dim: 112
|
| 72 |
+
num_heads: 2
|
| 73 |
+
drop_path_rate: 0.1
|
| 74 |
+
neck:
|
| 75 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 76 |
+
position_encoding:
|
| 77 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 78 |
+
num_pos_feats: 256
|
| 79 |
+
normalize: true
|
| 80 |
+
scale: null
|
| 81 |
+
temperature: 10000
|
| 82 |
+
d_model: 256
|
| 83 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 84 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 85 |
+
fpn_interp_model: nearest
|
| 86 |
+
|
| 87 |
+
memory_attention:
|
| 88 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 89 |
+
d_model: 256
|
| 90 |
+
pos_enc_at_input: true
|
| 91 |
+
layer:
|
| 92 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 93 |
+
activation: relu
|
| 94 |
+
dim_feedforward: 2048
|
| 95 |
+
dropout: 0.1
|
| 96 |
+
pos_enc_at_attn: false
|
| 97 |
+
self_attention:
|
| 98 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 99 |
+
rope_theta: 10000.0
|
| 100 |
+
feat_sizes: [64, 64]
|
| 101 |
+
embedding_dim: 256
|
| 102 |
+
num_heads: 1
|
| 103 |
+
downsample_rate: 1
|
| 104 |
+
dropout: 0.1
|
| 105 |
+
d_model: 256
|
| 106 |
+
pos_enc_at_cross_attn_keys: true
|
| 107 |
+
pos_enc_at_cross_attn_queries: false
|
| 108 |
+
cross_attention:
|
| 109 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 110 |
+
rope_theta: 10000.0
|
| 111 |
+
feat_sizes: [64, 64]
|
| 112 |
+
rope_k_repeat: True
|
| 113 |
+
embedding_dim: 256
|
| 114 |
+
num_heads: 1
|
| 115 |
+
downsample_rate: 1
|
| 116 |
+
dropout: 0.1
|
| 117 |
+
kv_in_dim: 64
|
| 118 |
+
num_layers: 4
|
| 119 |
+
|
| 120 |
+
memory_encoder:
|
| 121 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 122 |
+
out_dim: 64
|
| 123 |
+
position_encoding:
|
| 124 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 125 |
+
num_pos_feats: 64
|
| 126 |
+
normalize: true
|
| 127 |
+
scale: null
|
| 128 |
+
temperature: 10000
|
| 129 |
+
mask_downsampler:
|
| 130 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 131 |
+
kernel_size: 3
|
| 132 |
+
stride: 2
|
| 133 |
+
padding: 1
|
| 134 |
+
fuser:
|
| 135 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 136 |
+
layer:
|
| 137 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 138 |
+
dim: 256
|
| 139 |
+
kernel_size: 7
|
| 140 |
+
padding: 3
|
| 141 |
+
layer_scale_init_value: 1e-6
|
| 142 |
+
use_dwconv: True # depth-wise convs
|
| 143 |
+
num_layers: 2
|
| 144 |
+
|
| 145 |
+
num_maskmem: 7
|
| 146 |
+
image_size: ${scratch.resolution}
|
| 147 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 148 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 149 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 150 |
+
use_mask_input_as_output_without_sam: true
|
| 151 |
+
# Memory
|
| 152 |
+
directly_add_no_mem_embed: true
|
| 153 |
+
no_obj_embed_spatial: true
|
| 154 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 155 |
+
use_high_res_features_in_sam: true
|
| 156 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 157 |
+
multimask_output_in_sam: true
|
| 158 |
+
# SAM heads
|
| 159 |
+
iou_prediction_use_sigmoid: True
|
| 160 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 161 |
+
use_obj_ptrs_in_encoder: true
|
| 162 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 163 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 164 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 165 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 166 |
+
# object occlusion prediction
|
| 167 |
+
pred_obj_scores: true
|
| 168 |
+
pred_obj_scores_mlp: true
|
| 169 |
+
fixed_no_obj_ptr: true
|
| 170 |
+
# multimask tracking settings
|
| 171 |
+
multimask_output_for_tracking: true
|
| 172 |
+
use_multimask_token_for_obj_ptr: true
|
| 173 |
+
multimask_min_pt_num: 0
|
| 174 |
+
multimask_max_pt_num: 1
|
| 175 |
+
use_mlp_for_obj_ptr_proj: true
|
| 176 |
+
# Compilation flag
|
| 177 |
+
# compile_image_encoder: False
|
| 178 |
+
|
| 179 |
+
####### Training specific params #######
|
| 180 |
+
# box/point input and corrections
|
| 181 |
+
prob_to_use_pt_input_for_train: 0.5
|
| 182 |
+
prob_to_use_pt_input_for_eval: 0.0
|
| 183 |
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
| 184 |
+
prob_to_use_box_input_for_eval: 0.0
|
| 185 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
| 186 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
| 187 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
| 188 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
| 189 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
| 190 |
+
# maximum 2 initial conditioning frames
|
| 191 |
+
num_init_cond_frames_for_train: 2
|
| 192 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
| 193 |
+
num_correction_pt_per_frame: 7
|
| 194 |
+
use_act_ckpt_iterative_pt_sampling: false
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
| 199 |
+
forward_backbone_per_frame_for_eval: True
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
data:
|
| 203 |
+
train:
|
| 204 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
| 205 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
| 206 |
+
batch_sizes:
|
| 207 |
+
- ${scratch.train_batch_size}
|
| 208 |
+
|
| 209 |
+
datasets:
|
| 210 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
| 211 |
+
dataset:
|
| 212 |
+
_target_: training.dataset.utils.ConcatDataset
|
| 213 |
+
datasets:
|
| 214 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
| 215 |
+
transforms: ${vos.train_transforms}
|
| 216 |
+
training: true
|
| 217 |
+
video_dataset:
|
| 218 |
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
| 219 |
+
img_folder: ${dataset.img_folder}
|
| 220 |
+
gt_folder: ${dataset.gt_folder}
|
| 221 |
+
file_list_txt: ${dataset.file_list_txt}
|
| 222 |
+
sampler:
|
| 223 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
| 224 |
+
num_frames: ${scratch.num_frames}
|
| 225 |
+
max_num_objects: ${scratch.max_num_objects}
|
| 226 |
+
multiplier: ${dataset.multiplier}
|
| 227 |
+
shuffle: True
|
| 228 |
+
num_workers: ${scratch.num_train_workers}
|
| 229 |
+
pin_memory: True
|
| 230 |
+
drop_last: True
|
| 231 |
+
collate_fn:
|
| 232 |
+
_target_: training.utils.data_utils.collate_fn
|
| 233 |
+
_partial_: true
|
| 234 |
+
dict_key: all
|
| 235 |
+
|
| 236 |
+
optim:
|
| 237 |
+
amp:
|
| 238 |
+
enabled: True
|
| 239 |
+
amp_dtype: bfloat16
|
| 240 |
+
|
| 241 |
+
optimizer:
|
| 242 |
+
_target_: torch.optim.AdamW
|
| 243 |
+
|
| 244 |
+
gradient_clip:
|
| 245 |
+
_target_: training.optimizer.GradientClipper
|
| 246 |
+
max_norm: 0.1
|
| 247 |
+
norm_type: 2
|
| 248 |
+
|
| 249 |
+
param_group_modifiers:
|
| 250 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
| 251 |
+
_partial_: True
|
| 252 |
+
layer_decay_value: 0.9
|
| 253 |
+
apply_to: 'image_encoder.trunk'
|
| 254 |
+
overrides:
|
| 255 |
+
- pattern: '*pos_embed*'
|
| 256 |
+
value: 1.0
|
| 257 |
+
|
| 258 |
+
options:
|
| 259 |
+
lr:
|
| 260 |
+
- scheduler:
|
| 261 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 262 |
+
start_value: ${scratch.base_lr}
|
| 263 |
+
end_value: ${divide:${scratch.base_lr},10}
|
| 264 |
+
- scheduler:
|
| 265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 266 |
+
start_value: ${scratch.vision_lr}
|
| 267 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
| 268 |
+
param_names:
|
| 269 |
+
- 'image_encoder.*'
|
| 270 |
+
weight_decay:
|
| 271 |
+
- scheduler:
|
| 272 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 273 |
+
value: 0.1
|
| 274 |
+
- scheduler:
|
| 275 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 276 |
+
value: 0.0
|
| 277 |
+
param_names:
|
| 278 |
+
- '*bias*'
|
| 279 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
| 280 |
+
|
| 281 |
+
loss:
|
| 282 |
+
all:
|
| 283 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
| 284 |
+
weight_dict:
|
| 285 |
+
loss_mask: 20
|
| 286 |
+
loss_dice: 1
|
| 287 |
+
loss_iou: 1
|
| 288 |
+
loss_class: 1
|
| 289 |
+
supervise_all_iou: true
|
| 290 |
+
iou_use_l1_loss: true
|
| 291 |
+
pred_obj_scores: true
|
| 292 |
+
focal_gamma_obj_score: 0.0
|
| 293 |
+
focal_alpha_obj_score: -1.0
|
| 294 |
+
|
| 295 |
+
distributed:
|
| 296 |
+
backend: nccl
|
| 297 |
+
find_unused_parameters: True
|
| 298 |
+
|
| 299 |
+
logging:
|
| 300 |
+
tensorboard_writer:
|
| 301 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
| 302 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
| 303 |
+
flush_secs: 120
|
| 304 |
+
should_log: True
|
| 305 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
| 306 |
+
log_freq: 10
|
| 307 |
+
|
| 308 |
+
# initialize from a SAM 2 checkpoint
|
| 309 |
+
checkpoint:
|
| 310 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
| 311 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
| 312 |
+
model_weight_initializer:
|
| 313 |
+
_partial_: True
|
| 314 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
| 315 |
+
strict: True
|
| 316 |
+
ignore_unexpected_keys: null
|
| 317 |
+
ignore_missing_keys: null
|
| 318 |
+
|
| 319 |
+
state_dict:
|
| 320 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 321 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 322 |
+
ckpt_state_dict_keys: ['model']
|
| 323 |
+
|
| 324 |
+
launcher:
|
| 325 |
+
num_nodes: 1
|
| 326 |
+
gpus_per_node: 8
|
| 327 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
| 328 |
+
|
| 329 |
+
# SLURM args if running on a cluster
|
| 330 |
+
submitit:
|
| 331 |
+
partition: null
|
| 332 |
+
account: null
|
| 333 |
+
qos: null
|
| 334 |
+
cpus_per_task: 10
|
| 335 |
+
use_cluster: false
|
| 336 |
+
timeout_hour: 24
|
| 337 |
+
name: null
|
| 338 |
+
port_range: [10000, 65000]
|
| 339 |
+
|
configs/sam2/sam2_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [64, 64]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [64, 64]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 93 |
+
use_high_res_features_in_sam: true
|
| 94 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 95 |
+
multimask_output_in_sam: true
|
| 96 |
+
# SAM heads
|
| 97 |
+
iou_prediction_use_sigmoid: True
|
| 98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 99 |
+
use_obj_ptrs_in_encoder: true
|
| 100 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 102 |
+
# object occlusion prediction
|
| 103 |
+
pred_obj_scores: true
|
| 104 |
+
pred_obj_scores_mlp: true
|
| 105 |
+
fixed_no_obj_ptr: true
|
| 106 |
+
# multimask tracking settings
|
| 107 |
+
multimask_output_for_tracking: true
|
| 108 |
+
use_multimask_token_for_obj_ptr: true
|
| 109 |
+
multimask_min_pt_num: 0
|
| 110 |
+
multimask_max_pt_num: 1
|
| 111 |
+
use_mlp_for_obj_ptr_proj: true
|
| 112 |
+
# Compilation flag
|
| 113 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [64, 64]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [64, 64]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_s.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 96 |
+
use_high_res_features_in_sam: true
|
| 97 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 98 |
+
multimask_output_in_sam: true
|
| 99 |
+
# SAM heads
|
| 100 |
+
iou_prediction_use_sigmoid: True
|
| 101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 102 |
+
use_obj_ptrs_in_encoder: true
|
| 103 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
configs/sam2/sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 118 |
+
compile_image_encoder: False
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
opencv-python
|
| 4 |
+
gradio
|
| 5 |
+
matplotlib
|
| 6 |
+
Pillow
|
| 7 |
+
ultralytics
|
| 8 |
+
diffusers
|
| 9 |
+
huggingface_hub
|
sam2
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/data_sdf/yifan/sam2
|
sam2segment_structure.py
ADDED
|
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import sys,os
|
| 3 |
+
# sys.path.append("/home/yifan/sam2")
|
| 4 |
+
# sys.path.append("/data_sdf/yifan/miniconda3/envs/sam2/lib/python3.10/site-packages")
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "sam2"))
|
| 7 |
+
from sam2.build_sam import build_sam2
|
| 8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 9 |
+
import torch
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import cv2
|
| 13 |
+
import random
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
device = torch.device("cuda")
|
| 17 |
+
sam2_checkpoint = hf_hub_download(
|
| 18 |
+
repo_id="Evan73/sam2-models",
|
| 19 |
+
filename="sam2.1_hiera_large.pt"
|
| 20 |
+
)
|
| 21 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 22 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
| 23 |
+
# global sam2_model
|
| 24 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
| 25 |
+
from ultralytics import YOLO
|
| 26 |
+
from diffusers.utils import load_image
|
| 27 |
+
import pickle
|
| 28 |
+
import os
|
| 29 |
+
import math
|
| 30 |
+
heatmap_zip = hf_hub_download(
|
| 31 |
+
repo_id="Evan73/attention-heatmaps",
|
| 32 |
+
filename="attention_heatmaps.zip"
|
| 33 |
+
)
|
| 34 |
+
import zipfile
|
| 35 |
+
import os
|
| 36 |
+
|
| 37 |
+
with zipfile.ZipFile(heatmap_zip, 'r') as zip_ref:
|
| 38 |
+
zip_ref.extractall("heatmaps_lda")
|
| 39 |
+
|
| 40 |
+
with open("heatmaps_lda/attention_heatmaps.pkl", "rb") as f:
|
| 41 |
+
heatmap_dict = pickle.load(f)
|
| 42 |
+
|
| 43 |
+
def load_yolov5_model():
|
| 44 |
+
# 使用YOLOv5官方模型加载器(需要安装yolov5)
|
| 45 |
+
# model = torch.hub.load('ultralytics/yolov11', 'yolov11s') # 可以根据需要选择不同大小的模型
|
| 46 |
+
model = YOLO("yolo11n.pt")
|
| 47 |
+
class_names = model.names # class index to name mapping
|
| 48 |
+
print("YOLOv11 Class Names:")
|
| 49 |
+
for idx, name in class_names.items():
|
| 50 |
+
print(f"{idx}: {name}")
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
# 检查点是否在汽车区域内
|
| 54 |
+
def is_point_in_car_area(point, model, image):
|
| 55 |
+
"""
|
| 56 |
+
检查给定的点是否在车辆区域内
|
| 57 |
+
- point: 点的坐标 (x, y)
|
| 58 |
+
- model: YOLO模型
|
| 59 |
+
- image: 输入的图像
|
| 60 |
+
"""
|
| 61 |
+
# 使用YOLO模型进行物体检测
|
| 62 |
+
results = model(image) # 获取检测结果
|
| 63 |
+
|
| 64 |
+
# 获取汽车类别(根据模型调整类别ID)
|
| 65 |
+
# print("Detected classes:", results[0].boxes.cls.cpu().numpy())
|
| 66 |
+
car_class_id = [2, 5, 7] # COCO数据集中汽车类别通常为2,但需确认
|
| 67 |
+
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 68 |
+
|
| 69 |
+
# 遍历每个检测结果(支持批量处理,这里假设单张图像)
|
| 70 |
+
for result in results:
|
| 71 |
+
# 提取检测框的xyxy坐标、置信度、类别
|
| 72 |
+
boxes = result.boxes.xyxy.cpu().numpy() # 转换为左上和右下坐标
|
| 73 |
+
confidences = result.boxes.conf.cpu().numpy()
|
| 74 |
+
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
| 75 |
+
|
| 76 |
+
# 遍历每个检测框
|
| 77 |
+
for box, cls in zip(boxes, class_ids):
|
| 78 |
+
if cls in car_class_id:
|
| 79 |
+
x_min, y_min, x_max, y_max = box[:4]
|
| 80 |
+
# 绘制检测框(可选)
|
| 81 |
+
cv2.rectangle(image_bgr, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
|
| 82 |
+
# 检查点是否在框内
|
| 83 |
+
if (x_min <= point[0] <= x_max) and (y_min <= point[1] <= y_max):
|
| 84 |
+
cv2.imwrite("yolo_res.jpg", image_bgr)
|
| 85 |
+
return False
|
| 86 |
+
cv2.imwrite("yolo_res.jpg", image_bgr)
|
| 87 |
+
print(f"检测结果已保存至 yolo_res.jpg")
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def show_mask(mask, ax, image_path,random_color=False, borders=True, image=None, save_path=None):
|
| 92 |
+
"""
|
| 93 |
+
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
|
| 94 |
+
|
| 95 |
+
参数:
|
| 96 |
+
- `mask`: 掩码区域
|
| 97 |
+
- `ax`: 用于绘制的matplotlib轴
|
| 98 |
+
- `random_color`: 是否使用随机颜色
|
| 99 |
+
- `borders`: 是否显示边界
|
| 100 |
+
- `image`: 原始图像,用于绘制矩形框
|
| 101 |
+
- `save_path`: 保存结果图像的路径
|
| 102 |
+
"""
|
| 103 |
+
if random_color:
|
| 104 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 105 |
+
else:
|
| 106 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 107 |
+
|
| 108 |
+
h, w = mask.shape[-2:]
|
| 109 |
+
mask = mask.astype(np.uint8)
|
| 110 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 111 |
+
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
|
| 112 |
+
print("原始二值掩码已保存为 binary_mask.png")
|
| 113 |
+
if borders:
|
| 114 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 115 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
| 116 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=5)
|
| 117 |
+
# print(f"Mask unique values: {np.unique(mask)}")
|
| 118 |
+
# print(f"Max value in mask: {mask.max()}, Min value in mask: {mask.min()}")
|
| 119 |
+
# 如果提供了原始图像,绘制矩形框
|
| 120 |
+
# size = 100
|
| 121 |
+
colors = [
|
| 122 |
+
(255, 0, 0), # 红色
|
| 123 |
+
(0, 255, 0), # 绿色
|
| 124 |
+
(0, 0, 255), # 蓝色
|
| 125 |
+
(255, 255, 0), # 黄色
|
| 126 |
+
(255, 0, 255), # 品红色
|
| 127 |
+
(0, 255, 255), # 青色
|
| 128 |
+
(255, 128, 0), # 橙色
|
| 129 |
+
(128, 0, 255), # 紫色
|
| 130 |
+
(128, 128, 128), # 灰色
|
| 131 |
+
(0, 128, 0) # 深绿色
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
for idx, contour in enumerate(contours):
|
| 135 |
+
x, y, w, h = cv2.boundingRect(contour)
|
| 136 |
+
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
| 137 |
+
color = colors[idx % len(colors)]
|
| 138 |
+
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
| 139 |
+
middle_save_path = "contours_colored_result.png"
|
| 140 |
+
cv2.imwrite(middle_save_path, image)
|
| 141 |
+
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
| 142 |
+
if image is not None:
|
| 143 |
+
# 找到掩码的边界
|
| 144 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 145 |
+
for contour in contours:
|
| 146 |
+
x, y, w, h = cv2.boundingRect(contour)
|
| 147 |
+
# print(x, y, w, h)
|
| 148 |
+
if w > 50 and h > 50:
|
| 149 |
+
for size in range(90,40,-5):
|
| 150 |
+
for _ in range(100):
|
| 151 |
+
random_x1 = random.randint(x, x + w - 50)
|
| 152 |
+
random_y1 = random.randint(y, y + h - 50)
|
| 153 |
+
random_x2 = random_x1 - size
|
| 154 |
+
random_y2 = random_y1 - size
|
| 155 |
+
# print(random_x1, random_y1,random_x2,random_y2)
|
| 156 |
+
# 在原图上绘制矩形框
|
| 157 |
+
# 保存结果图像
|
| 158 |
+
try:
|
| 159 |
+
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
| 160 |
+
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
| 161 |
+
cv2.imwrite(save_path, image)
|
| 162 |
+
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
|
| 163 |
+
print(f"Image with rectangle saved at {save_path}")
|
| 164 |
+
return (random_x1,random_y1),(random_x2,random_y2)
|
| 165 |
+
except:
|
| 166 |
+
pass
|
| 167 |
+
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
| 168 |
+
# cv2.imwrite(save_path, image)
|
| 169 |
+
# print(f"Image with rectangle saved at {save_path}")
|
| 170 |
+
# break
|
| 171 |
+
for _ in range(100):
|
| 172 |
+
random_x1 = random.randint(x, x + w - 50)
|
| 173 |
+
random_y1 = random.randint(y, y + h - 50)
|
| 174 |
+
random_x2 = random_x1 + size
|
| 175 |
+
random_y2 = random_y1 + size
|
| 176 |
+
# print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
|
| 177 |
+
# 在原图上绘制矩形框
|
| 178 |
+
# 保存结果图像
|
| 179 |
+
try:
|
| 180 |
+
if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
| 181 |
+
cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
| 182 |
+
cv2.imwrite(save_path, image)
|
| 183 |
+
print(f"Image with rectangle saved at {save_path}")
|
| 184 |
+
# generate_gt_mask_from_intersection([(random_x1, random_y1),(random_x2, random_y2)], yolo_boxes, image, sam2_model, threshold_iou=0.01)
|
| 185 |
+
return (random_x1,random_y1),(random_x2,random_y2)
|
| 186 |
+
except:
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
ax.imshow(mask_image)
|
| 190 |
+
plt.axis('off')
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def attention_mask(mask, ax, image_path,strategy="LOA",random_color=False, borders=True, image=None, save_path=None):
|
| 195 |
+
"""
|
| 196 |
+
根据mask区域随机选择两个对角点并在原始图像上绘制矩形框。
|
| 197 |
+
|
| 198 |
+
参数:
|
| 199 |
+
- `mask`: 掩码区域
|
| 200 |
+
- `ax`: 用于绘制的matplotlib轴
|
| 201 |
+
- `random_color`: 是否使用随机颜色
|
| 202 |
+
- `borders`: 是否显示边界
|
| 203 |
+
- `image`: 原始图像,用于绘制矩形框
|
| 204 |
+
- `save_path`: 保存结果图像的路径
|
| 205 |
+
"""
|
| 206 |
+
if random_color:
|
| 207 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 208 |
+
else:
|
| 209 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 210 |
+
orig_w, orig_h = image.shape[1],image.shape[0]
|
| 211 |
+
# print(image.shape)
|
| 212 |
+
h, w = mask.shape[-2:]
|
| 213 |
+
mask = mask.astype(np.uint8)
|
| 214 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 215 |
+
cv2.imwrite("binary_mask.png", (mask * 255).astype(np.uint8))
|
| 216 |
+
print("原始二值掩码已保存为 binary_mask.png")
|
| 217 |
+
# if borders:
|
| 218 |
+
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 219 |
+
# contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
| 220 |
+
# mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
| 221 |
+
# colors = [
|
| 222 |
+
# (255, 0, 0), # 红色
|
| 223 |
+
# (0, 255, 0), # 绿色
|
| 224 |
+
# (0, 0, 255), # 蓝色
|
| 225 |
+
# (255, 255, 0), # 黄色
|
| 226 |
+
# (255, 0, 255), # 品红色
|
| 227 |
+
# (0, 255, 255), # 青色
|
| 228 |
+
# (255, 128, 0), # 橙色
|
| 229 |
+
# (128, 0, 255), # 紫色
|
| 230 |
+
# (128, 128, 128), # 灰色
|
| 231 |
+
# (0, 128, 0) # 深绿色
|
| 232 |
+
# ]
|
| 233 |
+
# # print(mask.shape)
|
| 234 |
+
# for idx, contour in enumerate(contours):
|
| 235 |
+
# x, y, w, h = cv2.boundingRect(contour)
|
| 236 |
+
# print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
| 237 |
+
# color = colors[idx % len(colors)]
|
| 238 |
+
# cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
| 239 |
+
# middle_save_path = "contours_colored_result.png"
|
| 240 |
+
# cv2.imwrite(middle_save_path, image)
|
| 241 |
+
# print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
| 242 |
+
candidates = []
|
| 243 |
+
path = image_path
|
| 244 |
+
cls_heatmap = heatmap_dict[path]['cls_heatmap']
|
| 245 |
+
reg_heatmap = heatmap_dict[path]['reg_heatmap']
|
| 246 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 247 |
+
if strategy == "LDA":
|
| 248 |
+
combined = cls_heatmap.astype(np.float32)
|
| 249 |
+
if strategy == "LOA" or strategy == "LRA":
|
| 250 |
+
combined = reg_heatmap.astype(np.float32)
|
| 251 |
+
print(mask.shape)
|
| 252 |
+
mask = cv2.resize(mask, (combined.shape[1], combined.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 253 |
+
mask = (mask > 0.5).astype(np.uint8)
|
| 254 |
+
cv2.imwrite("crop_binary_mask.png", (mask * 255).astype(np.uint8))
|
| 255 |
+
print("处理后的裁剪二值掩码已保存为 crop_binary_mask.png")
|
| 256 |
+
print(combined.shape)
|
| 257 |
+
vis_image = cv2.imread(image_path)
|
| 258 |
+
vis_image = cv2.resize(vis_image,(combined.shape[1],combined.shape[0]))
|
| 259 |
+
mask_image = cv2.resize(mask_image,(combined.shape[1],combined.shape[0]))
|
| 260 |
+
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
|
| 261 |
+
if borders:
|
| 262 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 263 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
| 264 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
| 265 |
+
colors = [
|
| 266 |
+
(255, 0, 0), # 红色
|
| 267 |
+
(0, 255, 0), # 绿色
|
| 268 |
+
(0, 0, 255), # 蓝色
|
| 269 |
+
(255, 255, 0), # 黄色
|
| 270 |
+
(255, 0, 255), # 品红色
|
| 271 |
+
(0, 255, 255), # 青色
|
| 272 |
+
(255, 128, 0), # 橙色
|
| 273 |
+
(128, 0, 255), # 紫色
|
| 274 |
+
(128, 128, 128), # 灰色
|
| 275 |
+
(0, 128, 0) # 深绿色
|
| 276 |
+
]
|
| 277 |
+
# print(mask.shape)
|
| 278 |
+
for idx, contour in enumerate(contours):
|
| 279 |
+
x, y, w, h = cv2.boundingRect(contour)
|
| 280 |
+
print(f"轮廓{idx}: x={x}, y={y}, w={w}, h={h}")
|
| 281 |
+
color = colors[idx % len(colors)]
|
| 282 |
+
cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
|
| 283 |
+
middle_save_path = "contours_colored_result.png"
|
| 284 |
+
cv2.imwrite(middle_save_path, image)
|
| 285 |
+
print(f"带颜色的轮廓结果已保存至 {middle_save_path}")
|
| 286 |
+
if image is not None:
|
| 287 |
+
# 找到掩码的边界
|
| 288 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 289 |
+
# print(contours)
|
| 290 |
+
for contour in contours:
|
| 291 |
+
x, y, w, h = cv2.boundingRect(contour)
|
| 292 |
+
print("the contour is:",x, y, w, h)
|
| 293 |
+
if w > 50 and h > 50:
|
| 294 |
+
for size in range(50,40,-5):
|
| 295 |
+
for y_step in range(y, y+h - size,5):
|
| 296 |
+
for x_step in range(x, x+w - size,5):
|
| 297 |
+
x1, y1, x2, y2 = x_step, y_step, x_step + size, y_step + size
|
| 298 |
+
# print(mask[y1:y2, x1:x2].sum())
|
| 299 |
+
if mask[y1:y2, x1:x2].sum() >= size * size: # 掩码区域必须都在内部
|
| 300 |
+
heat_value = combined[y1:y2, x1:x2].mean()
|
| 301 |
+
# print("the heat_value is:",heat_value,y1,y2, x1,x2,combined.shape)
|
| 302 |
+
if not math.isnan(heat_value):
|
| 303 |
+
candidates.append(((x1, y1, x2, y2), heat_value))
|
| 304 |
+
cv2.rectangle(vis_image, (x1, y1), (x2, y2), (0, 255, 0), 1)
|
| 305 |
+
cv2.putText(vis_image, f'{heat_value:.1f}', (x1, y1 - 2), font, 0.4, (0, 0, 255), 1)
|
| 306 |
+
if not candidates:
|
| 307 |
+
print("⚠️ 没有找到满足掩码内区域的候选框")
|
| 308 |
+
else:
|
| 309 |
+
break
|
| 310 |
+
cv2.imwrite("attention_vis.jpg", vis_image)
|
| 311 |
+
print(f"Attention 候选框可视化已保存 attention_vis.jpg")
|
| 312 |
+
# 从高到低排序,选择热值最高的
|
| 313 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 314 |
+
print(save_path,candidates[0],candidates[-1])
|
| 315 |
+
for (x1, y1, x2, y2), _ in candidates:
|
| 316 |
+
try:
|
| 317 |
+
if mask[y1, x1] == 1 and mask[y2, x2] == 1:
|
| 318 |
+
# 可视化 + 保存
|
| 319 |
+
if save_path:
|
| 320 |
+
image = cv2.imread(image_path)
|
| 321 |
+
image = cv2.resize(image,(combined.shape[1],combined.shape[0]))
|
| 322 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 323 |
+
# os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 324 |
+
cv2.imwrite(save_path, image)
|
| 325 |
+
print(f"Image with rectangle saved at {save_path}")
|
| 326 |
+
resize_w, resize_h = combined.shape[1],combined.shape[0]
|
| 327 |
+
scale_x = orig_w / resize_w
|
| 328 |
+
scale_y = orig_h / resize_h
|
| 329 |
+
x1_orig = int(x1 * scale_x)
|
| 330 |
+
x2_orig = int(x2 * scale_x)
|
| 331 |
+
y1_orig = int(y1 * scale_y)
|
| 332 |
+
y2_orig = int(y2 * scale_y)
|
| 333 |
+
cx = (x1_orig + x2_orig) // 2
|
| 334 |
+
cy = (y1_orig + y2_orig) // 2
|
| 335 |
+
target_size = 90
|
| 336 |
+
half = target_size // 2
|
| 337 |
+
x1_exp = max(0, cx - half)
|
| 338 |
+
y1_exp = max(0, cy - half)
|
| 339 |
+
x2_exp = min(orig_w - 1, cx + half)
|
| 340 |
+
y2_exp = min(orig_h - 1, cy + half)
|
| 341 |
+
print(f"扩展后的原图坐标: ({x1_exp}, {y1_exp}), ({x2_exp}, {y2_exp})")
|
| 342 |
+
image_full = cv2.imread(image_path) # 原图大小读取
|
| 343 |
+
cv2.rectangle(image_full, (x1_exp, y1_exp), (x2_exp, y2_exp), (0, 0, 255), 2)
|
| 344 |
+
cv2.imwrite("expanded_bbox_on_original.jpg", image_full)
|
| 345 |
+
print("📌 扩大后的候选框已绘制到原图并保存为 expanded_bbox_on_original.jpg")
|
| 346 |
+
return (x1_exp, y1_exp), (x2_exp, y2_exp)
|
| 347 |
+
except Exception as e:
|
| 348 |
+
print("the error is:",e)
|
| 349 |
+
pass # 若越界等问题,继续下一个
|
| 350 |
+
# for _ in range(100):
|
| 351 |
+
# random_x1 = random.randint(x, x + w - 50)
|
| 352 |
+
# random_y1 = random.randint(y, y + h - 50)
|
| 353 |
+
# random_x2 = random_x1 + size
|
| 354 |
+
# random_y2 = random_y1 + size
|
| 355 |
+
# # print(mask[random_y1, random_x1] == 1,mask[random_y2, random_x2] == 1)
|
| 356 |
+
# try:
|
| 357 |
+
# if save_path and mask[random_y1, random_x1] == 1 and mask[random_y2, random_x2] == 1:
|
| 358 |
+
# cv2.rectangle(image,(random_x2, random_y2), (random_x1, random_y1), (0, 255, 0), 2)
|
| 359 |
+
# cv2.imwrite(save_path, image)
|
| 360 |
+
# print(f"Image with rectangle saved at {save_path}")
|
| 361 |
+
# return (random_x1,random_y1),(random_x2,random_y2)
|
| 362 |
+
# except:
|
| 363 |
+
# pass
|
| 364 |
+
|
| 365 |
+
ax.imshow(mask_image)
|
| 366 |
+
plt.axis('off')
|
| 367 |
+
|
| 368 |
+
def generate_gt_mask_from_intersection(random_rectangle, yolo_boxes, image, mask_img,sam2_model, threshold_iou):
|
| 369 |
+
"""
|
| 370 |
+
判断随机生成的矩形与YOLO的框是否足够接近,
|
| 371 |
+
若满足条件则调用SAM获取精准掩码作为GT。
|
| 372 |
+
"""
|
| 373 |
+
image_np = np.array(image)
|
| 374 |
+
x1_rect, y1_rect = random_rectangle[0]
|
| 375 |
+
x2_rect, y2_rect = random_rectangle[1]
|
| 376 |
+
rect_mask = np.zeros(image_np.shape[:2], dtype=np.uint8)
|
| 377 |
+
cv2.rectangle(rect_mask, (x1_rect, y1_rect), (x2_rect, y2_rect), color=255, thickness=-1)
|
| 378 |
+
|
| 379 |
+
rect_box = [min(x1_rect, x2_rect), min(y1_rect, y2_rect), max(x1_rect, x2_rect), max(y1_rect, y2_rect)]
|
| 380 |
+
|
| 381 |
+
for box in yolo_boxes:
|
| 382 |
+
iou = calculate_iou(rect_box, box)
|
| 383 |
+
print(f"与YOLO box的IoU为: {iou}, 阈值: {threshold_iou}")
|
| 384 |
+
|
| 385 |
+
if iou >= threshold_iou:
|
| 386 |
+
# 在YOLO框内随机取两个点
|
| 387 |
+
x_min, y_min, x_max, y_max = box
|
| 388 |
+
input_point1 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
| 389 |
+
input_point2 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
| 390 |
+
input_point3 = (np.random.randint(x_min, x_max), np.random.randint(y_min, y_max))
|
| 391 |
+
|
| 392 |
+
# 使用SAM生成精准掩码
|
| 393 |
+
gt_mask = get_gt_mask_from_sam(image, sam2_model, [input_point1, input_point2,input_point3], rect_mask)
|
| 394 |
+
mask_img[gt_mask > 0] = 0
|
| 395 |
+
# 保存gt掩码
|
| 396 |
+
cv2.imwrite('gt_mask_from_sam.png', gt_mask)
|
| 397 |
+
print(f"SAM生成的GT掩码已保存至 gt_mask_from_sam.png")
|
| 398 |
+
|
| 399 |
+
return gt_mask,mask_img
|
| 400 |
+
h, w = image_np.shape[:2]
|
| 401 |
+
black_mask = np.zeros((h, w), dtype=np.uint8)
|
| 402 |
+
no_match_save_path = 'gt_mask_from_sam.png'
|
| 403 |
+
cv2.imwrite(no_match_save_path, black_mask)
|
| 404 |
+
print("未找到满足阈值条件的YOLO box。")
|
| 405 |
+
print(f"未匹配成功,保存空掩码图至 {no_match_save_path}")
|
| 406 |
+
return None,mask_img
|
| 407 |
+
|
| 408 |
+
def calculate_iou(boxA, boxB):
|
| 409 |
+
"""计算两个box的IoU."""
|
| 410 |
+
xA = max(boxA[0], boxB[0])
|
| 411 |
+
yA = max(boxA[1], boxB[1])
|
| 412 |
+
xB = min(boxA[2], boxB[2])
|
| 413 |
+
yB = min(boxA[3], boxB[3])
|
| 414 |
+
|
| 415 |
+
inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
| 416 |
+
|
| 417 |
+
boxA_area = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
| 418 |
+
boxB_area = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
| 419 |
+
|
| 420 |
+
iou = inter_area / float(boxA_area + boxB_area - inter_area)
|
| 421 |
+
return iou
|
| 422 |
+
|
| 423 |
+
def get_gt_mask_from_sam(image, sam2_model, input_points, rect_mask):
|
| 424 |
+
"""使用SAM根据两个点生成掩码,并保存选取点和掩码图"""
|
| 425 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
| 426 |
+
print("load sam2")
|
| 427 |
+
predictor.set_image(image)
|
| 428 |
+
|
| 429 |
+
input_point_np = np.array(input_points)
|
| 430 |
+
input_label = np.array([1, 1,1])
|
| 431 |
+
|
| 432 |
+
masks, _, _ = predictor.predict(
|
| 433 |
+
point_coords=input_point_np,
|
| 434 |
+
point_labels=input_label,
|
| 435 |
+
multimask_output=False,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
mask_img = masks[0].astype(np.uint8) * 255
|
| 439 |
+
# mask_img[rect_mask == 255] = 0 # 将 `random_rectangle` 区域设为黑色
|
| 440 |
+
|
| 441 |
+
# 保存SAM生成的掩码图
|
| 442 |
+
mask_save_path = 'sam_gt_mask.jpg'
|
| 443 |
+
cv2.imwrite(mask_save_path, mask_img)
|
| 444 |
+
print(f"SAM生成的掩码已保存至 {mask_save_path}")
|
| 445 |
+
|
| 446 |
+
# 把选取的两个点画在原图上
|
| 447 |
+
image_with_points = np.array(image).copy()
|
| 448 |
+
for point in input_points:
|
| 449 |
+
cv2.circle(image_with_points, point, radius=5, color=(255, 0, 0), thickness=-1)
|
| 450 |
+
|
| 451 |
+
# 保存带有标记点的原图
|
| 452 |
+
point_marked_save_path = 'image_with_points.jpg'
|
| 453 |
+
image_bgr = cv2.cvtColor(image_with_points, cv2.COLOR_RGB2BGR)
|
| 454 |
+
cv2.imwrite(point_marked_save_path, image_bgr)
|
| 455 |
+
print(f"带点标记的原图已保存至 {point_marked_save_path}")
|
| 456 |
+
|
| 457 |
+
return mask_img
|
| 458 |
+
|
| 459 |
+
def show_points(coords, labels, ax, marker_size=375):
|
| 460 |
+
pos_points = coords[labels==1]
|
| 461 |
+
neg_points = coords[labels==0]
|
| 462 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 463 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 464 |
+
|
| 465 |
+
def show_box(box, ax):
|
| 466 |
+
x0, y0 = box[0], box[1]
|
| 467 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 468 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
| 469 |
+
|
| 470 |
+
def display_mask(mask, ax, random_color=False, borders = True):
|
| 471 |
+
if random_color:
|
| 472 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 473 |
+
else:
|
| 474 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 475 |
+
h, w = mask.shape[-2:]
|
| 476 |
+
mask = mask.astype(np.uint8)
|
| 477 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 478 |
+
if borders:
|
| 479 |
+
import cv2
|
| 480 |
+
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 481 |
+
# Try to smooth contours
|
| 482 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
| 483 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
| 484 |
+
cv2.imwrite("check.jpg", mask_image)
|
| 485 |
+
ax.imshow(mask_image)
|
| 486 |
+
|
| 487 |
+
def random_points_below(point, radius, min_distance, model, image, max_attempts=100):
|
| 488 |
+
"""
|
| 489 |
+
在给定的point偏下方50像素的区域内,随机选择两个点直到满足条件。
|
| 490 |
+
|
| 491 |
+
参数:
|
| 492 |
+
- point: (x, y) 格式的坐标
|
| 493 |
+
- radius: 随机点的最大半径
|
| 494 |
+
- min_distance: 两个随机点之间的最小距离
|
| 495 |
+
- max_attempts: 最大尝试次数,避免死循环
|
| 496 |
+
|
| 497 |
+
返回:
|
| 498 |
+
- 两个随机点的坐标,如果没有找到合适的点则返回None
|
| 499 |
+
"""
|
| 500 |
+
for _ in range(max_attempts):
|
| 501 |
+
# 在点的偏下方50像素区域内随机选择两个点
|
| 502 |
+
x1 = random.randint(point[0] - radius, point[0] + radius)
|
| 503 |
+
y1 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
|
| 504 |
+
|
| 505 |
+
x2 = random.randint(point[0] - radius, point[0] + radius)
|
| 506 |
+
y2 = random.randint(point[1] + 50, point[1] + 50 + radius) # 偏下50像素
|
| 507 |
+
|
| 508 |
+
# 计算两个点之间的欧几里得距离
|
| 509 |
+
distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
|
| 510 |
+
|
| 511 |
+
# 检查距离条件
|
| 512 |
+
if distance >= min_distance and is_point_in_car_area((x1, y1), model, image) and is_point_in_car_area((x2, y2), model, image) :
|
| 513 |
+
return [(x1, y1), (x2, y2)]
|
| 514 |
+
|
| 515 |
+
# 如果超过最大尝试次数还没有找到合适的点,返回None
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def show_masks(image, masks, scores, image_path, strategy,point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
|
| 520 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 521 |
+
plt.figure(figsize=(10, 10))
|
| 522 |
+
plt.imshow(image)
|
| 523 |
+
display_mask(mask, plt.gca(), borders=borders)
|
| 524 |
+
if point_coords is not None:
|
| 525 |
+
assert input_labels is not None
|
| 526 |
+
show_points(point_coords, input_labels, plt.gca())
|
| 527 |
+
if box_coords is not None:
|
| 528 |
+
# boxes
|
| 529 |
+
show_box(box_coords, plt.gca())
|
| 530 |
+
plt.axis('off')
|
| 531 |
+
plt.savefig('check.jpg', bbox_inches='tight', pad_inches=0) # 保存图像
|
| 532 |
+
point1,point2 = attention_mask(mask, plt.gca(), image_path,strategy,borders=borders, image=image, save_path=save_path)
|
| 533 |
+
return point1,point2
|
| 534 |
+
|
| 535 |
+
def random_crop(image, target_width, target_height, mask_point1, mask_point2):
|
| 536 |
+
# global global_mask_point1_relative, global_mask_point2_relative
|
| 537 |
+
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
|
| 538 |
+
width, height = image.size
|
| 539 |
+
# 计算两个对角点的中点
|
| 540 |
+
center_x = (mask_point1[0] + mask_point2[0]) // 2
|
| 541 |
+
center_y = (mask_point1[1] + mask_point2[1]) // 2
|
| 542 |
+
|
| 543 |
+
# 计算裁剪区域的左上角和右下角
|
| 544 |
+
left = center_x - target_width // 2
|
| 545 |
+
top = center_y - target_height // 2
|
| 546 |
+
right = left + target_width
|
| 547 |
+
bottom = top + target_height
|
| 548 |
+
|
| 549 |
+
# 确保裁剪区域不会超出图像边界
|
| 550 |
+
if left < 0:
|
| 551 |
+
left = 0
|
| 552 |
+
right = target_width
|
| 553 |
+
if top < 0:
|
| 554 |
+
top = 0
|
| 555 |
+
bottom = target_height
|
| 556 |
+
if right > width:
|
| 557 |
+
right = width
|
| 558 |
+
left = width - target_width
|
| 559 |
+
if bottom > height:
|
| 560 |
+
bottom = height
|
| 561 |
+
top = height - target_height
|
| 562 |
+
|
| 563 |
+
# 计算 padding
|
| 564 |
+
top_padding = max(0, top)
|
| 565 |
+
left_padding = max(0, left)
|
| 566 |
+
|
| 567 |
+
# 裁剪图像
|
| 568 |
+
cropped_image = image.crop((left, top, right, bottom))
|
| 569 |
+
|
| 570 |
+
global_mask_point1_relative = (mask_point1[0] - left, mask_point1[1] - top)
|
| 571 |
+
global_mask_point2_relative = (mask_point2[0] - left, mask_point2[1] - top)
|
| 572 |
+
print("裁剪后点的相对位置为:")
|
| 573 |
+
print("mask_point1:", global_mask_point1_relative)
|
| 574 |
+
print("mask_point2:", global_mask_point2_relative)
|
| 575 |
+
return cropped_image, top_padding, left_padding,global_mask_point1_relative,global_mask_point2_relative
|
| 576 |
+
|
| 577 |
+
def get_left_right_points(lane_data,image_path):
|
| 578 |
+
lanes = lane_data["lanes"]
|
| 579 |
+
h_samples = lane_data["h_samples"]
|
| 580 |
+
model = load_yolov5_model()
|
| 581 |
+
# 找到h_samples的中间索引
|
| 582 |
+
mid_idx = len(h_samples) // 2
|
| 583 |
+
image = cv2.imread(image_path)
|
| 584 |
+
# 存储最左和最右的点
|
| 585 |
+
left_point = None
|
| 586 |
+
right_point = None
|
| 587 |
+
points = []
|
| 588 |
+
# 遍历每条车道线
|
| 589 |
+
for lane in lanes:
|
| 590 |
+
# 去掉值为-2的无效点
|
| 591 |
+
valid_points = [(x, y) for x, y in zip(lane, h_samples) if x != -2]
|
| 592 |
+
|
| 593 |
+
if valid_points:
|
| 594 |
+
if lane[mid_idx] != -2:
|
| 595 |
+
for i in range(mid_idx-2,0,-1):
|
| 596 |
+
left_point = lane[i]
|
| 597 |
+
print(left_point)
|
| 598 |
+
if lane[i] != -2:
|
| 599 |
+
point = (left_point,h_samples[i])
|
| 600 |
+
FLAG = is_point_in_car_area(point, model, image)
|
| 601 |
+
print(point,FLAG)
|
| 602 |
+
if FLAG:
|
| 603 |
+
points.append((left_point,h_samples[i]))
|
| 604 |
+
break
|
| 605 |
+
else:
|
| 606 |
+
point = (1540/2, 590/2+30) # 初始点坐标
|
| 607 |
+
radius = 50 # 随机点的最大半径
|
| 608 |
+
min_distance = 40 # 两个点之间的最小距离
|
| 609 |
+
points = random_points_below(point, radius, min_distance,model,image)
|
| 610 |
+
# first_non_minus_two = next((x for x in lane if x != -2), None)
|
| 611 |
+
# if first_non_minus_two:
|
| 612 |
+
# idx = lane.index(first_non_minus_two)
|
| 613 |
+
# for i in range(idx+5,idx,-1):
|
| 614 |
+
# left_point = lane[i]
|
| 615 |
+
# if lane[i] != -2:
|
| 616 |
+
# point = (left_point,h_samples[i])
|
| 617 |
+
# FLAG = is_point_in_car_area(point, model, image)
|
| 618 |
+
# if FLAG:
|
| 619 |
+
# points.append((left_point,h_samples[i]))
|
| 620 |
+
# break
|
| 621 |
+
|
| 622 |
+
# return left_point, right_point
|
| 623 |
+
return points
|
| 624 |
+
|
| 625 |
+
def sam2segment(image_path,points,strategy):
|
| 626 |
+
# print(points)
|
| 627 |
+
image = Image.open(image_path)
|
| 628 |
+
image = np.array(image.convert("RGB"))
|
| 629 |
+
predictor.set_image(image)
|
| 630 |
+
# print([points[0][0], points[0][1]])
|
| 631 |
+
input_point = np.array([(points[0][0], points[0][1])])
|
| 632 |
+
input_label = np.array([1])
|
| 633 |
+
masks, scores, logits = predictor.predict(
|
| 634 |
+
point_coords=input_point,
|
| 635 |
+
point_labels=input_label,
|
| 636 |
+
multimask_output=True,
|
| 637 |
+
)
|
| 638 |
+
sorted_ind = np.argsort(scores)[::-1]
|
| 639 |
+
masks = masks[sorted_ind]
|
| 640 |
+
scores = scores[sorted_ind]
|
| 641 |
+
logits = logits[sorted_ind]
|
| 642 |
+
#mask
|
| 643 |
+
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
| 644 |
+
points_set = []
|
| 645 |
+
for point in points:
|
| 646 |
+
points_set.append((point[0], point[1]))
|
| 647 |
+
# print(points_set)
|
| 648 |
+
input_point = np.array(points_set)
|
| 649 |
+
input_label = np.array([1]*len(points_set))
|
| 650 |
+
masks, scores, _ = predictor.predict(
|
| 651 |
+
point_coords=input_point,
|
| 652 |
+
point_labels=input_label,
|
| 653 |
+
mask_input=mask_input[None, :, :],
|
| 654 |
+
multimask_output=False,
|
| 655 |
+
)
|
| 656 |
+
# random_mask_selection(image, masks, mask_index=0,output_path="cropped_image.jpg")
|
| 657 |
+
point1,point2 = show_masks(image, masks, scores, image_path, strategy,save_path="masked_image.jpg")
|
| 658 |
+
return point1,point2
|
| 659 |
+
|
| 660 |
+
def draw_point(image_path,points):
|
| 661 |
+
image = cv2.imread(image_path)
|
| 662 |
+
if image is not None:
|
| 663 |
+
# 绘制点
|
| 664 |
+
for point in points:
|
| 665 |
+
cv2.circle(image, point, radius=5, color=(0, 255, 0), thickness=-1) # 绿色点
|
| 666 |
+
|
| 667 |
+
# 保存图像
|
| 668 |
+
output_path = "output_image_with_points.jpg"
|
| 669 |
+
cv2.imwrite(output_path, image)
|
| 670 |
+
print(f"Image saved with points at {output_path}")
|
| 671 |
+
else:
|
| 672 |
+
print("Error: Image could not be loaded.")
|
| 673 |
+
|
| 674 |
+
def generate_mask(original_img_path, point1, point2):
|
| 675 |
+
"""根据坐标生成掩码图像"""
|
| 676 |
+
# 读取原图
|
| 677 |
+
original_img = cv2.imread(original_img_path)
|
| 678 |
+
|
| 679 |
+
# 获取原图的尺寸
|
| 680 |
+
height, width, _ = original_img.shape
|
| 681 |
+
|
| 682 |
+
# 创建一个黑色的 mask 图像,尺寸与原图相同
|
| 683 |
+
mask = np.zeros((height, width), dtype=np.uint8)
|
| 684 |
+
|
| 685 |
+
# 计算3/4点
|
| 686 |
+
three_quarter_point = (
|
| 687 |
+
int(point1[0] + 0.95 * (point2[0] - point1[0])), # 计算 x 坐标
|
| 688 |
+
int(point1[1] + 0.95 * (point2[1] - point1[1])) # 计算 y 坐标
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# 画出一个白色的矩形(将该区域填充为白色)
|
| 692 |
+
cv2.rectangle(mask, point1, three_quarter_point, color=255, thickness=-1)
|
| 693 |
+
|
| 694 |
+
# 保存生成的mask图像
|
| 695 |
+
mask_path = original_img_path.replace('test.jpg', 'mask_test.jpg')
|
| 696 |
+
cv2.imwrite(mask_path, mask)
|
| 697 |
+
print(mask_path)
|
| 698 |
+
return mask_path, point1, three_quarter_point
|
| 699 |
+
|
| 700 |
+
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
|
| 701 |
+
"""
|
| 702 |
+
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
|
| 703 |
+
"""
|
| 704 |
+
cropped_lanes = []
|
| 705 |
+
for lane in lane_data["lanes"]:
|
| 706 |
+
cropped_lane = []
|
| 707 |
+
for x, y in zip(lane, lane_data["h_samples"]):
|
| 708 |
+
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
|
| 709 |
+
cropped_lane.append((x, y))
|
| 710 |
+
# new_x = x - crop_x_min
|
| 711 |
+
# new_y = y - crop_y_min
|
| 712 |
+
# cropped_lane.append((new_x, new_y))
|
| 713 |
+
if cropped_lane:
|
| 714 |
+
cropped_lanes.append(cropped_lane)
|
| 715 |
+
|
| 716 |
+
return cropped_lanes
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def generate_trigger_crop(image_path: str, lane_data: dict):
|
| 720 |
+
"""
|
| 721 |
+
输入一张图像路径,返回处理后的 crop 图像和 crop mask 图像路径。
|
| 722 |
+
"""
|
| 723 |
+
# 1. 获取触发点
|
| 724 |
+
points = get_left_right_points(lane_data, image_path)
|
| 725 |
+
print(f"[INFO] 获取 trigger 点: {points}")
|
| 726 |
+
draw_point(image_path, points)
|
| 727 |
+
|
| 728 |
+
# 2. 使用 SAM2 获取 mask 点
|
| 729 |
+
image = load_image(image_path)
|
| 730 |
+
mask_point1, mask_point2 = sam2segment(image_path, points, "LDA")
|
| 731 |
+
|
| 732 |
+
# 3. Crop 原图
|
| 733 |
+
input_image, *_ = random_crop(image, 512, 512, mask_point1, mask_point2)
|
| 734 |
+
input_crop_path = "crop.jpg"
|
| 735 |
+
input_image.save(input_crop_path)
|
| 736 |
+
|
| 737 |
+
# 4. 生成 trigger mask
|
| 738 |
+
mask_path, point1, point2 = generate_mask(image_path, mask_point1, mask_point2)
|
| 739 |
+
mask_img = load_image(mask_path)
|
| 740 |
+
mask_img, *_ = random_crop(mask_img, 512, 512, mask_point1, mask_point2)
|
| 741 |
+
crop_mask_path = "crop_mask.jpg"
|
| 742 |
+
cv2.imwrite(crop_mask_path, np.array(mask_img))
|
| 743 |
+
|
| 744 |
+
return input_crop_path, crop_mask_path
|
| 745 |
+
|
| 746 |
+
if __name__ == "__main__":
|
| 747 |
+
lane_data = {"lanes": [[-2, -2, -2, -2, -2, -2, -2, 814, 751, 688, 625, 562, 500, 438, 373, 305, 234, 160, 88, 16, -64, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, 818, 801, 784, 768, 751, 734, 717, 701, 685, 668, 651, 634, 618, 601, 585, 568, 551, 535, 518, 502, 484, 468, 451, 435, 418, 401, 385, 368, 351, 335, 318, 301, 287], [-2, -2, -2, -2, -2, -2, -2, 863, 872, 881, 890, 899, 908, 918, 927, 936, 945, 954, 964, 972, 982, 991, 1000, 1009, 1018, 1027, 1036, 1046, 1055, 1064, 1073, 1082, 1091, 1100, 1109, 1119, 1128, 1137, 1146, 1154]], "h_samples": [200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590], "raw_file": "driver_182_30frame/06010513_0036.MP4/00270.jpg"}
|
| 748 |
+
|
| 749 |
+
image_path = "driver_182_30frame/06010513_0036.MP4/00270.jpg"
|
| 750 |
+
points = get_left_right_points(lane_data,image_path)
|
| 751 |
+
print(points)
|
| 752 |
+
draw_point(image_path,points)
|
| 753 |
+
# left_point, right_point = get_left_right_points(lane_data)
|
| 754 |
+
# print(f"Left point: {left_point}, Right point: {right_point}")
|
| 755 |
+
# sam2segment(image_path,left_point, right_point)
|
| 756 |
+
image = load_image(image_path)
|
| 757 |
+
mask_point1,mask_point2 = sam2segment(image_path,points,"LDA")
|
| 758 |
+
input_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(image, 512, 512, mask_point1, mask_point2)
|
| 759 |
+
input_image.save("crop.jpg") # 直接用 PIL 的 `save()` 方法
|
| 760 |
+
print(f"Image saved with points at crop.jpg")
|
| 761 |
+
mask_path, point1, point2 = generate_mask('culane_test.jpg', mask_point1, mask_point2)
|
| 762 |
+
mask_img = load_image(mask_path)
|
| 763 |
+
mask_img,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(mask_img, 512, 512,mask_point1,mask_point2)
|
| 764 |
+
|
| 765 |
+
mask_img = np.array(mask_img)
|
| 766 |
+
# print(mask_img.shape)
|
| 767 |
+
model = load_yolov5_model()
|
| 768 |
+
yolo_results = model(input_image)
|
| 769 |
+
yolo_boxes = []
|
| 770 |
+
car_class_id = [2, 5, 7] # 汽车、巴士、卡车等类别ID,根据实际情况调整
|
| 771 |
+
|
| 772 |
+
for result in yolo_results:
|
| 773 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
| 774 |
+
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
| 775 |
+
|
| 776 |
+
for box, cls in zip(boxes, class_ids):
|
| 777 |
+
if cls in car_class_id:
|
| 778 |
+
x_min, y_min, x_max, y_max = box[:4]
|
| 779 |
+
yolo_boxes.append([int(x_min), int(y_min), int(x_max), int(y_max)])
|
| 780 |
+
_,mask_img=generate_gt_mask_from_intersection([global_mask_point1_relative,global_mask_point2_relative], yolo_boxes, input_image, mask_img,sam2_model, threshold_iou=0.01)
|
| 781 |
+
cv2.imwrite("crop_mask.jpg", mask_img)
|
| 782 |
+
|
| 783 |
+
print("Mask 已成功保存至 crop_mask.jpg")
|
| 784 |
+
crop_x_min = min(mask_point1[0], mask_point2[0])
|
| 785 |
+
crop_x_max = max(mask_point1[0], mask_point2[0])
|
| 786 |
+
crop_y_min = min(mask_point1[1], mask_point2[1])
|
| 787 |
+
crop_y_max = max(mask_point1[1], mask_point2[1])
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max):
|
| 791 |
+
"""
|
| 792 |
+
过滤 TuSimple `lanes`,只保留 `crop` 内的部分
|
| 793 |
+
"""
|
| 794 |
+
cropped_lanes = []
|
| 795 |
+
for lane in lane_data["lanes"]:
|
| 796 |
+
cropped_lane = []
|
| 797 |
+
for x, y in zip(lane, lane_data["h_samples"]):
|
| 798 |
+
if x != -2 and crop_x_min <= x <= crop_x_max and crop_y_min <= y <= crop_y_max:
|
| 799 |
+
cropped_lane.append((x, y))
|
| 800 |
+
# new_x = x - crop_x_min
|
| 801 |
+
# new_y = y - crop_y_min
|
| 802 |
+
# cropped_lane.append((new_x, new_y))
|
| 803 |
+
if cropped_lane:
|
| 804 |
+
cropped_lanes.append(cropped_lane)
|
| 805 |
+
|
| 806 |
+
return cropped_lanes
|
| 807 |
+
|
| 808 |
+
# **获取在 crop 范围内的 lane**
|
| 809 |
+
cropped_lanes = extract_lanes_in_crop(lane_data, crop_x_min, crop_x_max, crop_y_min, crop_y_max)
|
| 810 |
+
# print(cropped_lanes)
|
| 811 |
+
# def draw_lane_mask(image, lanes):
|
| 812 |
+
# """
|
| 813 |
+
# 画出 `lane_mask` 只在 `crop` 图像中
|
| 814 |
+
# """
|
| 815 |
+
# height, width, _ = image.shape
|
| 816 |
+
# lane_mask = np.zeros((height, width), dtype=np.uint8)
|
| 817 |
+
|
| 818 |
+
# for lane in lanes:
|
| 819 |
+
# points = np.array(lane, dtype=np.int32)
|
| 820 |
+
# cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=5)
|
| 821 |
+
|
| 822 |
+
# return lane_mask
|
| 823 |
+
|
| 824 |
+
# crop_image = load_image("crop.jpg").convert("RGB")
|
| 825 |
+
# crop_image = np.array(crop_image)
|
| 826 |
+
# lane_mask = draw_lane_mask(crop_image, cropped_lanes)
|
| 827 |
+
def draw_lane_mask_on_original(image, cropped_lanes):
|
| 828 |
+
"""
|
| 829 |
+
在原图上绘制 **仅包含 cropped_lanes** 的车道线
|
| 830 |
+
"""
|
| 831 |
+
height, width, _ = image.shape
|
| 832 |
+
lane_mask = np.zeros((height, width), dtype=np.uint8)
|
| 833 |
+
|
| 834 |
+
for lane in cropped_lanes:
|
| 835 |
+
points = np.array(lane, dtype=np.int32)
|
| 836 |
+
cv2.polylines(lane_mask, [points], isClosed=False, color=255, thickness=10)
|
| 837 |
+
|
| 838 |
+
return lane_mask
|
| 839 |
+
|
| 840 |
+
def random_crop_lane(image, target_width, target_height, mask_point1, mask_point2):
|
| 841 |
+
"""从两个对角点的中点裁剪指定宽度和高度的区域,避免超出图像边界"""
|
| 842 |
+
|
| 843 |
+
# **确保 image 是 NumPy 数组**
|
| 844 |
+
if isinstance(image, Image.Image):
|
| 845 |
+
image = np.array(image)
|
| 846 |
+
|
| 847 |
+
height, width = image.shape[:2] # 获取 NumPy 数组的大小
|
| 848 |
+
|
| 849 |
+
# 计算两个对角点的中点
|
| 850 |
+
center_x = (mask_point1[0] + mask_point2[0]) // 2
|
| 851 |
+
center_y = (mask_point1[1] + mask_point2[1]) // 2
|
| 852 |
+
|
| 853 |
+
# 计算裁剪区域的左上角和右下角
|
| 854 |
+
left = max(0, center_x - target_width // 2)
|
| 855 |
+
top = max(0, center_y - target_height // 2)
|
| 856 |
+
right = min(width, left + target_width)
|
| 857 |
+
bottom = min(height, top + target_height)
|
| 858 |
+
|
| 859 |
+
# 计算 padding(如果裁剪区域超出边界)
|
| 860 |
+
top_padding = max(0, target_height - (bottom - top))
|
| 861 |
+
left_padding = max(0, target_width - (right - left))
|
| 862 |
+
|
| 863 |
+
# **使用 NumPy 进行裁剪**
|
| 864 |
+
cropped_image = image[top:bottom, left:right]
|
| 865 |
+
|
| 866 |
+
return cropped_image, top_padding, left_padding
|
| 867 |
+
# **绘制 lane_mask 在原图上**
|
| 868 |
+
raw_image = np.array(load_image(image_path).convert("RGB"))
|
| 869 |
+
lane_mask = draw_lane_mask_on_original(raw_image, cropped_lanes)
|
| 870 |
+
lane_mask_pil = Image.fromarray(lane_mask)
|
| 871 |
+
crop_image,top_padding,left_padding,global_mask_point1_relative,global_mask_point2_relative = random_crop(lane_mask_pil, 512, 512,mask_point1,mask_point2)
|
| 872 |
+
|
| 873 |
+
# **保存 lane_mask**
|
| 874 |
+
crop_image.save("lane_mask_crop.jpg")
|
| 875 |
+
print("✅ 车道 Mask 已保存为 lane_mask_crop.jpg")
|
| 876 |
+
|
| 877 |
+
crop_img = cv2.imread("crop.jpg") # 读取原图(BGR格式)
|
| 878 |
+
mask_img = cv2.imread("crop_mask.jpg", cv2.IMREAD_GRAYSCALE) # 读取掩码(灰度图)
|
| 879 |
+
if crop_img.shape[:2] != mask_img.shape:
|
| 880 |
+
print("⚠️ Resizing mask to match crop image size...")
|
| 881 |
+
mask_img = cv2.resize(mask_img, (crop_img.shape[1], crop_img.shape[0]))
|
| 882 |
+
white_overlay = np.ones_like(crop_img) * 255 # 生成全白图
|
| 883 |
+
masked_result = np.where(mask_img[:, :, None] == 255, white_overlay, crop_img) # 只替换白色部分
|
| 884 |
+
|
| 885 |
+
# **保存叠加后的图像**
|
| 886 |
+
cv2.imwrite("crop_with_mask.jpg", masked_result)
|
| 887 |
+
print("✅ 叠加后的 Mask 图像已保存至 crop_with_mask.jpg")
|
yolo11n.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ebbc80d4a7680d14987a577cd21342b65ecfd94632bd9a8da63ae6417644ee1
|
| 3 |
+
size 5613764
|