File size: 2,884 Bytes
194b4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from PIL import Image
from torch.cuda import OutOfMemoryError
from scripts.supported_preprocessor import Preprocessor
from scripts.utils import resize_image_with_pad
from scripts import sam
from modules import shared, sd_models, devices
from cn_sam_preprocessor.tools import convertImageIntoPILFormat, convertIntoCNImageFormat
from cn_sam_preprocessor.options import (getTemplate, getAutoSamOptions,
getSegmentAnythingModel, needAutoUnloadModels, avoidOOM,
)
def unloadSAM():
sam.clear_cache()
devices.torch_gc()
def processAutoSegmentAnything(image: Image.Image):
sam_model_name = getSegmentAnythingModel(sam.sam_model_list)
nones = ['random', None, None, None, None, None]
auto_sam_config = getAutoSamOptions()
print(sam_model_name, auto_sam_config)
result = sam.cnet_seg(sam_model_name, image, *nones, *auto_sam_config)
print(result[1])
return result[0][1]
def processAutoSegmentAnything_cpu(image: Image.Image):
print('Using cpu for auto segmentation')
oldDevice = devices.device
oldSamDevice = sam.sam_device
devices.device = 'cpu'
sam.sam_device = 'cpu'
try:
return processAutoSegmentAnything(image)
finally:
devices.device = oldDevice
sam.sam_device = oldSamDevice
def processAutoSegmentAnything_avoidOOM(image: Image.Image):
try:
result = processAutoSegmentAnything(image)
except OutOfMemoryError:
print("\nOut of Memory. Unload Stable Diffusion\n")
unloadSAM()
sd_models.unload_model_weights()
try:
result = processAutoSegmentAnything(image)
except OutOfMemoryError:
print("\nOut of Memory. Use CPU\n")
unloadSAM()
result = processAutoSegmentAnything_cpu(image)
return result
class PreprocessorSegmentAnything(Preprocessor):
NAME = "segment_anything"
def __init__(self):
super().__init__(name=self.NAME)
self.tags = ["Segmentation"]
def unload(self) -> bool:
"""@Override"""
unloadSAM()
return True
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
img = convertImageIntoPILFormat(img)
if avoidOOM():
result = processAutoSegmentAnything_avoidOOM(img)
else:
result = processAutoSegmentAnything(img)
if needAutoUnloadModels(): unloadSAM()
result = convertIntoCNImageFormat(result)
result = remove_pad(result)
return result
shared.options_templates.update(getTemplate(sam.sam_model_list))
if not Preprocessor.get_preprocessor(PreprocessorSegmentAnything.NAME):
Preprocessor.add_supported_preprocessor(PreprocessorSegmentAnything())
|