| from os import getenv | |
| from typing import Union | |
| from loguru import logger | |
| from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig | |
| from carvekit.api.interface import Interface | |
| from carvekit.ml.wrap.fba_matting import FBAMatting | |
| from carvekit.ml.wrap.u2net import U2NET | |
| from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 | |
| from carvekit.ml.wrap.basnet import BASNET | |
| from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 | |
| from carvekit.pipelines.postprocessing import MattingMethod | |
| from carvekit.pipelines.preprocessing import PreprocessingStub | |
| from carvekit.trimap.generator import TrimapGenerator | |
| def init_config() -> WebAPIConfig: | |
| default_config = WebAPIConfig() | |
| config = WebAPIConfig( | |
| **dict( | |
| port=int(getenv("CARVEKIT_PORT", default_config.port)), | |
| host=getenv("CARVEKIT_HOST", default_config.host), | |
| ml=MLConfig( | |
| segmentation_network=getenv( | |
| "CARVEKIT_SEGMENTATION_NETWORK", | |
| default_config.ml.segmentation_network, | |
| ), | |
| preprocessing_method=getenv( | |
| "CARVEKIT_PREPROCESSING_METHOD", | |
| default_config.ml.preprocessing_method, | |
| ), | |
| postprocessing_method=getenv( | |
| "CARVEKIT_POSTPROCESSING_METHOD", | |
| default_config.ml.postprocessing_method, | |
| ), | |
| device=getenv("CARVEKIT_DEVICE", default_config.ml.device), | |
| batch_size_seg=int( | |
| getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) | |
| ), | |
| batch_size_matting=int( | |
| getenv( | |
| "CARVEKIT_BATCH_SIZE_MATTING", | |
| default_config.ml.batch_size_matting, | |
| ) | |
| ), | |
| seg_mask_size=int( | |
| getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) | |
| ), | |
| matting_mask_size=int( | |
| getenv( | |
| "CARVEKIT_MATTING_MASK_SIZE", | |
| default_config.ml.matting_mask_size, | |
| ) | |
| ), | |
| fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), | |
| trimap_prob_threshold=int( | |
| getenv( | |
| "CARVEKIT_TRIMAP_PROB_THRESHOLD", | |
| default_config.ml.trimap_prob_threshold, | |
| ) | |
| ), | |
| trimap_dilation=int( | |
| getenv( | |
| "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation | |
| ) | |
| ), | |
| trimap_erosion=int( | |
| getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) | |
| ), | |
| ), | |
| auth=AuthConfig( | |
| auth=bool( | |
| int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) | |
| ), | |
| admin_token=getenv( | |
| "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token | |
| ), | |
| allowed_tokens=default_config.auth.allowed_tokens | |
| if getenv("CARVEKIT_ALLOWED_TOKENS") is None | |
| else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), | |
| ), | |
| ) | |
| ) | |
| logger.info(f"Admin token for Web API is {config.auth.admin_token}") | |
| logger.debug(f"Running Web API with this config: {config.json()}") | |
| return config | |
| def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: | |
| if isinstance(config, WebAPIConfig): | |
| config = config.ml | |
| if config.segmentation_network == "u2net": | |
| seg_net = U2NET( | |
| device=config.device, | |
| batch_size=config.batch_size_seg, | |
| input_image_size=config.seg_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| elif config.segmentation_network == "deeplabv3": | |
| seg_net = DeepLabV3( | |
| device=config.device, | |
| batch_size=config.batch_size_seg, | |
| input_image_size=config.seg_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| elif config.segmentation_network == "basnet": | |
| seg_net = BASNET( | |
| device=config.device, | |
| batch_size=config.batch_size_seg, | |
| input_image_size=config.seg_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| elif config.segmentation_network == "tracer_b7": | |
| seg_net = TracerUniversalB7( | |
| device=config.device, | |
| batch_size=config.batch_size_seg, | |
| input_image_size=config.seg_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| else: | |
| seg_net = TracerUniversalB7( | |
| device=config.device, | |
| batch_size=config.batch_size_seg, | |
| input_image_size=config.seg_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| if config.preprocessing_method == "stub": | |
| preprocessing = PreprocessingStub() | |
| elif config.preprocessing_method == "none": | |
| preprocessing = None | |
| else: | |
| preprocessing = None | |
| if config.postprocessing_method == "fba": | |
| fba = FBAMatting( | |
| device=config.device, | |
| batch_size=config.batch_size_matting, | |
| input_tensor_size=config.matting_mask_size, | |
| fp16=config.fp16, | |
| ) | |
| trimap_generator = TrimapGenerator( | |
| prob_threshold=config.trimap_prob_threshold, | |
| kernel_size=config.trimap_dilation, | |
| erosion_iters=config.trimap_erosion, | |
| ) | |
| postprocessing = MattingMethod( | |
| device=config.device, matting_module=fba, trimap_generator=trimap_generator | |
| ) | |
| elif config.postprocessing_method == "none": | |
| postprocessing = None | |
| else: | |
| postprocessing = None | |
| interface = Interface( | |
| pre_pipe=preprocessing, | |
| post_pipe=postprocessing, | |
| seg_pipe=seg_net, | |
| device=config.device, | |
| ) | |
| return interface | |