Spaces:
Build error
Build error
| """ | |
| project @ NTO-TCP-HF | |
| created @ 2024-10-29 | |
| author @ github.com/ishworrsubedii | |
| """ | |
| import base64 | |
| import time | |
| from io import BytesIO | |
| from typing import Optional | |
| import replicate | |
| import requests | |
| from PIL import Image | |
| from fastapi import APIRouter, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from src.utils.logger import logger | |
| image_regeneration_router = APIRouter() | |
| def image_regeneration_replicate(input): | |
| output = replicate.run( | |
| "konieshadow/fooocus-api:fda927242b1db6affa1ece4f54c37f19b964666bf23b0d06ae2439067cd344a4", | |
| input=input | |
| ) | |
| return output | |
| async def image_re_gen( | |
| prompt: str = Form(...), | |
| negative_prompt: str = Form(""), | |
| image: UploadFile = File(...), | |
| mask_image: Optional[UploadFile] = File(default=None), | |
| reference_image_c1: Optional[UploadFile] = File(default=None), | |
| reference_image_c1_type: Optional[str] = Form(default=""), | |
| reference_image_c1_weight: Optional[float] = Form(default=0.0), | |
| reference_image_c1_stop: Optional[float] = Form(default=0.0), | |
| reference_image_c2: Optional[UploadFile] = File(default=None), | |
| reference_image_c2_type: Optional[str] = Form(default=""), | |
| reference_image_c2_weight: Optional[float] = Form(default=0.0), | |
| reference_image_c2_stop: Optional[float] = Form(default=0.0), | |
| reference_image_c3: Optional[UploadFile] = File(default=None), | |
| reference_image_c3_type: Optional[str] = Form(default=""), | |
| reference_image_c3_weight: Optional[float] = Form(default=0.0), | |
| reference_image_c3_stop: Optional[float] = Form(default=0.0), | |
| reference_image_c4: Optional[UploadFile] = File(default=None), | |
| reference_image_c4_type: Optional[str] = Form(default=""), | |
| reference_image_c4_weight: Optional[float] = Form(default=0.0), | |
| reference_image_c4_stop: Optional[float] = Form(default=0.0), | |
| ): | |
| logger.info("-" * 50) | |
| logger.info(">>> IMAGE REDESIGN STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| async def process_reference_image(reference_image: Optional[UploadFile]) -> Optional[str]: | |
| if reference_image is not None: | |
| reference_image_bytes = await reference_image.read() | |
| reference_image = Image.open(BytesIO(reference_image_bytes)).convert("RGB") | |
| ref_img_base64 = BytesIO() | |
| reference_image.save(ref_img_base64, format="WEBP") | |
| reference_image_b64 = base64.b64encode(ref_img_base64.getvalue()).decode("utf-8") | |
| return f"data:image/WEBP;base64,{reference_image_b64}" | |
| return None | |
| logger.info(">>> REFERENCE IMAGE PROCESSING FUNCTION INITIALIZED <<<") | |
| except Exception as e: | |
| logger.error(f">>> REFERENCE IMAGE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing reference image: {str(e)}", "code": 500}) | |
| try: | |
| image_bytes = await image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| img_base64 = BytesIO() | |
| image.save(img_base64, format="WEBP") | |
| image_data_uri = f"data:image/WEBP;base64,{base64.b64encode(img_base64.getvalue()).decode('utf-8')}" | |
| logger.info(">>> MAIN IMAGE PROCESSED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> MAIN IMAGE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing main image: {str(e)}", "code": 500}) | |
| try: | |
| reference_images = { | |
| 'c1': await process_reference_image(reference_image_c1), | |
| 'c2': await process_reference_image(reference_image_c2), | |
| 'c3': await process_reference_image(reference_image_c3), | |
| 'c4': await process_reference_image(reference_image_c4) | |
| } | |
| logger.info(">>> REFERENCE IMAGES PROCESSED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> REFERENCE IMAGES PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing reference images: {str(e)}", "code": 500}) | |
| try: | |
| input_data = { | |
| "prompt": prompt, | |
| "inpaint_input_image": image_data_uri, | |
| "sharpness": 2, | |
| "guidance_scale": 4, | |
| "refiner_switch": 0.5, | |
| "performance_selection": "Quality", | |
| "aspect_ratios_selection": "1024*1024" | |
| } | |
| if negative_prompt: | |
| input_data["negative_prompt"] = negative_prompt | |
| if mask_image is not None: | |
| mask_image_bytes = await mask_image.read() | |
| mask_image = Image.open(BytesIO(mask_image_bytes)).convert("RGB") | |
| mask_base64 = BytesIO() | |
| mask_image.save(mask_base64, format="WEBP") | |
| mask_image_data_uri = f"data:image/WEBP;base64,{base64.b64encode(mask_base64.getvalue()).decode('utf-8')}" | |
| input_data["inpaint_input_mask"] = mask_image_data_uri | |
| logger.info(">>> INPUT DATA PREPARED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> INPUT DATA PREPARATION ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error preparing input data: {str(e)}", "code": 500}) | |
| try: | |
| for i in range(1, 5): | |
| c = f'c{i}' | |
| if reference_images[c] is not None: | |
| input_data[f"cn_img{i}"] = reference_images[c] | |
| type_value = locals()[f'reference_image_{c}_type'] | |
| if type_value: | |
| input_data[f"cn_type{i}"] = type_value | |
| weight_value = locals()[f'reference_image_{c}_weight'] | |
| if weight_value != 0.0: | |
| input_data[f"cn_weight{i}"] = weight_value | |
| stop_value = locals()[f'reference_image_{c}_stop'] | |
| if stop_value != 0.0 or stop_value != 0: | |
| input_data[f"cn_stop{i}"] = stop_value | |
| logger.info(">>> REFERENCE IMAGE PARAMETERS PROCESSED <<<") | |
| except Exception as e: | |
| logger.error(f">>> REFERENCE IMAGE PARAMETERS ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing reference image parameters: {str(e)}", "code": 500}) | |
| try: | |
| output = image_regeneration_replicate(input_data) | |
| response = requests.get(output[0]) | |
| output_base64 = base64.b64encode(response.content).decode('utf-8') | |
| base64_prefix = image_data_uri.split(",")[0] + "," | |
| logger.info(">>> IMAGE REGENERATION COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE REGENERATION ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error generating image: {str(e)}", "code": 500}) | |
| try: | |
| inference_time = round(time.time() - start_time, 2) | |
| response = { | |
| "output": f"{base64_prefix}{output_base64}", | |
| "inference_time": inference_time, | |
| "code": 200, | |
| } | |
| logger.info(f">>> TOTAL INFERENCE TIME: {inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content=response, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE CREATION ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error creating response: {str(e)}", "code": 500}) | |