Spaces:
Runtime error
Runtime error
| """some utils for api""" | |
| import random | |
| from typing import List | |
| import numpy | |
| from fastapi import Response | |
| from fastapi.security import APIKeyHeader | |
| from fastapi import HTTPException, Security | |
| from fooocusapi.models.common.base import EnhanceCtrlNets, ImagePrompt | |
| from modules import constants, flags | |
| from modules import config | |
| from modules.sdxl_styles import legal_style_names | |
| from fooocusapi.args import args | |
| from fooocusapi.utils.img_utils import read_input_image | |
| from fooocusapi.utils.file_utils import ( | |
| get_file_serve_url, | |
| output_file_to_base64img, | |
| output_file_to_bytesimg | |
| ) | |
| from fooocusapi.utils.logger import logger | |
| from fooocusapi.models.common.requests import ( | |
| CommonRequest as Text2ImgRequest | |
| ) | |
| from fooocusapi.models.common.response import ( | |
| AsyncJobResponse, | |
| AsyncJobStage, | |
| GeneratedImageResult | |
| ) | |
| from fooocusapi.models.requests_v1 import ( | |
| ImageEnhanceRequest, ImgInpaintOrOutpaintRequest, | |
| ImgPromptRequest, | |
| ImgUpscaleOrVaryRequest | |
| ) | |
| from fooocusapi.models.requests_v2 import ( | |
| ImageEnhanceRequestJson, Text2ImgRequestWithPrompt, | |
| ImgInpaintOrOutpaintRequestJson, | |
| ImgUpscaleOrVaryRequestJson, | |
| ImgPromptRequestJson | |
| ) | |
| from fooocusapi.models.common.task import ( | |
| ImageGenerationResult, | |
| GenerationFinishReason | |
| ) | |
| from fooocusapi.configs.default import ( | |
| default_inpaint_engine_version, | |
| default_sampler, | |
| default_scheduler, | |
| default_base_model_name, | |
| default_refiner_model_name | |
| ) | |
| from fooocusapi.parameters import ImageGenerationParams | |
| from fooocusapi.task_queue import QueueTask | |
| from modules.util import HWC3 | |
| api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False) | |
| def refresh_seed(seed_string: int | str | None) -> int: | |
| """ | |
| Refresh and check seed number. | |
| :params seed_string: seed, str or int. None means random | |
| :return: seed number | |
| """ | |
| RANDOM_SEED = random.randint(constants.MIN_SEED, constants.MAX_SEED) | |
| try: | |
| seed_value = int(seed_string) | |
| except ValueError: | |
| return RANDOM_SEED | |
| if seed_value < constants.MIN_SEED or seed_value > constants.MAX_SEED or seed_string == -1: | |
| return RANDOM_SEED | |
| return seed_value | |
| def check_models_exist(file_name: str, model_type: str) -> str: | |
| """ | |
| Check if all models exist | |
| """ | |
| if file_name in (None, 'None'): | |
| return 'None' | |
| config.update_files() | |
| if file_name not in (config.model_filenames + config.lora_filenames): | |
| logger.std_warn(f"[Warning] Wrong {model_type} model input: {file_name}, using default") | |
| if model_type == 'base': | |
| return default_base_model_name | |
| if model_type == 'refiner': | |
| return default_refiner_model_name | |
| return 'None' | |
| return file_name | |
| def api_key_auth(apikey: str = Security(api_key_header)): | |
| """ | |
| Check if the API key is valid, API key is not required if no API key is set | |
| Args: | |
| apikey: API key | |
| returns: | |
| None if API key is not set, otherwise raise HTTPException | |
| """ | |
| if args.apikey is None: | |
| return # Skip API key check if no API key is set | |
| if apikey != args.apikey: | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams: | |
| """ | |
| Convert Request to ImageGenerationParams | |
| Args: | |
| req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest | |
| returns: | |
| ImageGenerationParams | |
| """ | |
| prompt = req.prompt | |
| negative_prompt = req.negative_prompt | |
| style_selections = [ | |
| s for s in req.style_selections if s in legal_style_names] | |
| performance_selection = req.performance_selection.value | |
| aspect_ratios_selection = req.aspect_ratios_selection | |
| image_number = req.image_number | |
| image_seed = refresh_seed(req.image_seed) | |
| sharpness = req.sharpness | |
| guidance_scale = req.guidance_scale | |
| base_model_name = check_models_exist(req.base_model_name, 'base') | |
| refiner_model_name = check_models_exist(req.refiner_model_name, 'refiner') | |
| refiner_switch = req.refiner_switch | |
| loras = [(lora.enabled, check_models_exist(lora.model_name, 'lora'), lora.weight) for lora in req.loras] | |
| uov_input_image = None | |
| if not isinstance(req, Text2ImgRequestWithPrompt): | |
| if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)): | |
| uov_input_image = read_input_image(req.input_image) | |
| uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value | |
| upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value | |
| outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [ | |
| s.value for s in req.outpaint_selections] | |
| outpaint_distance_left = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left | |
| outpaint_distance_right = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right | |
| outpaint_distance_top = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top | |
| outpaint_distance_bottom = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom | |
| if refiner_model_name == '': | |
| refiner_model_name = 'None' | |
| inpaint_input_image = dict(image=None, mask=None) | |
| inpaint_additional_prompt = None | |
| if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None: | |
| inpaint_additional_prompt = req.inpaint_additional_prompt | |
| input_image = read_input_image(req.input_image) | |
| inpaint_image_size = input_image.shape[:2] | |
| input_mask = HWC3(numpy.zeros(inpaint_image_size, dtype=numpy.uint8)) | |
| if req.input_mask is not None: | |
| input_mask = HWC3(read_input_image(req.input_mask)) | |
| inpaint_input_image = { | |
| 'image': input_image, | |
| 'mask': input_mask | |
| } | |
| image_prompts = [] | |
| if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)): | |
| # Auto set mixing_image_prompt_and_inpaint to True | |
| if len(req.image_prompts) > 0 and uov_input_image is not None: | |
| print("[INFO] Mixing image prompt and vary upscale is set to True") | |
| req.advanced_params.mixing_image_prompt_and_vary_upscale = True | |
| elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None: | |
| print("[INFO] Mixing image prompt and inpaint is set to True") | |
| req.advanced_params.mixing_image_prompt_and_inpaint = True | |
| for img_prompt in req.image_prompts: | |
| if img_prompt.cn_img is not None: | |
| cn_img = read_input_image(img_prompt.cn_img) | |
| if img_prompt.cn_stop is None or img_prompt.cn_stop == 0: | |
| img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0] | |
| if img_prompt.cn_weight is None or img_prompt.cn_weight == 0: | |
| img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1] | |
| image_prompts.append( | |
| (cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value)) | |
| if len(image_prompts) < config.default_controlnet_image_count: | |
| dp = (None, 0.5, 0.6, 'ImagePrompt') | |
| image_prompts += [dp] * (config.default_controlnet_image_count - len(image_prompts)) | |
| if isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)): | |
| enhance_checkbox = True | |
| enhance_input_image = read_input_image(req.enhance_input_image) | |
| enhance_uov_method = req.enhance_uov_method | |
| enhance_uov_processing_order = req.enhance_uov_processing_order | |
| enhance_uov_prompt_type = req.enhance_uov_prompt_type | |
| save_final_enhanced_image_only = True | |
| else: | |
| enhance_checkbox = False | |
| enhance_input_image = None | |
| enhance_uov_method = flags.disabled | |
| enhance_uov_processing_order = "Before First Enhancement" | |
| enhance_uov_prompt_type = "Original Prompts" | |
| save_final_enhanced_image_only = False | |
| if not isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)): | |
| enhance_ctrlnets = [EnhanceCtrlNets()] * config.default_enhance_tabs | |
| else: | |
| enhance_ctrlnets = req.enhance_ctrlnets | |
| advanced_params = None | |
| if req.advanced_params is not None: | |
| adp = req.advanced_params | |
| if adp.refiner_swap_method not in ['joint', 'separate', 'vae']: | |
| print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default") | |
| adp.refiner_swap_method = 'joint' | |
| if adp.sampler_name not in flags.sampler_list: | |
| print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default") | |
| adp.sampler_name = default_sampler | |
| if adp.scheduler_name not in flags.scheduler_list: | |
| print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default") | |
| adp.scheduler_name = default_scheduler | |
| if adp.inpaint_engine not in flags.inpaint_engine_versions: | |
| print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default") | |
| adp.inpaint_engine = default_inpaint_engine_version | |
| advanced_params = adp | |
| return ImageGenerationParams( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| style_selections=style_selections, | |
| performance_selection=performance_selection, | |
| aspect_ratios_selection=aspect_ratios_selection, | |
| image_number=image_number, | |
| image_seed=image_seed, | |
| sharpness=sharpness, | |
| guidance_scale=guidance_scale, | |
| base_model_name=base_model_name, | |
| refiner_model_name=refiner_model_name, | |
| refiner_switch=refiner_switch, | |
| loras=loras, | |
| uov_input_image=uov_input_image, | |
| uov_method=uov_method, | |
| upscale_value=upscale_value, | |
| outpaint_selections=outpaint_selections, | |
| outpaint_distance_left=outpaint_distance_left, | |
| outpaint_distance_right=outpaint_distance_right, | |
| outpaint_distance_top=outpaint_distance_top, | |
| outpaint_distance_bottom=outpaint_distance_bottom, | |
| inpaint_input_image=inpaint_input_image, | |
| inpaint_additional_prompt=inpaint_additional_prompt, | |
| enhance_input_image=enhance_input_image, | |
| enhance_checkbox=enhance_checkbox, | |
| enhance_uov_method=enhance_uov_method, | |
| enhance_uov_processing_order=enhance_uov_processing_order, | |
| enhance_uov_prompt_type=enhance_uov_prompt_type, | |
| save_final_enhanced_image_only=save_final_enhanced_image_only, | |
| enhance_ctrlnets=enhance_ctrlnets, | |
| read_wildcards_in_order=req.read_wildcards_in_order, | |
| image_prompts=image_prompts, | |
| advanced_params=advanced_params, | |
| save_meta=req.save_meta, | |
| meta_scheme=req.meta_scheme, | |
| save_name=req.save_name, | |
| save_extension=req.save_extension, | |
| require_base64=req.require_base64, | |
| ) | |
| def generate_async_output( | |
| task: QueueTask, | |
| require_step_preview: bool = False) -> AsyncJobResponse: | |
| """ | |
| Generate output for async job | |
| Arguments: | |
| task: QueueTask | |
| require_step_preview: bool | |
| Returns: | |
| AsyncJobResponse | |
| """ | |
| job_stage = AsyncJobStage.running | |
| job_result = None | |
| if task.start_mills == 0: | |
| job_stage = AsyncJobStage.waiting | |
| if task.is_finished: | |
| if task.finish_with_error: | |
| job_stage = AsyncJobStage.error | |
| elif task.task_result is not None: | |
| job_stage = AsyncJobStage.success | |
| job_result = generate_image_result_output(task.task_result, task.req_param.require_base64) | |
| result = AsyncJobResponse( | |
| job_id=task.job_id, | |
| job_type=task.task_type, | |
| job_stage=job_stage, | |
| job_progress=task.finish_progress, | |
| job_status=task.task_status, | |
| job_step_preview=task.task_step_preview if require_step_preview else None, | |
| job_result=job_result) | |
| return result | |
| def generate_streaming_output(results: List[ImageGenerationResult]) -> Response: | |
| """ | |
| Generate streaming output for image generation results. | |
| Args: | |
| results (List[ImageGenerationResult]): List of image generation results. | |
| Returns: | |
| Response: Streaming response object, bytes image. | |
| """ | |
| if len(results) == 0: | |
| return Response(status_code=500) | |
| result = results[0] | |
| if result.finish_reason == GenerationFinishReason.queue_is_full: | |
| return Response(status_code=409, content=result.finish_reason.value) | |
| if result.finish_reason == GenerationFinishReason.user_cancel: | |
| return Response(status_code=400, content=result.finish_reason.value) | |
| if result.finish_reason == GenerationFinishReason.error: | |
| return Response(status_code=500, content=result.finish_reason.value) | |
| img_bytes = output_file_to_bytesimg(results[0].im) | |
| return Response(img_bytes, media_type='image/png') | |
| def generate_image_result_output( | |
| results: List[ImageGenerationResult], | |
| require_base64: bool) -> List[GeneratedImageResult]: | |
| """ | |
| Generate image result output | |
| Arguments: | |
| results: List[ImageGenerationResult] | |
| require_base64: bool | |
| Returns: | |
| List[GeneratedImageResult] | |
| """ | |
| results = [ | |
| GeneratedImageResult( | |
| base64=output_file_to_base64img(item.im) if require_base64 else None, | |
| url=get_file_serve_url(item.im), | |
| seed=str(item.seed), | |
| finish_reason=item.finish_reason | |
| ) for item in results | |
| ] | |
| return results | |