Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
2c719e3
1
Parent(s):
60434a4
integrate the function
Browse files- app.py +2 -2
- modules/sam_inference.py +75 -30
app.py
CHANGED
|
@@ -73,14 +73,14 @@ class App:
|
|
| 73 |
output_file = gr.File(label="Generated psd file", scale=9)
|
| 74 |
btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
|
| 75 |
|
| 76 |
-
sources = [img_input]
|
| 77 |
model_params = [dd_models]
|
| 78 |
auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
|
| 79 |
sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
|
| 80 |
sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
|
| 81 |
cb_use_m2m]
|
| 82 |
|
| 83 |
-
btn_generate.click(fn=self.sam_inf.
|
| 84 |
inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
|
| 85 |
btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
|
| 86 |
inputs=None, outputs=None)
|
|
|
|
| 73 |
output_file = gr.File(label="Generated psd file", scale=9)
|
| 74 |
btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
|
| 75 |
|
| 76 |
+
sources = [img_input, img_input_prompter, dd_input_modes]
|
| 77 |
model_params = [dd_models]
|
| 78 |
auto_mask_hparams = [nb_points_per_side, nb_points_per_batch, sld_pred_iou_thresh,
|
| 79 |
sld_stability_score_thresh, sld_stability_score_offset, nb_crop_n_layers,
|
| 80 |
sld_box_nms_thresh, nb_crop_n_points_downscale_factor, nb_min_mask_region_area,
|
| 81 |
cb_use_m2m]
|
| 82 |
|
| 83 |
+
btn_generate.click(fn=self.sam_inf.divide_layer,
|
| 84 |
inputs=sources + model_params + auto_mask_hparams, outputs=[gallery_output, output_file])
|
| 85 |
btn_open_folder.click(fn=lambda: open_folder(os.path.join(OUTPUT_DIR)),
|
| 86 |
inputs=None, outputs=None)
|
modules/sam_inference.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 2 |
from sam2.build_sam import build_sam2
|
| 3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
|
| 4 |
import torch
|
| 5 |
import os
|
| 6 |
from datetime import datetime
|
|
@@ -12,6 +13,7 @@ from modules.model_downloader import (
|
|
| 12 |
download_sam_model_url
|
| 13 |
)
|
| 14 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
|
|
|
| 15 |
from modules.mask_utils import (
|
| 16 |
save_psd_with_masks,
|
| 17 |
create_mask_combined_images,
|
|
@@ -62,42 +64,85 @@ class SamInference:
|
|
| 62 |
print(f"Error while Loading SAM2 model! {e}")
|
| 63 |
|
| 64 |
def generate_mask(self,
|
| 65 |
-
image: np.ndarray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return self.mask_generator.generate(image)
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
maskgen_hparams = {
|
| 74 |
-
'points_per_side': int(params[0]),
|
| 75 |
-
'points_per_batch': int(params[1]),
|
| 76 |
-
'pred_iou_thresh': float(params[2]),
|
| 77 |
-
'stability_score_thresh': float(params[3]),
|
| 78 |
-
'stability_score_offset': float(params[4]),
|
| 79 |
-
'crop_n_layers': int(params[5]),
|
| 80 |
-
'box_nms_thresh': float(params[6]),
|
| 81 |
-
'crop_n_points_downscale_factor': int(params[7]),
|
| 82 |
-
'min_mask_region_area': int(params[8]),
|
| 83 |
-
'use_m2m': bool(params[9])
|
| 84 |
-
}
|
| 85 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 86 |
-
output_file_name = f"result-{timestamp}.psd"
|
| 87 |
-
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 88 |
-
|
| 89 |
if self.model is None or self.model_type != model_type:
|
| 90 |
self.model_type = model_type
|
| 91 |
self.load_model()
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
combined_image = create_mask_combined_images(image, masks)
|
| 102 |
-
gallery = create_mask_gallery(image, masks)
|
| 103 |
-
return [combined_image] + gallery, output_path
|
|
|
|
| 1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 2 |
from sam2.build_sam import build_sam2
|
| 3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 4 |
+
from typing import Dict, List
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
from datetime import datetime
|
|
|
|
| 13 |
download_sam_model_url
|
| 14 |
)
|
| 15 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
|
| 16 |
+
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE
|
| 17 |
from modules.mask_utils import (
|
| 18 |
save_psd_with_masks,
|
| 19 |
create_mask_combined_images,
|
|
|
|
| 64 |
print(f"Error while Loading SAM2 model! {e}")
|
| 65 |
|
| 66 |
def generate_mask(self,
|
| 67 |
+
image: np.ndarray,
|
| 68 |
+
model_type: str,
|
| 69 |
+
**params):
|
| 70 |
+
if self.model is None or self.model_type != model_type:
|
| 71 |
+
self.model_type = model_type
|
| 72 |
+
self.load_model()
|
| 73 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 74 |
+
model=self.model,
|
| 75 |
+
**params
|
| 76 |
+
)
|
| 77 |
return self.mask_generator.generate(image)
|
| 78 |
|
| 79 |
+
def predict_image(self,
|
| 80 |
+
image: np.ndarray,
|
| 81 |
+
model_type: str,
|
| 82 |
+
box: np.ndarray,
|
| 83 |
+
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if self.model is None or self.model_type != model_type:
|
| 85 |
self.model_type = model_type
|
| 86 |
self.load_model()
|
| 87 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
| 88 |
+
self.image_predictor.set_image(image)
|
| 89 |
|
| 90 |
+
masks, scores, logits = self.image_predictor.predict(
|
| 91 |
+
box=box,
|
| 92 |
+
multimask_output=params["multimask_output"],
|
| 93 |
)
|
| 94 |
+
print(f"masks: {masks}")
|
| 95 |
+
print(f"scores: {scores}")
|
| 96 |
+
print(f"logits: {logits}")
|
| 97 |
+
return masks, scores, logits
|
| 98 |
+
|
| 99 |
+
def divide_layer(self,
|
| 100 |
+
image_input: np.ndarray,
|
| 101 |
+
image_prompt_input_data: Dict,
|
| 102 |
+
input_mode: str,
|
| 103 |
+
model_type: str,
|
| 104 |
+
*params):
|
| 105 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 106 |
+
output_file_name = f"result-{timestamp}.psd"
|
| 107 |
+
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 108 |
+
|
| 109 |
+
if input_mode == AUTOMATIC_MODE:
|
| 110 |
+
image = image_input
|
| 111 |
+
maskgen_hparams = {
|
| 112 |
+
'points_per_side': int(params[0]),
|
| 113 |
+
'points_per_batch': int(params[1]),
|
| 114 |
+
'pred_iou_thresh': float(params[2]),
|
| 115 |
+
'stability_score_thresh': float(params[3]),
|
| 116 |
+
'stability_score_offset': float(params[4]),
|
| 117 |
+
'crop_n_layers': int(params[5]),
|
| 118 |
+
'box_nms_thresh': float(params[6]),
|
| 119 |
+
'crop_n_points_downscale_factor': int(params[7]),
|
| 120 |
+
'min_mask_region_area': int(params[8]),
|
| 121 |
+
'use_m2m': bool(params[9])
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
generated_masks = self.generate_mask(
|
| 125 |
+
image=image,
|
| 126 |
+
model_type=model_type,
|
| 127 |
+
**maskgen_hparams
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
elif input_mode == BOX_PROMPT_MODE:
|
| 131 |
+
image = image_prompt_input_data["image"]
|
| 132 |
+
box = image_prompt_input_data["points"]
|
| 133 |
+
predict_image_hparams = {
|
| 134 |
+
"multimask_output": params[0]
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
generated_masks, scores, logits = self.predict_image(
|
| 138 |
+
image=image,
|
| 139 |
+
model_type=model_type,
|
| 140 |
+
box=box,
|
| 141 |
+
**predict_image_hparams
|
| 142 |
+
)
|
| 143 |
|
| 144 |
+
save_psd_with_masks(image, generated_masks, output_path)
|
| 145 |
+
mask_combined_image = create_mask_combined_images(image, generated_masks)
|
| 146 |
+
gallery = create_mask_gallery(image, generated_masks)
|
| 147 |
|
| 148 |
+
return [mask_combined_image] + gallery, output_path
|
|
|
|
|
|
|
|
|