| from PIL import Image, ExifTags |
| import numpy as np |
| import torch |
| from torch import Tensor |
|
|
| from einops import rearrange |
| import uuid |
| import os |
|
|
| from .modules.layers import ( |
| SingleStreamBlockProcessor, |
| DoubleStreamBlockProcessor, |
| SingleStreamBlockLoraProcessor, |
| DoubleStreamBlockLoraProcessor, |
| IPDoubleStreamBlockProcessor, |
| ImageProjModel, |
| ) |
| from .sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack |
| from .util import ( |
| load_ae, |
| load_clip, |
| load_flow_model, |
| load_t5, |
| load_controlnet, |
| load_flow_model_quintized, |
| Annotator, |
| get_lora_rank, |
| load_checkpoint |
| ) |
|
|
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
| class XFluxPipeline: |
| def __init__(self, model_type, device, offload: bool = False): |
| self.device = torch.device(device) |
| self.offload = offload |
| self.model_type = model_type |
|
|
| self.clip = load_clip(self.device) |
| self.t5 = load_t5(self.device, max_length=512) |
| self.ae = load_ae(model_type, device="cpu" if offload else self.device) |
| if "fp8" in model_type: |
| self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) |
| else: |
| self.model = load_flow_model(model_type, device="cpu" if offload else self.device) |
|
|
| self.image_encoder_path = "openai/clip-vit-large-patch14" |
| self.hf_lora_collection = "XLabs-AI/flux-lora-collection" |
| self.lora_types_to_names = { |
| "realism": "lora.safetensors", |
| } |
| self.controlnet_loaded = False |
| self.ip_loaded = False |
|
|
| def set_ip(self, local_path: str = None, repo_id = None, name: str = None): |
| self.model.to(self.device) |
|
|
| |
| checkpoint = load_checkpoint(local_path, repo_id, name) |
| prefix = "double_blocks." |
| blocks = {} |
| proj = {} |
|
|
| for key, value in checkpoint.items(): |
| if key.startswith(prefix): |
| blocks[key[len(prefix):].replace('.processor.', '.')] = value |
| if key.startswith("ip_adapter_proj_model"): |
| proj[key[len("ip_adapter_proj_model."):]] = value |
|
|
| |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( |
| self.device, dtype=torch.float16 |
| ) |
| self.clip_image_processor = CLIPImageProcessor() |
|
|
| |
| self.improj = ImageProjModel(4096, 768, 4) |
| self.improj.load_state_dict(proj) |
| self.improj = self.improj.to(self.device, dtype=torch.bfloat16) |
|
|
| ip_attn_procs = {} |
|
|
| for name, _ in self.model.attn_processors.items(): |
| ip_state_dict = {} |
| for k in checkpoint.keys(): |
| if name in k: |
| ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] |
| if ip_state_dict: |
| ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) |
| ip_attn_procs[name].load_state_dict(ip_state_dict) |
| ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) |
| else: |
| ip_attn_procs[name] = self.model.attn_processors[name] |
|
|
| self.model.set_attn_processor(ip_attn_procs) |
| self.ip_loaded = True |
|
|
| def set_lora(self, local_path: str = None, repo_id: str = None, |
| name: str = None, lora_weight: int = 0.7): |
| checkpoint = load_checkpoint(local_path, repo_id, name) |
| self.update_model_with_lora(checkpoint, lora_weight) |
|
|
| def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): |
| checkpoint = load_checkpoint( |
| None, self.hf_lora_collection, self.lora_types_to_names[lora_type] |
| ) |
| self.update_model_with_lora(checkpoint, lora_weight) |
|
|
| def update_model_with_lora(self, checkpoint, lora_weight): |
| rank = get_lora_rank(checkpoint) |
| lora_attn_procs = {} |
|
|
| for name, _ in self.model.attn_processors.items(): |
| lora_state_dict = {} |
| for k in checkpoint.keys(): |
| if name in k: |
| lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight |
|
|
| if len(lora_state_dict): |
| if name.startswith("single_blocks"): |
| lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) |
| else: |
| lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) |
| lora_attn_procs[name].load_state_dict(lora_state_dict) |
| lora_attn_procs[name].to(self.device) |
| else: |
| if name.startswith("single_blocks"): |
| lora_attn_procs[name] = SingleStreamBlockProcessor() |
| else: |
| lora_attn_procs[name] = DoubleStreamBlockProcessor() |
|
|
| self.model.set_attn_processor(lora_attn_procs) |
|
|
| def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): |
| self.model.to(self.device) |
| self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) |
|
|
| checkpoint = load_checkpoint(local_path, repo_id, name) |
| self.controlnet.load_state_dict(checkpoint, strict=False) |
| self.annotator = Annotator(control_type, self.device) |
| self.controlnet_loaded = True |
| self.control_type = control_type |
|
|
| def get_image_proj( |
| self, |
| image_prompt: Tensor, |
| ): |
| |
| image_prompt = self.clip_image_processor( |
| images=image_prompt, |
| return_tensors="pt" |
| ).pixel_values |
| image_prompt = image_prompt.to(self.image_encoder.device) |
| image_prompt_embeds = self.image_encoder( |
| image_prompt |
| ).image_embeds.to( |
| device=self.device, dtype=torch.bfloat16, |
| ) |
| |
| image_proj = self.improj(image_prompt_embeds) |
| return image_proj |
|
|
| def __call__(self, |
| prompt: str, |
| image_prompt: Image = None, |
| controlnet_image: Image = None, |
| width: int = 512, |
| height: int = 512, |
| guidance: float = 4, |
| num_steps: int = 50, |
| seed: int = 123456789, |
| true_gs: float = 3, |
| control_weight: float = 0.9, |
| ip_scale: float = 1.0, |
| neg_ip_scale: float = 1.0, |
| neg_prompt: str = '', |
| neg_image_prompt: Image = None, |
| timestep_to_start_cfg: int = 0, |
| ): |
| width = 16 * (width // 16) |
| height = 16 * (height // 16) |
| image_proj = None |
| neg_image_proj = None |
| if not (image_prompt is None and neg_image_prompt is None) : |
| assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' |
|
|
| if image_prompt is None: |
| image_prompt = np.zeros((width, height, 3), dtype=np.uint8) |
| if neg_image_prompt is None: |
| neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) |
|
|
| image_proj = self.get_image_proj(image_prompt) |
| neg_image_proj = self.get_image_proj(neg_image_prompt) |
|
|
| if self.controlnet_loaded: |
| controlnet_image = self.annotator(controlnet_image, width, height) |
| controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) |
| controlnet_image = controlnet_image.permute( |
| 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) |
|
|
| return self.forward( |
| prompt, |
| width, |
| height, |
| guidance, |
| num_steps, |
| seed, |
| controlnet_image, |
| timestep_to_start_cfg=timestep_to_start_cfg, |
| true_gs=true_gs, |
| control_weight=control_weight, |
| neg_prompt=neg_prompt, |
| image_proj=image_proj, |
| neg_image_proj=neg_image_proj, |
| ip_scale=ip_scale, |
| neg_ip_scale=neg_ip_scale, |
| ) |
|
|
| def generate_from_feat(self, |
| prompt: str, |
| image_proj: Tensor = None, |
| controlnet_image: Image = None, |
| width: int = 512, |
| height: int = 512, |
| guidance: float = 4, |
| num_steps: int = 50, |
| seed: int = 123456789, |
| true_gs: float = 3, |
| control_weight: float = 0.9, |
| ip_scale: float = 1.0, |
| neg_ip_scale: float = 1.0, |
| neg_prompt: str = '', |
| neg_image_proj: Tensor = None, |
| timestep_to_start_cfg: int = 0, |
| ): |
| width = 16 * (width // 16) |
| height = 16 * (height // 16) |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| return self.forward( |
| prompt, |
| width, |
| height, |
| guidance, |
| num_steps, |
| seed, |
| controlnet_image, |
| timestep_to_start_cfg=timestep_to_start_cfg, |
| true_gs=true_gs, |
| control_weight=control_weight, |
| neg_prompt=neg_prompt, |
| image_proj=image_proj, |
| neg_image_proj=neg_image_proj, |
| ip_scale=ip_scale, |
| neg_ip_scale=neg_ip_scale, |
| ) |
|
|
|
|
| @torch.inference_mode() |
| def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, |
| num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, |
| neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, |
| lora_weight, local_path, lora_local_path, ip_local_path): |
| if controlnet_image is not None: |
| controlnet_image = Image.fromarray(controlnet_image) |
| if ((self.controlnet_loaded and control_type != self.control_type) |
| or not self.controlnet_loaded): |
| if local_path is not None: |
| self.set_controlnet(control_type, local_path=local_path) |
| else: |
| self.set_controlnet(control_type, local_path=None, |
| repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", |
| name=f"flux-{control_type}-controlnet-v3.safetensors") |
| if lora_local_path is not None: |
| self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) |
| if image_prompt is not None: |
| image_prompt = Image.fromarray(image_prompt) |
| if neg_image_prompt is not None: |
| neg_image_prompt = Image.fromarray(neg_image_prompt) |
| if not self.ip_loaded: |
| if ip_local_path is not None: |
| self.set_ip(local_path=ip_local_path) |
| else: |
| self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", |
| name="flux-ip-adapter.safetensors") |
| seed = int(seed) |
| if seed == -1: |
| seed = torch.Generator(device="cpu").seed() |
|
|
| img = self(prompt, image_prompt, controlnet_image, width, height, guidance, |
| num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, |
| neg_image_prompt, timestep_to_start_cfg) |
|
|
| filename = f"output/gradio/{uuid.uuid4()}.jpg" |
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
| exif_data = Image.Exif() |
| exif_data[ExifTags.Base.Make] = "XLabs AI" |
| exif_data[ExifTags.Base.Model] = self.model_type |
| img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) |
| return img, filename |
|
|
| def forward( |
| self, |
| prompt, |
| width, |
| height, |
| guidance, |
| num_steps, |
| seed, |
| controlnet_image = None, |
| timestep_to_start_cfg = 0, |
| true_gs = 3.5, |
| control_weight = 0.9, |
| neg_prompt="", |
| image_proj=None, |
| neg_image_proj=None, |
| ip_scale=1.0, |
| neg_ip_scale=1.0, |
| ): |
| x = get_noise( |
| 1, height, width, device=self.device, |
| dtype=torch.bfloat16, seed=seed |
| ) |
| timesteps = get_schedule( |
| num_steps, |
| (width // 8) * (height // 8) // (16 * 16), |
| shift=True, |
| ) |
| torch.manual_seed(seed) |
| with torch.no_grad(): |
| if self.offload: |
| self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) |
| inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) |
| neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) |
|
|
| if self.offload: |
| self.offload_model_to_cpu(self.t5, self.clip) |
| self.model = self.model.to(self.device) |
| if self.controlnet_loaded: |
| x = denoise_controlnet( |
| self.model, |
| **inp_cond, |
| controlnet=self.controlnet, |
| timesteps=timesteps, |
| guidance=guidance, |
| controlnet_cond=controlnet_image, |
| timestep_to_start_cfg=timestep_to_start_cfg, |
| neg_txt=neg_inp_cond['txt'], |
| neg_txt_ids=neg_inp_cond['txt_ids'], |
| neg_vec=neg_inp_cond['vec'], |
| true_gs=true_gs, |
| controlnet_gs=control_weight, |
| image_proj=image_proj, |
| neg_image_proj=neg_image_proj, |
| ip_scale=ip_scale, |
| neg_ip_scale=neg_ip_scale, |
| ) |
| else: |
| x = denoise( |
| self.model, |
| **inp_cond, |
| timesteps=timesteps, |
| guidance=guidance, |
| timestep_to_start_cfg=timestep_to_start_cfg, |
| neg_txt=neg_inp_cond['txt'], |
| neg_txt_ids=neg_inp_cond['txt_ids'], |
| neg_vec=neg_inp_cond['vec'], |
| true_gs=true_gs, |
| image_proj=image_proj, |
| neg_image_proj=neg_image_proj, |
| ip_scale=ip_scale, |
| neg_ip_scale=neg_ip_scale, |
| ) |
|
|
| if self.offload: |
| self.offload_model_to_cpu(self.model) |
| self.ae.decoder.to(x.device) |
| x = unpack(x.float(), height, width) |
| x = self.ae.decode(x) |
| self.offload_model_to_cpu(self.ae.decoder) |
|
|
| x1 = x.clamp(-1, 1) |
| x1 = rearrange(x1[-1], "c h w -> h w c") |
| output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) |
| return output_img |
|
|
| def offload_model_to_cpu(self, *models): |
| if not self.offload: return |
| for model in models: |
| model.cpu() |
| torch.cuda.empty_cache() |
|
|
|
|
| class XFluxSampler(XFluxPipeline): |
| def __init__(self, clip, t5, ae, model, device): |
| self.clip = clip |
| self.t5 = t5 |
| self.ae = ae |
| self.model = model |
| self.model.eval() |
| self.device = device |
| self.controlnet_loaded = False |
| self.ip_loaded = False |
| self.offload = False |
|
|