import os import random import sys from typing import Sequence, Mapping, Any, Union import torch import time from PIL import Image import numpy as np def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: """Returns the value at the given index of a sequence or mapping.""" try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: """Recursively looks at parent folders to find the given name.""" if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: """Add 'ComfyUI' to the sys.path""" comfyui_path = find_path("ComfyUI") if comfyui_path is not None and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: """Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.""" try: from main import load_extra_path_config except ImportError: print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.") from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths is not None: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") add_comfyui_directory_to_sys_path() add_extra_model_paths() def import_custom_nodes() -> None: """Find all custom nodes and initialize them""" import asyncio import execution from nodes import init_extra_nodes import server loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() from nodes import NODE_CLASS_MAPPINGS class FitCheckWorkflow: def __init__(self): import_custom_nodes() with torch.inference_mode(): # Initialize all node classes self.loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() self.comfyuivtonmaskloader = NODE_CLASS_MAPPINGS["ComfyUIVtonMaskLoader"]() self.emptyimage = NODE_CLASS_MAPPINGS["EmptyImage"]() self.rmbg = NODE_CLASS_MAPPINGS["RMBG"]() self.layerutility_imageremovealpha = NODE_CLASS_MAPPINGS["LayerUtility: ImageRemoveAlpha"]() self.inpaintcropimproved = NODE_CLASS_MAPPINGS["InpaintCropImproved"]() self.geminiflash = NODE_CLASS_MAPPINGS["GeminiFlash"]() self.stringfunctionpysssss = NODE_CLASS_MAPPINGS["StringFunction|pysssss"]() self.cr_text_replace = NODE_CLASS_MAPPINGS["CR Text Replace"]() self.dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() self.cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() self.vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() self.unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() self.stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]() self.clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() self.clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() self.loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() self.fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() self.stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]() self.conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]() self.controlnetloader = NODE_CLASS_MAPPINGS["ControlNetLoader"]() self.setunioncontrolnettype = NODE_CLASS_MAPPINGS["SetUnionControlNetType"]() self.upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]() self.imageupscalewithmodel = NODE_CLASS_MAPPINGS["ImageUpscaleWithModel"]() self.imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() self.comfyuivtonmaskgenerator = NODE_CLASS_MAPPINGS["ComfyUIVtonMaskGenerator"]() self.imagetomask = NODE_CLASS_MAPPINGS["ImageToMask"]() self.layermask_maskgrow = NODE_CLASS_MAPPINGS["LayerMask: MaskGrow"]() self.loadimagemask = NODE_CLASS_MAPPINGS["LoadImageMask"]() self.mask_fill_holes = NODE_CLASS_MAPPINGS["Mask Fill Holes"]() self.resizemask = NODE_CLASS_MAPPINGS["ResizeMask"]() self.imageconcanate = NODE_CLASS_MAPPINGS["ImageConcanate"]() self.getimagesize = NODE_CLASS_MAPPINGS["GetImageSize+"]() self.pixelperfectresolution = NODE_CLASS_MAPPINGS["PixelPerfectResolution"]() self.aio_preprocessor = NODE_CLASS_MAPPINGS["AIO_Preprocessor"]() self.layerutility_purgevram_v2 = NODE_CLASS_MAPPINGS["LayerUtility: PurgeVRAM V2"]() self.controlnetapplyadvanced = NODE_CLASS_MAPPINGS["ControlNetApplyAdvanced"]() self.getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() self.sammodelloader_segment_anything = NODE_CLASS_MAPPINGS["SAMModelLoader (segment anything)"]() self.groundingdinomodelloader_segment_anything = NODE_CLASS_MAPPINGS["GroundingDinoModelLoader (segment anything)"]() self.groundingdinosamsegment_segment_anything = NODE_CLASS_MAPPINGS["GroundingDinoSAMSegment (segment anything)"]() self.maskcomposite = NODE_CLASS_MAPPINGS["MaskComposite"]() self.apersonmaskgenerator = NODE_CLASS_MAPPINGS["APersonMaskGenerator"]() self.masktoimage = NODE_CLASS_MAPPINGS["MaskToImage"]() self.inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]() self.differentialdiffusion = NODE_CLASS_MAPPINGS["DifferentialDiffusion"]() self.ksampler = NODE_CLASS_MAPPINGS["KSampler"]() self.vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]() self.imagecrop = NODE_CLASS_MAPPINGS["ImageCrop+"]() self.inpaintstitchimproved = NODE_CLASS_MAPPINGS["InpaintStitchImproved"]() self.showtextpysssss = NODE_CLASS_MAPPINGS["ShowText|pysssss"]() # Initialize commonly used nodes self.comfyuivtonmaskloader_983 = self.comfyuivtonmaskloader.load_mask_model(device="cpu") self.emptyimage_1015 = self.emptyimage.generate(width=768, height=1024, batch_size=1, color=0) self.dualcliploader_1024 = self.dualcliploader.load_clip( clip_name1="clip_l.safetensors", clip_name2="t5xxl_fp8_e4m3fn.safetensors", type="flux", device="default", ) self.vaeloader_1023 = self.vaeloader.load_vae(vae_name="ae.safetensors") self.unetloader_1025 = self.unetloader.load_unet( unet_name="flux1-fill-dev.safetensors", weight_dtype="fp8_e4m3fn" ) self.stylemodelloader_1026 = self.stylemodelloader.load_style_model( style_model_name="flux1-redux-dev.safetensors" ) self.clipvisionloader_1151 = self.clipvisionloader.load_clip( clip_name="sigclip_vision_patch14_384.safetensors" ) self.controlnetloader_1042 = self.controlnetloader.load_controlnet( control_net_name="flux-union-pro-v2.safetensors" ) self.setunioncontrolnettype_1041 = self.setunioncontrolnettype.set_controlnet_type( type="depth", control_net=get_value_at_index(self.controlnetloader_1042, 0) ) self.upscalemodelloader_1155 = self.upscalemodelloader.load_model( model_name="RealESRGAN_x2.pth" ) # self.upscalemodelloader_1189 = self.upscalemodelloader.load_model( # model_name="Phips/1xDeNoise_realplksr_otf.safetensors" # ) self.comfyuivtonmaskloader_1173 = self.comfyuivtonmaskloader.load_mask_model(device="cpu") self.sammodelloader_segment_anything_1167 = self.sammodelloader_segment_anything.main( model_name="sam_vit_h (2.56GB)" ) self.groundingdinomodelloader_segment_anything_1168 = self.groundingdinomodelloader_segment_anything.main( model_name="GroundingDINO_SwinT_OGC (694MB)" ) @torch.inference_mode() def __call__(self, *args, **kwargs): start = time.time() # Extract parameters from kwargs with defaults api_key = kwargs.get("api_key", "AIzaSyA2XScgkb65IaskjGK6EkUb7HKGjl9cKNw") swap_type = kwargs.get("swap_type", "Dresses") mode = kwargs.get("mode", "balanced") seed = kwargs.get("seed", random.randint(1, 2**64)) # Validate parameters valid_swap_types = ["Upper-body", "Lower-body", "Dresses", "Manual"] valid_modes = ["speed", "balanced", "quality"] if swap_type not in valid_swap_types: raise ValueError(f"swap_type must be one of {valid_swap_types}") if mode not in valid_modes: raise ValueError(f"mode must be one of {valid_modes}") print(f"Running FitCheck with swap_type: {swap_type}, mode: {mode}") # Load images loadimage_904 = self.loadimage.load_image(image="model_img.png") loadimage_909 = self.loadimage.load_image(image="cloth_img.png") # RMBG processing rmbg_1160 = self.rmbg.process_image( model="RMBG-2.0", sensitivity=1, process_res=1024, mask_blur=0, mask_offset=0, invert_output=False, refine_foreground=True, background="Alpha", background_color="#000000", image=get_value_at_index(loadimage_909, 0), ) layerutility_imageremovealpha_1158 = self.layerutility_imageremovealpha.image_remove_alpha( fill_background=True, background_color="#000000", RGBA_image=get_value_at_index(loadimage_909, 0), mask=get_value_at_index(rmbg_1160, 1), ) inpaintcropimproved_1003 = self.inpaintcropimproved.inpaint_crop( downscale_algorithm="bilinear", upscale_algorithm="bicubic", preresize=False, preresize_mode="ensure minimum resolution", preresize_min_width=1024, preresize_min_height=1024, preresize_max_width=16384, preresize_max_height=16384, mask_fill_holes=True, mask_expand_pixels=0, mask_invert=False, mask_blend_pixels=0, mask_hipass_filter=0.1, extend_for_outpainting=False, extend_up_factor=1, extend_down_factor=1, extend_left_factor=1, extend_right_factor=1, context_from_mask_extend_factor=1.1500000000000001, output_resize_to_target_size=True, output_target_width=768, output_target_height=1024, output_padding="0", image=get_value_at_index(layerutility_imageremovealpha_1158, 0), mask=get_value_at_index(rmbg_1160, 1), ) # Gemini processing with configurable API key geminiflash_1120 = self.geminiflash.generate_content( prompt="What kind of outfit is this,models size like slim,plus size etc,and describe it clearly in short, return to the point combined prompt in plain text", input_type="image", model_version="gemini-2.0-flash", operation_mode="analysis", chat_mode=False, clear_history=True, Additional_Context="", api_key=api_key, max_output_tokens=8192, temperature=0.4, structured_output=False, max_images=6, batch_count=1, seed=random.randint(1, 2**64), images=get_value_at_index(inpaintcropimproved_1003, 1), ) stringfunctionpysssss_1110 = self.stringfunctionpysssss.exec( action="append", tidy_tags="no", text_a="The fashion model wearing the [outfit]\n", text_b="The 2 shirts on both sides are exactly the same, same color, same logo, same text, same features", text_c="", ) cr_text_replace_1119 = self.cr_text_replace.replace_text( find1="[outfit]", replace1=get_value_at_index(geminiflash_1120, 0), find2="", replace2="", find3="", replace3="", text=get_value_at_index(stringfunctionpysssss_1110, 0), ) print("\n=================\n\n\n") print("Generated prompt:\n", get_value_at_index(cr_text_replace_1119, 0)) print("\n\n\n=================\n") cliptextencode_1022 = self.cliptextencode.encode( text=get_value_at_index(cr_text_replace_1119, 0), clip=get_value_at_index(self.dualcliploader_1024, 0), ) clipvisionencode_1027 = self.clipvisionencode.encode( crop="none", clip_vision=get_value_at_index(self.clipvisionloader_1151, 0), image=get_value_at_index(inpaintcropimproved_1003, 1), ) # Always load cat-vton LoRA first loraloadermodelonly_1032 = self.loraloadermodelonly.load_lora_model_only( lora_name="cat-vton.safetensors", strength_model=1, model=get_value_at_index(self.unetloader_1025, 0), ) # Mode-based LoRA loading and configuration if mode == "speed": loraloadermodelonly_1031 = self.loraloadermodelonly.load_lora_model_only( lora_name="turbo.safetensors", strength_model=1.0, model=get_value_at_index(loraloadermodelonly_1032, 0), ) current_model = get_value_at_index(loraloadermodelonly_1031, 0) steps = 11 elif mode == "balanced": loraloadermodelonly_1031 = self.loraloadermodelonly.load_lora_model_only( lora_name="turbo.safetensors", strength_model=0.5, model=get_value_at_index(loraloadermodelonly_1032, 0), ) current_model = get_value_at_index(loraloadermodelonly_1031, 0) steps = 17 else: # quality current_model = get_value_at_index(loraloadermodelonly_1032, 0) steps = 34 fluxguidance_1020 = self.fluxguidance.append( guidance=50, conditioning=get_value_at_index(cliptextencode_1022, 0) ) stylemodelapply_1019 = self.stylemodelapply.apply_stylemodel( strength=1, strength_type="multiply", conditioning=get_value_at_index(fluxguidance_1020, 0), style_model=get_value_at_index(self.stylemodelloader_1026, 0), clip_vision_output=get_value_at_index(clipvisionencode_1027, 0), ) conditioningzeroout_1021 = self.conditioningzeroout.zero_out( conditioning=get_value_at_index(fluxguidance_1020, 0) ) imageupscalewithmodel_1156 = self.imageupscalewithmodel.upscale( upscale_model=get_value_at_index(self.upscalemodelloader_1155, 0), image=get_value_at_index(loadimage_904, 0), ) imageresize_1058 = self.imageresize.execute( width=1536, height=1536, interpolation="nearest", method="keep proportion", condition="always", multiple_of=0, image=get_value_at_index(imageupscalewithmodel_1156, 0), ) # Conditional logic based on swap_type if swap_type != "Manual": # Generate masks automatically for Upper-body, Lower-body, Dresses comfyuivtonmaskgenerator_982 = self.comfyuivtonmaskgenerator.generate_mask( category=swap_type, offset_top=0, offset_bottom=0, offset_left=0, offset_right=0, mask_model=get_value_at_index(self.comfyuivtonmaskloader_983, 0), vton_image=get_value_at_index(imageresize_1058, 0), ) imagetomask_990 = self.imagetomask.image_to_mask( channel="red", image=get_value_at_index(comfyuivtonmaskgenerator_982, 1) ) layermask_maskgrow_891 = self.layermask_maskgrow.mask_grow( invert_mask=False, grow=0, blur=3, mask=get_value_at_index(imagetomask_990, 0), ) # Use automatically generated mask resize_mask_source = get_value_at_index(layermask_maskgrow_891, 0) else: # Manual mode - load user provided mask loadimage_manual_mask = self.loadimage.load_image(image="mask_img.png") # Convert image to mask (same as automatic mode) imagetomask_manual = self.imagetomask.image_to_mask( channel="red", image=get_value_at_index(loadimage_manual_mask, 0) ) # mask_fill_holes_1147 = self.mask_fill_holes.fill_region( # masks=get_value_at_index(imagetomask_manual, 0), # ) # Use user provided mask resize_mask_source = get_value_at_index(imagetomask_manual, 0) resizemask_1059 = self.resizemask.resize( width=get_value_at_index(imageresize_1058, 1), height=get_value_at_index(imageresize_1058, 2), keep_proportions=False, upscale_method="nearest-exact", crop="disabled", mask=resize_mask_source, ) inpaintcropimproved_999 = self.inpaintcropimproved.inpaint_crop( downscale_algorithm="nearest", upscale_algorithm="nearest", preresize=False, preresize_mode="ensure minimum resolution", preresize_min_width=1024, preresize_min_height=1024, preresize_max_width=16384, preresize_max_height=16384, mask_fill_holes=True, mask_expand_pixels=8, mask_invert=False, mask_blend_pixels=20, mask_hipass_filter=0.1, extend_for_outpainting=False, extend_up_factor=1, extend_down_factor=1, extend_left_factor=1, extend_right_factor=1, context_from_mask_extend_factor=1.0500000000000003, output_resize_to_target_size=True, output_target_width=768, output_target_height=1024, output_padding="64", image=get_value_at_index(imageresize_1058, 0), mask=get_value_at_index(resizemask_1059, 0), ) imageconcanate_1044 = self.imageconcanate.concatenate( direction="left", match_image_size=True, image1=get_value_at_index(inpaintcropimproved_999, 1), image2=get_value_at_index(self.emptyimage_1015, 0), ) getimagesize_1047 = self.getimagesize.execute( image=get_value_at_index(imageconcanate_1044, 0) ) pixelperfectresolution_1049 = self.pixelperfectresolution.execute( image_gen_width=get_value_at_index(getimagesize_1047, 0), image_gen_height=get_value_at_index(getimagesize_1047, 1), resize_mode="Just Resize", original_image=get_value_at_index(imageconcanate_1044, 0), ) aio_preprocessor_1046 = self.aio_preprocessor.execute( preprocessor="Zoe_DepthAnythingPreprocessor", resolution=get_value_at_index(pixelperfectresolution_1049, 0), image=get_value_at_index(imageconcanate_1044, 0), ) layerutility_purgevram_v2_1191 = self.layerutility_purgevram_v2.purge_vram_v2( purge_cache=True, purge_models=True, anything=get_value_at_index(aio_preprocessor_1046, 0), ) controlnetapplyadvanced_1043 = self.controlnetapplyadvanced.apply_controlnet( strength=0.7000000000000002, start_percent=0, end_percent=0.5000000000000001, positive=get_value_at_index(stylemodelapply_1019, 0), negative=get_value_at_index(conditioningzeroout_1021, 0), control_net=get_value_at_index(self.setunioncontrolnettype_1041, 0), image=get_value_at_index(layerutility_purgevram_v2_1191, 0), vae=get_value_at_index(self.vaeloader_1023, 0), ) imageconcanate_1013 = self.imageconcanate.concatenate( direction="left", match_image_size=True, image1=get_value_at_index(inpaintcropimproved_999, 1), image2=get_value_at_index(inpaintcropimproved_1003, 1), ) # Second mask generation logic (only if not Manual) if swap_type != "Manual": getimagesizeandcount_1165 = self.getimagesizeandcount.getsize( image=get_value_at_index(inpaintcropimproved_999, 1) ) comfyuivtonmaskgenerator_1179 = self.comfyuivtonmaskgenerator.generate_mask( category=swap_type, offset_top=0, offset_bottom=0, offset_left=0, offset_right=0, mask_model=get_value_at_index(self.comfyuivtonmaskloader_1173, 0), vton_image=get_value_at_index(getimagesizeandcount_1165, 0), ) imagetomask_1175 = self.imagetomask.image_to_mask( channel="red", image=get_value_at_index(comfyuivtonmaskgenerator_1179, 1) ) groundingdinosamsegment_segment_anything_1176 = self.groundingdinosamsegment_segment_anything.main( prompt="hand", threshold=0.28, sam_model=get_value_at_index(self.sammodelloader_segment_anything_1167, 0), grounding_dino_model=get_value_at_index(self.groundingdinomodelloader_segment_anything_1168, 0), image=get_value_at_index(getimagesizeandcount_1165, 0), ) layerutility_purgevram_v2_1192 = self.layerutility_purgevram_v2.purge_vram_v2( purge_cache=True, purge_models=True, anything=get_value_at_index(groundingdinosamsegment_segment_anything_1176, 1), ) maskcomposite_1174 = self.maskcomposite.combine( x=0, y=0, operation="subtract", destination=get_value_at_index(imagetomask_1175, 0), source=get_value_at_index(layerutility_purgevram_v2_1192, 0), ) apersonmaskgenerator_1181 = self.apersonmaskgenerator.generate_mask( face_mask=True, background_mask=False, hair_mask=False, body_mask=False, clothes_mask=False, confidence=0.4, refine_mask=True, images=get_value_at_index(getimagesizeandcount_1165, 0), ) apersonmaskgenerator_1177 = self.apersonmaskgenerator.generate_mask( face_mask=False, background_mask=False, hair_mask=True, body_mask=False, clothes_mask=False, confidence=0.4, refine_mask=True, images=get_value_at_index(getimagesizeandcount_1165, 0), ) maskcomposite_1171 = self.maskcomposite.combine( x=0, y=0, operation="add", destination=get_value_at_index(apersonmaskgenerator_1181, 0), source=get_value_at_index(apersonmaskgenerator_1177, 0), ) maskcomposite_1169 = self.maskcomposite.combine( x=0, y=0, operation="subtract", destination=get_value_at_index(maskcomposite_1174, 0), source=get_value_at_index(maskcomposite_1171, 0), ) layermask_maskgrow_1178 = self.layermask_maskgrow.mask_grow( invert_mask=False, grow=0, blur=3, mask=get_value_at_index(maskcomposite_1169, 0), ) # Use processed mask for automatic modes masktoimage_mask_source = get_value_at_index(layermask_maskgrow_1178, 0) else: # Use cropped mask for Manual mode masktoimage_mask_source = get_value_at_index(inpaintcropimproved_999, 2) masktoimage_1017 = self.masktoimage.mask_to_image( mask=masktoimage_mask_source ) imageconcanate_1016 = self.imageconcanate.concatenate( direction="left", match_image_size=True, image1=get_value_at_index(masktoimage_1017, 0), image2=get_value_at_index(self.emptyimage_1015, 0), ) imagetomask_1035 = self.imagetomask.image_to_mask( channel="red", image=get_value_at_index(imageconcanate_1016, 0) ) inpaintmodelconditioning_1033 = self.inpaintmodelconditioning.encode( noise_mask=True, positive=get_value_at_index(controlnetapplyadvanced_1043, 0), negative=get_value_at_index(controlnetapplyadvanced_1043, 1), vae=get_value_at_index(self.vaeloader_1023, 0), pixels=get_value_at_index(imageconcanate_1013, 0), mask=get_value_at_index(imagetomask_1035, 0), ) differentialdiffusion_1040 = self.differentialdiffusion.apply( model=current_model ) ksampler_1030 = self.ksampler.sample( seed=seed, steps=steps, cfg=1, sampler_name="euler", scheduler="simple", denoise=1, model=get_value_at_index(differentialdiffusion_1040, 0), positive=get_value_at_index(inpaintmodelconditioning_1033, 0), negative=get_value_at_index(inpaintmodelconditioning_1033, 1), latent_image=get_value_at_index(inpaintmodelconditioning_1033, 2), ) vaedecode_1036 = self.vaedecode.decode( samples=get_value_at_index(ksampler_1030, 0), vae=get_value_at_index(self.vaeloader_1023, 0), ) imagecrop_1055 = self.imagecrop.execute( width=768, height=1024, position="top-right", x_offset=0, y_offset=0, image=get_value_at_index(vaedecode_1036, 0), ) imageupscalewithmodel_1188 = self.imageupscalewithmodel.upscale( upscale_model=get_value_at_index(self.upscalemodelloader_1155, 0), image=get_value_at_index(imagecrop_1055, 0), ) layerutility_purgevram_v2_1187 = self.layerutility_purgevram_v2.purge_vram_v2( purge_cache=True, purge_models=True, anything=get_value_at_index(imageupscalewithmodel_1188, 0), ) inpaintstitchimproved_1054 = self.inpaintstitchimproved.inpaint_stitch( stitcher=get_value_at_index(inpaintcropimproved_999, 0), inpainted_image=get_value_at_index(layerutility_purgevram_v2_1187, 0), ) showtextpysssss_1111 = self.showtextpysssss.notify( text=get_value_at_index(cr_text_replace_1119, 0), unique_id=16351491204491641391, ) # layerutility_purgevram_v2_1187 = self.layerutility_purgevram_v2.purge_vram_v2( # purge_cache=True, # purge_models=True, # anything=get_value_at_index(inpaintstitchimproved_1054, 0), # ) # imageupscalewithmodel_1188 = self.imageupscalewithmodel.upscale( # upscale_model=get_value_at_index(self.upscalemodelloader_1189, 0), # image=get_value_at_index(layerutility_purgevram_v2_1187, 0), # ) # Convert output to image and save imgs = [] for res in inpaintstitchimproved_1054[0]: img = Image.fromarray(np.clip(255. * res.detach().cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) img.save("fitcheck_output.png") imgs.append(img) stop = time.time() print(f"Total time: {stop - start:.2f} seconds") return imgs def cleanup(self): """Clean up VRAM and cache after inference""" try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print("VRAM cleanup completed") except Exception as e: print(f"Cleanup warning: {e}") # Example usage: # generator = FitCheckWorkflow() # imgs = generator(api_key="your_api_key", swap_type="Dresses", mode="balanced")