Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
cfa5142
1
Parent(s):
baa2a55
Add inference script
Browse files- modules/sam_inference.py +119 -0
modules/sam_inference.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 2 |
+
from sam2.build_sam import build_sam2
|
| 3 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from modules.model_downloader import (
|
| 10 |
+
AVAILABLE_MODELS,
|
| 11 |
+
DEFAULT_MODEL_TYPE,
|
| 12 |
+
OUTPUT_DIR,
|
| 13 |
+
is_sam_exist,
|
| 14 |
+
download_sam_model_url
|
| 15 |
+
)
|
| 16 |
+
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
| 17 |
+
from modules.mask_utils import (
|
| 18 |
+
save_psd_with_masks,
|
| 19 |
+
create_mask_combined_images,
|
| 20 |
+
create_mask_gallery
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
CONFIGS = {
|
| 24 |
+
"sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
|
| 25 |
+
"sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
|
| 26 |
+
"sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
|
| 27 |
+
"sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SamInference:
|
| 32 |
+
def __init__(self,
|
| 33 |
+
model_dir: str = MODELS_DIR,
|
| 34 |
+
output_dir: str = OUTPUT_DIR
|
| 35 |
+
):
|
| 36 |
+
self.model = None
|
| 37 |
+
self.available_models = list(AVAILABLE_MODELS.keys())
|
| 38 |
+
self.model_type = DEFAULT_MODEL_TYPE
|
| 39 |
+
self.model_dir = model_dir
|
| 40 |
+
self.output_dir = output_dir
|
| 41 |
+
self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
|
| 42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
+
self.mask_generator = None
|
| 44 |
+
self.image_predictor = None
|
| 45 |
+
|
| 46 |
+
# Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb
|
| 47 |
+
self.maskgen_hparams = {
|
| 48 |
+
"points_per_side": 64,
|
| 49 |
+
"points_per_batch": 128,
|
| 50 |
+
"pred_iou_thresh": 0.7,
|
| 51 |
+
"stability_score_thresh": 0.92,
|
| 52 |
+
"stability_score_offset": 0.7,
|
| 53 |
+
"crop_n_layers": 1,
|
| 54 |
+
"box_nms_thresh": 0.7,
|
| 55 |
+
"crop_n_points_downscale_factor": 2,
|
| 56 |
+
"min_mask_region_area": 25.0,
|
| 57 |
+
"use_m2m": True,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
def load_model(self):
|
| 61 |
+
config = CONFIGS[self.model_type]
|
| 62 |
+
filename, url = AVAILABLE_MODELS[self.model_type]
|
| 63 |
+
model_path = os.path.join(self.model_dir, filename)
|
| 64 |
+
|
| 65 |
+
if not is_sam_exist(self.model_type):
|
| 66 |
+
print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...")
|
| 67 |
+
download_sam_model_url(self.model_type)
|
| 68 |
+
print("\nLayer Divider Extension : applying configs to model..")
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
self.model = build_sam2(
|
| 72 |
+
config_file=config,
|
| 73 |
+
ckpt_path=model_path,
|
| 74 |
+
device=self.device
|
| 75 |
+
)
|
| 76 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
| 77 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 78 |
+
model=self.model,
|
| 79 |
+
**self.maskgen_hparams
|
| 80 |
+
)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
| 83 |
+
|
| 84 |
+
def generate_mask(self,
|
| 85 |
+
image: np.ndarray):
|
| 86 |
+
return self.mask_generator.generate(image)
|
| 87 |
+
|
| 88 |
+
def generate_mask_app(self,
|
| 89 |
+
image: np.ndarray,
|
| 90 |
+
model_type: str,
|
| 91 |
+
*params
|
| 92 |
+
):
|
| 93 |
+
maskgen_hparams = {
|
| 94 |
+
'points_per_side': int(params[0]),
|
| 95 |
+
'points_per_batch': int(params[1]),
|
| 96 |
+
'pred_iou_thresh': float(params[2]),
|
| 97 |
+
'stability_score_thresh': float(params[3]),
|
| 98 |
+
'stability_score_offset': float(params[4]),
|
| 99 |
+
'crop_n_layers': int(params[5]),
|
| 100 |
+
'box_nms_thresh': float(params[6]),
|
| 101 |
+
'crop_n_points_downscale_factor': int(params[7]),
|
| 102 |
+
'min_mask_region_area': int(params[8]),
|
| 103 |
+
'use_m2m': bool(params[9])
|
| 104 |
+
}
|
| 105 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 106 |
+
output_file_name = f"result-{timestamp}.psd"
|
| 107 |
+
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 108 |
+
|
| 109 |
+
if self.model is None or self.mask_generator is None or self.model_type != model_type or self.maskgen_hparams != maskgen_hparams:
|
| 110 |
+
self.model_type = model_type
|
| 111 |
+
self.maskgen_hparams = maskgen_hparams
|
| 112 |
+
self.load_model()
|
| 113 |
+
|
| 114 |
+
masks = self.mask_generator.generate(image)
|
| 115 |
+
|
| 116 |
+
save_psd_with_masks(image, masks, output_path)
|
| 117 |
+
combined_image = create_mask_combined_images(image, masks)
|
| 118 |
+
gallery = create_mask_gallery(image, masks)
|
| 119 |
+
return [combined_image] + gallery, output_path
|