Spaces:
Runtime error
Runtime error
| import asyncio | |
| import os | |
| import threading | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Optional, Dict, List | |
| import cv2 | |
| import numpy as np | |
| import socketio | |
| import torch | |
| try: | |
| torch._C._jit_override_can_fuse_on_cpu(False) | |
| torch._C._jit_override_can_fuse_on_gpu(False) | |
| torch._C._jit_set_texpr_fuser_enabled(False) | |
| torch._C._jit_set_nvfuser_enabled(False) | |
| except: | |
| pass | |
| import uvicorn | |
| from PIL import Image | |
| from fastapi import APIRouter, FastAPI, Request, UploadFile | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.exceptions import HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, FileResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from loguru import logger | |
| from socketio import AsyncServer | |
| from iopaint.file_manager import FileManager | |
| from iopaint.helper import ( | |
| load_img, | |
| decode_base64_to_image, | |
| pil_to_bytes, | |
| numpy_to_bytes, | |
| concat_alpha_channel, | |
| gen_frontend_mask, | |
| adjust_mask, | |
| ) | |
| from iopaint.model.utils import torch_gc | |
| from iopaint.model_manager import ModelManager | |
| from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg | |
| from iopaint.plugins.base_plugin import BasePlugin | |
| from iopaint.plugins.remove_bg import RemoveBG | |
| from iopaint.schema import ( | |
| GenInfoResponse, | |
| ApiConfig, | |
| ServerConfigResponse, | |
| SwitchModelRequest, | |
| InpaintRequest, | |
| RunPluginRequest, | |
| SDSampler, | |
| PluginInfo, | |
| AdjustMaskRequest, | |
| RemoveBGModel, | |
| SwitchPluginModelRequest, | |
| ModelInfo, | |
| InteractiveSegModel, | |
| RealESRGANModel, | |
| ) | |
| CURRENT_DIR = Path(__file__).parent.absolute().resolve() | |
| WEB_APP_DIR = CURRENT_DIR / "web_app" | |
| def api_middleware(app: FastAPI): | |
| rich_available = False | |
| try: | |
| if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: | |
| import anyio # importing just so it can be placed on silent list | |
| import starlette # importing just so it can be placed on silent list | |
| from rich.console import Console | |
| console = Console() | |
| rich_available = True | |
| except Exception: | |
| pass | |
| def handle_exception(request: Request, e: Exception): | |
| err = { | |
| "error": type(e).__name__, | |
| "detail": vars(e).get("detail", ""), | |
| "body": vars(e).get("body", ""), | |
| "errors": str(e), | |
| } | |
| if not isinstance( | |
| e, HTTPException | |
| ): # do not print backtrace on known httpexceptions | |
| message = f"API error: {request.method}: {request.url} {err}" | |
| if rich_available: | |
| print(message) | |
| console.print_exception( | |
| show_locals=True, | |
| max_frames=2, | |
| extra_lines=1, | |
| suppress=[anyio, starlette], | |
| word_wrap=False, | |
| width=min([console.width, 200]), | |
| ) | |
| else: | |
| traceback.print_exc() | |
| return JSONResponse( | |
| status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) | |
| ) | |
| async def exception_handling(request: Request, call_next): | |
| try: | |
| return await call_next(request) | |
| except Exception as e: | |
| return handle_exception(request, e) | |
| async def fastapi_exception_handler(request: Request, e: Exception): | |
| return handle_exception(request, e) | |
| async def http_exception_handler(request: Request, e: HTTPException): | |
| return handle_exception(request, e) | |
| cors_options = { | |
| "allow_methods": ["*"], | |
| "allow_headers": ["*"], | |
| "allow_origins": ["*"], | |
| "allow_credentials": True, | |
| } | |
| app.add_middleware(CORSMiddleware, **cors_options) | |
| global_sio: AsyncServer = None | |
| def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}): | |
| # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict | |
| # logger.info(f"diffusion callback: step={step}, timestep={timestep}") | |
| # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI, | |
| # but for now let's just start a separate event loop. It shouldn't make a difference for single person use | |
| asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) | |
| return {} | |
| class Api: | |
| def __init__(self, app: FastAPI, config: ApiConfig): | |
| self.app = app | |
| self.config = config | |
| self.router = APIRouter() | |
| self.queue_lock = threading.Lock() | |
| api_middleware(self.app) | |
| self.file_manager = self._build_file_manager() | |
| self.plugins = self._build_plugins() | |
| self.model_manager = self._build_model_manager() | |
| # fmt: off | |
| self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) | |
| self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) | |
| self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) | |
| self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) | |
| self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) | |
| self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) | |
| self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) | |
| self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) | |
| self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) | |
| self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) | |
| self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) | |
| self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) | |
| self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") | |
| # fmt: on | |
| global global_sio | |
| self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") | |
| self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) | |
| self.app.mount("/ws", self.combined_asgi_app) | |
| global_sio = self.sio | |
| def add_api_route(self, path: str, endpoint, **kwargs): | |
| return self.app.add_api_route(path, endpoint, **kwargs) | |
| def api_save_image(self, file: UploadFile): | |
| filename = file.filename | |
| origin_image_bytes = file.file.read() | |
| with open(self.config.output_dir / filename, "wb") as fw: | |
| fw.write(origin_image_bytes) | |
| def api_current_model(self) -> ModelInfo: | |
| return self.model_manager.current_model | |
| def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: | |
| if req.name == self.model_manager.name: | |
| return self.model_manager.current_model | |
| self.model_manager.switch(req.name) | |
| return self.model_manager.current_model | |
| def api_switch_plugin_model(self, req: SwitchPluginModelRequest): | |
| if req.plugin_name in self.plugins: | |
| self.plugins[req.plugin_name].switch_model(req.model_name) | |
| if req.plugin_name == RemoveBG.name: | |
| self.config.remove_bg_model = req.model_name | |
| if req.plugin_name == RealESRGANUpscaler.name: | |
| self.config.realesrgan_model = req.model_name | |
| if req.plugin_name == InteractiveSeg.name: | |
| self.config.interactive_seg_model = req.model_name | |
| torch_gc() | |
| def api_server_config(self) -> ServerConfigResponse: | |
| plugins = [] | |
| for it in self.plugins.values(): | |
| plugins.append( | |
| PluginInfo( | |
| name=it.name, | |
| support_gen_image=it.support_gen_image, | |
| support_gen_mask=it.support_gen_mask, | |
| ) | |
| ) | |
| return ServerConfigResponse( | |
| plugins=plugins, | |
| modelInfos=self.model_manager.scan_models(), | |
| removeBGModel=self.config.remove_bg_model, | |
| removeBGModels=RemoveBGModel.values(), | |
| realesrganModel=self.config.realesrgan_model, | |
| realesrganModels=RealESRGANModel.values(), | |
| interactiveSegModel=self.config.interactive_seg_model, | |
| interactiveSegModels=InteractiveSegModel.values(), | |
| enableFileManager=self.file_manager is not None, | |
| enableAutoSaving=self.config.output_dir is not None, | |
| enableControlnet=self.model_manager.enable_controlnet, | |
| controlnetMethod=self.model_manager.controlnet_method, | |
| disableModelSwitch=False, | |
| isDesktop=False, | |
| samplers=self.api_samplers(), | |
| ) | |
| def api_input_image(self) -> FileResponse: | |
| if self.config.input and self.config.input.is_file(): | |
| return FileResponse(self.config.input) | |
| raise HTTPException(status_code=404, detail="Input image not found") | |
| def api_geninfo(self, file: UploadFile) -> GenInfoResponse: | |
| _, _, info = load_img(file.file.read(), return_info=True) | |
| parts = info.get("parameters", "").split("Negative prompt: ") | |
| prompt = parts[0].strip() | |
| negative_prompt = "" | |
| if len(parts) > 1: | |
| negative_prompt = parts[1].split("\n")[0].strip() | |
| return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) | |
| def api_inpaint(self, req: InpaintRequest): | |
| image, alpha_channel, infos = decode_base64_to_image(req.image) | |
| mask, _, _ = decode_base64_to_image(req.mask, gray=True) | |
| mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] | |
| if image.shape[:2] != mask.shape[:2]: | |
| raise HTTPException( | |
| 400, | |
| detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", | |
| ) | |
| if req.paint_by_example_example_image: | |
| paint_by_example_image, _, _ = decode_base64_to_image( | |
| req.paint_by_example_example_image | |
| ) | |
| start = time.time() | |
| rgb_np_img = self.model_manager(image, mask, req) | |
| logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") | |
| torch_gc() | |
| rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) | |
| rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) | |
| ext = "png" | |
| res_img_bytes = pil_to_bytes( | |
| Image.fromarray(rgb_res), | |
| ext=ext, | |
| quality=self.config.quality, | |
| infos=infos, | |
| ) | |
| asyncio.run(self.sio.emit("diffusion_finish")) | |
| return Response( | |
| content=res_img_bytes, | |
| media_type=f"image/{ext}", | |
| headers={"X-Seed": str(req.sd_seed)}, | |
| ) | |
| def api_run_plugin_gen_image(self, req: RunPluginRequest): | |
| ext = "png" | |
| if req.name not in self.plugins: | |
| raise HTTPException(status_code=422, detail="Plugin not found") | |
| if not self.plugins[req.name].support_gen_image: | |
| raise HTTPException( | |
| status_code=422, detail="Plugin does not support output image" | |
| ) | |
| rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) | |
| bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) | |
| torch_gc() | |
| if bgr_or_rgba_np_img.shape[2] == 4: | |
| rgba_np_img = bgr_or_rgba_np_img | |
| else: | |
| rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) | |
| rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) | |
| return Response( | |
| content=pil_to_bytes( | |
| Image.fromarray(rgba_np_img), | |
| ext=ext, | |
| quality=self.config.quality, | |
| infos=infos, | |
| ), | |
| media_type=f"image/{ext}", | |
| ) | |
| def api_run_plugin_gen_mask(self, req: RunPluginRequest): | |
| if req.name not in self.plugins: | |
| raise HTTPException(status_code=422, detail="Plugin not found") | |
| if not self.plugins[req.name].support_gen_mask: | |
| raise HTTPException( | |
| status_code=422, detail="Plugin does not support output image" | |
| ) | |
| rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) | |
| bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) | |
| torch_gc() | |
| res_mask = gen_frontend_mask(bgr_or_gray_mask) | |
| return Response( | |
| content=numpy_to_bytes(res_mask, "png"), | |
| media_type="image/png", | |
| ) | |
| def api_samplers(self) -> List[str]: | |
| return [member.value for member in SDSampler.__members__.values()] | |
| def api_adjust_mask(self, req: AdjustMaskRequest): | |
| mask, _, _ = decode_base64_to_image(req.mask, gray=True) | |
| mask = adjust_mask(mask, req.kernel_size, req.operate) | |
| return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") | |
| def launch(self): | |
| self.app.include_router(self.router) | |
| uvicorn.run( | |
| self.combined_asgi_app, | |
| host=self.config.host, | |
| port=self.config.port, | |
| timeout_keep_alive=999999999, | |
| ) | |
| def _build_file_manager(self) -> Optional[FileManager]: | |
| if self.config.input and self.config.input.is_dir(): | |
| logger.info( | |
| f"Input is directory, initialize file manager {self.config.input}" | |
| ) | |
| return FileManager( | |
| app=self.app, | |
| input_dir=self.config.input, | |
| output_dir=self.config.output_dir, | |
| ) | |
| return None | |
| def _build_plugins(self) -> Dict[str, BasePlugin]: | |
| return build_plugins( | |
| self.config.enable_interactive_seg, | |
| self.config.interactive_seg_model, | |
| self.config.interactive_seg_device, | |
| self.config.enable_remove_bg, | |
| self.config.remove_bg_model, | |
| self.config.enable_anime_seg, | |
| self.config.enable_realesrgan, | |
| self.config.realesrgan_device, | |
| self.config.realesrgan_model, | |
| self.config.enable_gfpgan, | |
| self.config.gfpgan_device, | |
| self.config.enable_restoreformer, | |
| self.config.restoreformer_device, | |
| self.config.no_half, | |
| ) | |
| def _build_model_manager(self): | |
| return ModelManager( | |
| name=self.config.model, | |
| device=torch.device(self.config.device), | |
| no_half=self.config.no_half, | |
| low_mem=self.config.low_mem, | |
| disable_nsfw=self.config.disable_nsfw_checker, | |
| sd_cpu_textencoder=self.config.cpu_textencoder, | |
| local_files_only=self.config.local_files_only, | |
| cpu_offload=self.config.cpu_offload, | |
| callback=diffuser_callback, | |
| ) | |