Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
8d52a7d
1
Parent(s):
f822d17
Divide model / predictors
Browse files- modules/sam_inference.py +15 -7
modules/sam_inference.py
CHANGED
|
@@ -71,14 +71,19 @@ class SamInference:
|
|
| 71 |
ckpt_path=model_path,
|
| 72 |
device=self.device
|
| 73 |
)
|
| 74 |
-
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
| 75 |
-
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 76 |
-
model=self.model,
|
| 77 |
-
**self.maskgen_hparams
|
| 78 |
-
)
|
| 79 |
except Exception as e:
|
| 80 |
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def generate_mask(self,
|
| 83 |
image: np.ndarray):
|
| 84 |
return self.mask_generator.generate(image)
|
|
@@ -104,11 +109,14 @@ class SamInference:
|
|
| 104 |
output_file_name = f"result-{timestamp}.psd"
|
| 105 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 106 |
|
| 107 |
-
if self.model is None or self.
|
| 108 |
self.model_type = model_type
|
| 109 |
-
self.maskgen_hparams = maskgen_hparams
|
| 110 |
self.load_model()
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
masks = self.mask_generator.generate(image)
|
| 113 |
|
| 114 |
save_psd_with_masks(image, masks, output_path)
|
|
|
|
| 71 |
ckpt_path=model_path,
|
| 72 |
device=self.device
|
| 73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
| 76 |
|
| 77 |
+
def set_predictors(self):
|
| 78 |
+
if self.model is None:
|
| 79 |
+
self.load_model()
|
| 80 |
+
|
| 81 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
| 82 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 83 |
+
model=self.model,
|
| 84 |
+
**self.maskgen_hparams
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
def generate_mask(self,
|
| 88 |
image: np.ndarray):
|
| 89 |
return self.mask_generator.generate(image)
|
|
|
|
| 109 |
output_file_name = f"result-{timestamp}.psd"
|
| 110 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 111 |
|
| 112 |
+
if self.model is None or self.model_type != model_type:
|
| 113 |
self.model_type = model_type
|
|
|
|
| 114 |
self.load_model()
|
| 115 |
|
| 116 |
+
if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams:
|
| 117 |
+
self.maskgen_hparams = maskgen_hparams
|
| 118 |
+
self.set_predictors()
|
| 119 |
+
|
| 120 |
masks = self.mask_generator.generate(image)
|
| 121 |
|
| 122 |
save_psd_with_masks(image, masks, output_path)
|