from PIL import Image from torch.cuda import OutOfMemoryError from scripts.supported_preprocessor import Preprocessor from scripts.utils import resize_image_with_pad from scripts import sam from modules import shared, sd_models, devices from cn_sam_preprocessor.tools import convertImageIntoPILFormat, convertIntoCNImageFormat from cn_sam_preprocessor.options import (getTemplate, getAutoSamOptions, getSegmentAnythingModel, needAutoUnloadModels, avoidOOM, ) def unloadSAM(): sam.clear_cache() devices.torch_gc() def processAutoSegmentAnything(image: Image.Image): sam_model_name = getSegmentAnythingModel(sam.sam_model_list) nones = ['random', None, None, None, None, None] auto_sam_config = getAutoSamOptions() print(sam_model_name, auto_sam_config) result = sam.cnet_seg(sam_model_name, image, *nones, *auto_sam_config) print(result[1]) return result[0][1] def processAutoSegmentAnything_cpu(image: Image.Image): print('Using cpu for auto segmentation') oldDevice = devices.device oldSamDevice = sam.sam_device devices.device = 'cpu' sam.sam_device = 'cpu' try: return processAutoSegmentAnything(image) finally: devices.device = oldDevice sam.sam_device = oldSamDevice def processAutoSegmentAnything_avoidOOM(image: Image.Image): try: result = processAutoSegmentAnything(image) except OutOfMemoryError: print("\nOut of Memory. Unload Stable Diffusion\n") unloadSAM() sd_models.unload_model_weights() try: result = processAutoSegmentAnything(image) except OutOfMemoryError: print("\nOut of Memory. Use CPU\n") unloadSAM() result = processAutoSegmentAnything_cpu(image) return result class PreprocessorSegmentAnything(Preprocessor): NAME = "segment_anything" def __init__(self): super().__init__(name=self.NAME) self.tags = ["Segmentation"] def unload(self) -> bool: """@Override""" unloadSAM() return True def __call__( self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs ): img, remove_pad = resize_image_with_pad(input_image, resolution) img = convertImageIntoPILFormat(img) if avoidOOM(): result = processAutoSegmentAnything_avoidOOM(img) else: result = processAutoSegmentAnything(img) if needAutoUnloadModels(): unloadSAM() result = convertIntoCNImageFormat(result) result = remove_pad(result) return result shared.options_templates.update(getTemplate(sam.sam_model_list)) if not Preprocessor.get_preprocessor(PreprocessorSegmentAnything.NAME): Preprocessor.add_supported_preprocessor(PreprocessorSegmentAnything())