| import os |
| import re |
| import time |
| from dataclasses import dataclass |
| from glob import iglob |
| from mmgp import offload as offload |
| import torch |
| from shared.utils.utils import calculate_new_dimensions |
| from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack, resizeinput |
| from .modules.layers import get_linear_split_map |
| from transformers import SiglipVisionModel, SiglipImageProcessor |
| import torchvision.transforms.functional as TVF |
| import math |
| from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image |
| from shared.utils import files_locator as fl |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer, Qwen2VLProcessor |
|
|
| from .util import ( |
| aspect_ratio_to_height_width, |
| load_ae, |
| load_clip, |
| load_flow_model, |
| load_t5, |
| save_image, |
| ) |
|
|
| from PIL import Image |
| def preprocess_ref(raw_image: Image.Image, long_size: int = 512): |
| |
| image_w, image_h = raw_image.size |
|
|
| |
| if image_w >= image_h: |
| new_w = long_size |
| new_h = int((long_size / image_w) * image_h) |
| else: |
| new_h = long_size |
| new_w = int((long_size / image_h) * image_w) |
|
|
| |
| raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) |
| target_w = new_w // 16 * 16 |
| target_h = new_h // 16 * 16 |
|
|
| |
| left = (new_w - target_w) // 2 |
| top = (new_h - target_h) // 2 |
| right = left + target_w |
| bottom = top + target_h |
|
|
| |
| raw_image = raw_image.crop((left, top, right, bottom)) |
|
|
| |
| raw_image = raw_image.convert("RGB") |
| return raw_image |
|
|
| def stitch_images(img1, img2): |
| |
| width1, height1 = img1.size |
| width2, height2 = img2.size |
| new_width2 = int(width2 * height1 / height2) |
| img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) |
| |
| stitched = Image.new('RGB', (width1 + new_width2, height1)) |
| stitched.paste(img1, (0, 0)) |
| stitched.paste(img2_resized, (width1, 0)) |
| return stitched |
|
|
| class model_factory: |
| def __init__( |
| self, |
| checkpoint_dir, |
| model_filename = None, |
| model_type = None, |
| model_def = None, |
| base_model_type = None, |
| text_encoder_filename = None, |
| quantizeTransformer = False, |
| save_quantized = False, |
| dtype = torch.bfloat16, |
| VAE_dtype = torch.float32, |
| mixed_precision_transformer = False |
| ): |
| self.device = torch.device(f"cuda") |
| self.VAE_dtype = VAE_dtype |
| self.dtype = dtype |
| torch_device = "cpu" |
| self.guidance_max_phases = model_def.get("guidance_max_phases", 0) |
|
|
| |
| |
| self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) |
| self.clip = load_clip(torch_device) |
| self.name = model_def.get("flux-model", "flux-dev") |
| |
| |
| |
| source = model_def.get("source", None) |
| self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) |
| self.model_def = model_def |
| self.vae = load_ae(self.name, device=torch_device) |
|
|
| siglip_processor = siglip_model = feature_embedder = None |
| if self.name == 'flux-dev-uso': |
| siglip_path = fl.locate_folder("siglip-so400m-patch14-384") |
| siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) |
| siglip_model = SiglipVisionModel.from_pretrained(siglip_path) |
| siglip_model.eval().to("cpu") |
| if len(model_filename) > 1: |
| from .modules.layers import SigLIPMultiFeatProjModel |
| feature_embedder = SigLIPMultiFeatProjModel( |
| siglip_token_nums=729, |
| style_token_nums=64, |
| siglip_token_dims=1152, |
| hidden_size=3072, |
| context_layer_norm=True, |
| ) |
| offload.load_model_data(feature_embedder, model_filename[1]) |
| self.vision_encoder = siglip_model |
| self.vision_encoder_processor = siglip_processor |
| self.feature_embedder = feature_embedder |
|
|
| if self.name in ['flux-dev-kontext-dreamomni2']: |
| self.processor = Qwen2VLProcessor.from_pretrained(fl.locate_folder("Qwen2.5-VL-7B-DreamOmni2")) |
| self.vlm_model = offload.fast_load_transformers_model(fl.locate_file( os.path.join("Qwen2.5-VL-7B-DreamOmni2","Qwen2.5-VL-7B-DreamOmni2_quanto_bf16_int8.safetensors")), writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= fl.locate_file(os.path.join("Qwen2.5-VL-7B-DreamOmni2", "config.json"))) |
| else: |
| self.processor = None |
| self.vlm_model = None |
| |
| |
|
|
| if not source is None: |
| from wgp import save_model |
| save_model(self.model, model_type, dtype, None) |
|
|
| if save_quantized: |
| from wgp import save_quantized_model |
| save_quantized_model(self.model, model_type, model_filename[0], dtype, None) |
|
|
| split_linear_modules_map = get_linear_split_map() |
| self.model.split_linear_modules_map = split_linear_modules_map |
| offload.split_linear_modules(self.model, split_linear_modules_map ) |
|
|
| def infer_vlm(self, input_img_path,input_instruction,prefix): |
| tp=[] |
| for path in input_img_path: |
| tp.append({"type": "image", "image": path}) |
| tp.append({"type": "text", "text": input_instruction+prefix}) |
| messages = [ |
| { |
| "role": "user", |
| "content": tp, |
| } |
| ] |
|
|
| |
| text = self.processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| |
| |
| inputs = self.processor( |
| text=[text], |
| images=input_img_path, |
| |
| |
| padding=True, |
| return_tensors="pt", |
| ) |
| inputs = inputs.to("cpu") |
|
|
| |
| generated_ids = self.vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096) |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| output_text = self.processor.batch_decode( |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
| return output_text[0] |
|
|
| |
| def generate( |
| self, |
| seed: int | None = None, |
| input_prompt: str = "replace the logo with the text 'Black Forest Labs'", |
| n_prompt: str = None, |
| sampling_steps: int = 20, |
| input_ref_images = None, |
| input_frames= None, |
| input_masks= None, |
| width= 832, |
| height=480, |
| embedded_guidance_scale: float = 2.5, |
| guide_scale = 2.5, |
| fit_into_canvas = None, |
| callback = None, |
| loras_slists = None, |
| batch_size = 1, |
| video_prompt_type = "", |
| joint_pass = False, |
| image_refs_relative_size = 100, |
| denoising_strength = 1., |
| **bbargs |
| ): |
| if self._interrupt: |
| return None |
| if self.guidance_max_phases < 1: guide_scale = 1 |
| if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" |
| device="cuda" |
| flux_dev_uso = self.name in ['flux-dev-uso'] |
| flux_dev_umo = self.name in ['flux-dev-umo'] |
| flux_kontext_dreamomni2 = self.name in ['flux-dev-kontext-dreamomni2'] |
| latent_stiching = flux_dev_uso or flux_dev_umo or flux_kontext_dreamomni2 |
|
|
| lock_dimensions= False |
|
|
| input_ref_images = [] if input_ref_images is None else input_ref_images[:] |
| if flux_dev_umo: |
| ref_long_side = 512 if len(input_ref_images) <= 1 else 320 |
| input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images] |
| lock_dimensions = True |
|
|
| elif flux_kontext_dreamomni2: |
| for i, img in enumerate(input_ref_images): |
| input_ref_images[i] = resizeinput(img) |
| input_prompt= self.infer_vlm(input_ref_images,input_prompt, " It is editing task." if "K" in video_prompt_type else " It is generation task." ) |
| input_prompt = input_prompt[6:-7] |
| print(input_prompt) |
| lock_dimensions = True |
|
|
| ref_style_imgs = [] |
| if "I" in video_prompt_type and len(input_ref_images) > 0: |
| if flux_dev_uso : |
| if "J" in video_prompt_type: |
| ref_style_imgs = input_ref_images |
| input_ref_images = [] |
| elif len(input_ref_images) > 1 : |
| ref_style_imgs = input_ref_images[-1:] |
| input_ref_images = input_ref_images[:-1] |
|
|
| if latent_stiching: |
| |
| if not lock_dimensions : |
| for i in range(len(input_ref_images)): |
| w, h = input_ref_images[i].size |
| image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0) |
| input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) |
| else: |
| |
| stiched = input_ref_images[0] |
| for new_img in input_ref_images[1:]: |
| stiched = stitch_images(stiched, new_img) |
| input_ref_images = [stiched] |
| elif input_frames is not None: |
| input_ref_images = [convert_tensor_to_image(input_frames) ] |
| else: |
| input_ref_images = None |
| image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) |
| |
|
|
| if latent_stiching : |
| inp, height, width = prepare_multi_ip( |
| ae=self.vae, |
| img_cond_list=input_ref_images, |
| target_width=width, |
| target_height=height, |
| bs=batch_size, |
| seed=seed, |
| device=device, |
| res_match_output= flux_dev_uso or flux_dev_umo, |
| pe = 'w' if flux_kontext_dreamomni2 else 'd', |
| set_cond_index = flux_kontext_dreamomni2, |
| conditions_zero_start= flux_kontext_dreamomni2 |
| ) |
| else: |
| inp, height, width = prepare_kontext( |
| ae=self.vae, |
| img_cond_list=input_ref_images, |
| target_width=width, |
| target_height=height, |
| bs=batch_size, |
| seed=seed, |
| device=device, |
| img_mask=image_mask, |
| ) |
|
|
| inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) |
| if guide_scale != 1: |
| inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device)) |
|
|
| timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) |
|
|
| ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] |
| if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None: |
| |
| siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs] |
| siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) |
| siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device) |
| inp["siglip_embedding"] = siglip_embedding |
| inp["siglip_embedding_ids"] = siglip_embedding_ids |
|
|
| def unpack_latent(x): |
| return unpack(x.float(), height, width) |
|
|
| |
| x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength) |
| if x==None: return None |
| |
| x = unpack_latent(x) |
| with torch.autocast(device_type=device, dtype=torch.bfloat16): |
| x = self.vae.decode(x) |
|
|
| if image_mask is not None: |
| img_msk_rebuilt = inp["img_msk_rebuilt"] |
| img= input_frames.squeeze(1).unsqueeze(0) |
| x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt |
|
|
| x = x.clamp(-1, 1) |
| x = x.transpose(0, 1) |
| return x |
|
|
| def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, video_prompt_type, **kwargs): |
| if model_type != "flux_dev_kontext_dreamomni2": return [], [] |
|
|
| preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") |
| if len(preloadURLs) < 2: return [], [] |
| edit = "K" in video_prompt_type |
| return [ fl.locate_file(os.path.basename(preloadURLs[0 if edit else 1]))] , [1] |