Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| import hashlib | |
| import os | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| import imghdr | |
| import io | |
| import logging | |
| import multiprocessing | |
| import random | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from loguru import logger | |
| from lama_cleaner.const import SD15_MODELS | |
| from lama_cleaner.file_manager import FileManager | |
| from lama_cleaner.model.utils import torch_gc | |
| from lama_cleaner.model_manager import ModelManager | |
| from lama_cleaner.plugins import ( | |
| InteractiveSeg, | |
| RemoveBG, | |
| RealESRGANUpscaler, | |
| MakeGIF, | |
| GFPGANPlugin, | |
| RestoreFormerPlugin, | |
| ) | |
| from lama_cleaner.schema import Config | |
| try: | |
| torch._C._jit_override_can_fuse_on_cpu(False) | |
| torch._C._jit_override_can_fuse_on_gpu(False) | |
| torch._C._jit_set_texpr_fuser_enabled(False) | |
| torch._C._jit_set_nvfuser_enabled(False) | |
| except: | |
| pass | |
| from flask import ( | |
| Flask, | |
| request, | |
| send_file, | |
| cli, | |
| make_response, | |
| send_from_directory, | |
| jsonify, | |
| ) | |
| # Disable ability for Flask to display warning about using a development server in a production environment. | |
| # https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356 | |
| cli.show_server_banner = lambda *_: None | |
| from flask_cors import CORS | |
| from lama_cleaner.helper import ( | |
| load_img, | |
| numpy_to_bytes, | |
| resize_max_size, | |
| pil_to_bytes, | |
| ) | |
| NUM_THREADS = str(multiprocessing.cpu_count()) | |
| # fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56 | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | |
| os.environ["OMP_NUM_THREADS"] = NUM_THREADS | |
| os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS | |
| os.environ["MKL_NUM_THREADS"] = NUM_THREADS | |
| os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS | |
| os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS | |
| if os.environ.get("CACHE_DIR"): | |
| os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] | |
| BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build") | |
| class NoFlaskwebgui(logging.Filter): | |
| def filter(self, record): | |
| return "flaskwebgui-keep-server-alive" not in record.getMessage() | |
| logging.getLogger("werkzeug").addFilter(NoFlaskwebgui()) | |
| app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static")) | |
| app.config["JSON_AS_ASCII"] = False | |
| CORS(app, expose_headers=["Content-Disposition"]) | |
| model: ModelManager = None | |
| thumb: FileManager = None | |
| output_dir: str = None | |
| device = None | |
| input_image_path: str = None | |
| is_disable_model_switch: bool = False | |
| is_controlnet: bool = False | |
| is_enable_file_manager: bool = False | |
| is_enable_auto_saving: bool = False | |
| is_desktop: bool = False | |
| image_quality: int = 95 | |
| plugins = {} | |
| def get_image_ext(img_bytes): | |
| w = imghdr.what("", img_bytes) | |
| if w is None: | |
| w = "jpeg" | |
| return w | |
| def diffuser_callback(i, t, latents): | |
| pass | |
| # socketio.emit('diffusion_step', {'diffusion_step': step}) | |
| def save_image(): | |
| if output_dir is None: | |
| return "--output-dir is None", 500 | |
| input = request.files | |
| filename = request.form["filename"] | |
| origin_image_bytes = input["image"].read() # RGB | |
| image, _ = load_img(origin_image_bytes) | |
| if image.shape[2] == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA) | |
| save_path = os.path.join(output_dir, filename) | |
| cv2.imencode(Path(save_path).suffix, image)[1].tofile(save_path) | |
| return "ok", 200 | |
| def medias(tab): | |
| if tab == "image": | |
| response = make_response(jsonify(thumb.media_names), 200) | |
| else: | |
| response = make_response(jsonify(thumb.output_media_names), 200) | |
| # response.last_modified = thumb.modified_time[tab] | |
| # response.cache_control.no_cache = True | |
| # response.cache_control.max_age = 0 | |
| # response.make_conditional(request) | |
| return response | |
| def media_file(tab, filename): | |
| if tab == "image": | |
| return send_from_directory(thumb.root_directory, filename) | |
| return send_from_directory(thumb.output_dir, filename) | |
| def media_thumbnail_file(tab, filename): | |
| args = request.args | |
| width = args.get("width") | |
| height = args.get("height") | |
| if width is None and height is None: | |
| width = 256 | |
| if width: | |
| width = int(float(width)) | |
| if height: | |
| height = int(float(height)) | |
| directory = thumb.root_directory | |
| if tab == "output": | |
| directory = thumb.output_dir | |
| thumb_filename, (width, height) = thumb.get_thumbnail( | |
| directory, filename, width, height | |
| ) | |
| thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}" | |
| response = make_response(send_file(thumb_filepath)) | |
| response.headers["X-Width"] = str(width) | |
| response.headers["X-Height"] = str(height) | |
| return response | |
| def process(): | |
| input = request.files | |
| # RGB | |
| origin_image_bytes = input["image"].read() | |
| image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) | |
| mask, _ = load_img(input["mask"].read(), gray=True) | |
| mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] | |
| if image.shape[:2] != mask.shape[:2]: | |
| return ( | |
| f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", | |
| 400, | |
| ) | |
| original_shape = image.shape | |
| interpolation = cv2.INTER_CUBIC | |
| form = request.form | |
| size_limit = max(image.shape) | |
| if "paintByExampleImage" in input: | |
| paint_by_example_example_image, _ = load_img( | |
| input["paintByExampleImage"].read() | |
| ) | |
| paint_by_example_example_image = Image.fromarray(paint_by_example_example_image) | |
| else: | |
| paint_by_example_example_image = None | |
| config = Config( | |
| ldm_steps=form["ldmSteps"], | |
| ldm_sampler=form["ldmSampler"], | |
| hd_strategy=form["hdStrategy"], | |
| zits_wireframe=form["zitsWireframe"], | |
| hd_strategy_crop_margin=form["hdStrategyCropMargin"], | |
| hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], | |
| hd_strategy_resize_limit=form["hdStrategyResizeLimit"], | |
| prompt=form["prompt"], | |
| negative_prompt=form["negativePrompt"], | |
| use_croper=form["useCroper"], | |
| croper_x=form["croperX"], | |
| croper_y=form["croperY"], | |
| croper_height=form["croperHeight"], | |
| croper_width=form["croperWidth"], | |
| sd_scale=form["sdScale"], | |
| sd_mask_blur=form["sdMaskBlur"], | |
| sd_strength=form["sdStrength"], | |
| sd_steps=form["sdSteps"], | |
| sd_guidance_scale=form["sdGuidanceScale"], | |
| sd_sampler=form["sdSampler"], | |
| sd_seed=form["sdSeed"], | |
| sd_match_histograms=form["sdMatchHistograms"], | |
| cv2_flag=form["cv2Flag"], | |
| cv2_radius=form["cv2Radius"], | |
| paint_by_example_steps=form["paintByExampleSteps"], | |
| paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"], | |
| paint_by_example_mask_blur=form["paintByExampleMaskBlur"], | |
| paint_by_example_seed=form["paintByExampleSeed"], | |
| paint_by_example_match_histograms=form["paintByExampleMatchHistograms"], | |
| paint_by_example_example_image=paint_by_example_example_image, | |
| p2p_steps=form["p2pSteps"], | |
| p2p_image_guidance_scale=form["p2pImageGuidanceScale"], | |
| p2p_guidance_scale=form["p2pGuidanceScale"], | |
| controlnet_conditioning_scale=form["controlnet_conditioning_scale"], | |
| ) | |
| if config.sd_seed == -1: | |
| config.sd_seed = random.randint(1, 999999999) | |
| if config.paint_by_example_seed == -1: | |
| config.paint_by_example_seed = random.randint(1, 999999999) | |
| logger.info(f"Origin image shape: {original_shape}") | |
| image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) | |
| logger.info(f"Resized image shape: {image.shape}") | |
| mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) | |
| start = time.time() | |
| try: | |
| res_np_img = model(image, mask, config) | |
| except RuntimeError as e: | |
| torch.cuda.empty_cache() | |
| if "CUDA out of memory. " in str(e): | |
| # NOTE: the string may change? | |
| return "CUDA out of memory", 500 | |
| else: | |
| logger.exception(e) | |
| return "Internal Server Error", 500 | |
| finally: | |
| logger.info(f"process time: {(time.time() - start) * 1000}ms") | |
| torch.cuda.empty_cache() | |
| res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) | |
| if alpha_channel is not None: | |
| if alpha_channel.shape[:2] != res_np_img.shape[:2]: | |
| alpha_channel = cv2.resize( | |
| alpha_channel, dsize=(res_np_img.shape[1], res_np_img.shape[0]) | |
| ) | |
| res_np_img = np.concatenate( | |
| (res_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 | |
| ) | |
| ext = get_image_ext(origin_image_bytes) | |
| # fmt: off | |
| if exif is not None: | |
| bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality, exif=exif)) | |
| else: | |
| bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, quality=image_quality)) | |
| # fmt: on | |
| response = make_response( | |
| send_file( | |
| # io.BytesIO(numpy_to_bytes(res_np_img, ext)), | |
| bytes_io, | |
| mimetype=f"image/{ext}", | |
| ) | |
| ) | |
| response.headers["X-Seed"] = str(config.sd_seed) | |
| return response | |
| def run_plugin(): | |
| form = request.form | |
| files = request.files | |
| name = form["name"] | |
| if name not in plugins: | |
| return "Plugin not found", 500 | |
| origin_image_bytes = files["image"].read() # RGB | |
| rgb_np_img, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True) | |
| start = time.time() | |
| try: | |
| form = dict(form) | |
| if name == InteractiveSeg.name: | |
| img_md5 = hashlib.md5(origin_image_bytes).hexdigest() | |
| form["img_md5"] = img_md5 | |
| bgr_res = plugins[name](rgb_np_img, files, form) | |
| except RuntimeError as e: | |
| torch.cuda.empty_cache() | |
| if "CUDA out of memory. " in str(e): | |
| # NOTE: the string may change? | |
| return "CUDA out of memory", 500 | |
| else: | |
| logger.exception(e) | |
| return "Internal Server Error", 500 | |
| logger.info(f"{name} process time: {(time.time() - start) * 1000}ms") | |
| torch_gc() | |
| if name == MakeGIF.name: | |
| return send_file( | |
| io.BytesIO(bgr_res), | |
| mimetype="image/gif", | |
| as_attachment=True, | |
| download_name=form["filename"], | |
| ) | |
| if name == InteractiveSeg.name: | |
| return make_response( | |
| send_file( | |
| io.BytesIO(numpy_to_bytes(bgr_res, "png")), | |
| mimetype="image/png", | |
| ) | |
| ) | |
| if name == RemoveBG.name: | |
| rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGRA2RGBA) | |
| ext = "png" | |
| else: | |
| rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB) | |
| ext = get_image_ext(origin_image_bytes) | |
| if alpha_channel is not None: | |
| if alpha_channel.shape[:2] != rgb_res.shape[:2]: | |
| alpha_channel = cv2.resize( | |
| alpha_channel, dsize=(rgb_res.shape[1], rgb_res.shape[0]) | |
| ) | |
| rgb_res = np.concatenate( | |
| (rgb_res, alpha_channel[:, :, np.newaxis]), axis=-1 | |
| ) | |
| response = make_response( | |
| send_file( | |
| io.BytesIO( | |
| pil_to_bytes( | |
| Image.fromarray(rgb_res), ext, quality=image_quality, exif=exif | |
| ) | |
| ), | |
| mimetype=f"image/{ext}", | |
| ) | |
| ) | |
| return response | |
| def get_server_config(): | |
| return { | |
| "isControlNet": is_controlnet, | |
| "isDisableModelSwitchState": is_disable_model_switch, | |
| "isEnableAutoSaving": is_enable_auto_saving, | |
| "enableFileManager": is_enable_file_manager, | |
| "plugins": list(plugins.keys()), | |
| }, 200 | |
| def current_model(): | |
| return model.name, 200 | |
| def model_downloaded(name): | |
| return str(model.is_downloaded(name)), 200 | |
| def get_is_desktop(): | |
| return str(is_desktop), 200 | |
| def switch_model(): | |
| if is_disable_model_switch: | |
| return "Switch model is disabled", 400 | |
| new_name = request.form.get("name") | |
| if new_name == model.name: | |
| return "Same model", 200 | |
| try: | |
| model.switch(new_name) | |
| except NotImplementedError: | |
| return f"{new_name} not implemented", 403 | |
| return f"ok, switch to {new_name}", 200 | |
| def index(): | |
| return send_file(os.path.join(BUILD_DIR, "index.html")) | |
| def set_input_photo(): | |
| if input_image_path: | |
| with open(input_image_path, "rb") as f: | |
| image_in_bytes = f.read() | |
| return send_file( | |
| input_image_path, | |
| as_attachment=True, | |
| attachment_filename=Path(input_image_path).name, | |
| mimetype=f"image/{get_image_ext(image_in_bytes)}", | |
| ) | |
| else: | |
| return "No Input Image" | |
| def build_plugins(args): | |
| global plugins | |
| if args.enable_interactive_seg: | |
| logger.info(f"Initialize {InteractiveSeg.name} plugin") | |
| plugins[InteractiveSeg.name] = InteractiveSeg( | |
| args.interactive_seg_model, args.interactive_seg_device | |
| ) | |
| if args.enable_remove_bg: | |
| logger.info(f"Initialize {RemoveBG.name} plugin") | |
| plugins[RemoveBG.name] = RemoveBG() | |
| if args.enable_realesrgan: | |
| logger.info( | |
| f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" | |
| ) | |
| plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( | |
| args.realesrgan_model, | |
| args.realesrgan_device, | |
| no_half=args.realesrgan_no_half, | |
| ) | |
| if args.enable_gfpgan: | |
| logger.info(f"Initialize {GFPGANPlugin.name} plugin") | |
| if args.enable_realesrgan: | |
| logger.info("Use realesrgan as GFPGAN background upscaler") | |
| else: | |
| logger.info( | |
| f"GFPGAN no background upscaler, use --enable-realesrgan to enable it" | |
| ) | |
| plugins[GFPGANPlugin.name] = GFPGANPlugin( | |
| args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) | |
| ) | |
| if args.enable_restoreformer: | |
| logger.info(f"Initialize {RestoreFormerPlugin.name} plugin") | |
| plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin( | |
| args.restoreformer_device, | |
| upscaler=plugins.get(RealESRGANUpscaler.name, None), | |
| ) | |
| if args.enable_gif: | |
| logger.info(f"Initialize GIF plugin") | |
| plugins[MakeGIF.name] = MakeGIF() | |
| def main(args): | |
| global model | |
| global device | |
| global input_image_path | |
| global is_disable_model_switch | |
| global is_enable_file_manager | |
| global is_desktop | |
| global thumb | |
| global output_dir | |
| global is_enable_auto_saving | |
| global is_controlnet | |
| global image_quality | |
| build_plugins(args) | |
| image_quality = args.quality | |
| if args.sd_controlnet and args.model in SD15_MODELS: | |
| is_controlnet = True | |
| output_dir = args.output_dir | |
| if output_dir: | |
| is_enable_auto_saving = True | |
| device = torch.device(args.device) | |
| is_disable_model_switch = args.disable_model_switch | |
| is_desktop = args.gui | |
| if is_disable_model_switch: | |
| logger.info( | |
| f"Start with --disable-model-switch, model switch on frontend is disable" | |
| ) | |
| if args.input and os.path.isdir(args.input): | |
| logger.info(f"Initialize file manager") | |
| thumb = FileManager(app) | |
| is_enable_file_manager = True | |
| app.config["THUMBNAIL_MEDIA_ROOT"] = args.input | |
| app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join( | |
| args.output_dir, "lama_cleaner_thumbnails" | |
| ) | |
| thumb.output_dir = Path(args.output_dir) | |
| # thumb.start() | |
| # try: | |
| # while True: | |
| # time.sleep(1) | |
| # finally: | |
| # thumb.image_dir_observer.stop() | |
| # thumb.image_dir_observer.join() | |
| # thumb.output_dir_observer.stop() | |
| # thumb.output_dir_observer.join() | |
| else: | |
| input_image_path = args.input | |
| model = ModelManager( | |
| name=args.model, | |
| sd_controlnet=args.sd_controlnet, | |
| device=device, | |
| no_half=args.no_half, | |
| hf_access_token=args.hf_access_token, | |
| disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, | |
| sd_cpu_textencoder=args.sd_cpu_textencoder, | |
| sd_run_local=args.sd_run_local, | |
| sd_local_model_path=args.sd_local_model_path, | |
| local_files_only=args.local_files_only, | |
| cpu_offload=args.cpu_offload, | |
| enable_xformers=args.sd_enable_xformers or args.enable_xformers, | |
| callback=diffuser_callback, | |
| ) | |
| if args.gui: | |
| app_width, app_height = args.gui_size | |
| from flaskwebgui import FlaskUI | |
| ui = FlaskUI( | |
| app, | |
| width=app_width, | |
| height=app_height, | |
| host=args.host, | |
| port=args.port, | |
| close_server_on_exit=not args.no_gui_auto_close, | |
| ) | |
| ui.run() | |
| else: | |
| app.run(host=args.host, port=args.port, debug=args.debug) | |