dikdimon's picture
Upload exhm using SD-Hub extension
194b4ef verified
from PIL import Image
from torch.cuda import OutOfMemoryError
from scripts.supported_preprocessor import Preprocessor
from scripts.utils import resize_image_with_pad
from scripts import sam
from modules import shared, sd_models, devices
from cn_sam_preprocessor.tools import convertImageIntoPILFormat, convertIntoCNImageFormat
from cn_sam_preprocessor.options import (getTemplate, getAutoSamOptions,
getSegmentAnythingModel, needAutoUnloadModels, avoidOOM,
)
def unloadSAM():
sam.clear_cache()
devices.torch_gc()
def processAutoSegmentAnything(image: Image.Image):
sam_model_name = getSegmentAnythingModel(sam.sam_model_list)
nones = ['random', None, None, None, None, None]
auto_sam_config = getAutoSamOptions()
print(sam_model_name, auto_sam_config)
result = sam.cnet_seg(sam_model_name, image, *nones, *auto_sam_config)
print(result[1])
return result[0][1]
def processAutoSegmentAnything_cpu(image: Image.Image):
print('Using cpu for auto segmentation')
oldDevice = devices.device
oldSamDevice = sam.sam_device
devices.device = 'cpu'
sam.sam_device = 'cpu'
try:
return processAutoSegmentAnything(image)
finally:
devices.device = oldDevice
sam.sam_device = oldSamDevice
def processAutoSegmentAnything_avoidOOM(image: Image.Image):
try:
result = processAutoSegmentAnything(image)
except OutOfMemoryError:
print("\nOut of Memory. Unload Stable Diffusion\n")
unloadSAM()
sd_models.unload_model_weights()
try:
result = processAutoSegmentAnything(image)
except OutOfMemoryError:
print("\nOut of Memory. Use CPU\n")
unloadSAM()
result = processAutoSegmentAnything_cpu(image)
return result
class PreprocessorSegmentAnything(Preprocessor):
NAME = "segment_anything"
def __init__(self):
super().__init__(name=self.NAME)
self.tags = ["Segmentation"]
def unload(self) -> bool:
"""@Override"""
unloadSAM()
return True
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
img = convertImageIntoPILFormat(img)
if avoidOOM():
result = processAutoSegmentAnything_avoidOOM(img)
else:
result = processAutoSegmentAnything(img)
if needAutoUnloadModels(): unloadSAM()
result = convertIntoCNImageFormat(result)
result = remove_pad(result)
return result
shared.options_templates.update(getTemplate(sam.sam_model_list))
if not Preprocessor.get_preprocessor(PreprocessorSegmentAnything.NAME):
Preprocessor.add_supported_preprocessor(PreprocessorSegmentAnything())