from __future__ import annotations import hashlib import re from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch from PIL import Image, ImageEnhance, ImageFilter from transformers import AutoTokenizer from .model import UniversalHFTextToImageAdapter @dataclass class ImageGenPipelineOutput: images: List[Image.Image] tensors: Optional[torch.Tensor] = None conditioning: Optional[Any] = None prompt: Optional[Union[str, List[str]]] = None metadata: Dict[str, Any] = field(default_factory=dict) class ImageGenPipeline: """ Hugging Face-style pipeline for the trained ImageGen adapter. The pipeline preserves the adapter checkpoint exactly. It loads `adapter_model.pt`, routes text through the supplied/shared Qwen text model when available, and exposes a `DiffusionPipeline`-like `from_pretrained` and `__call__` surface. """ config_name = "model_index.json" def __init__( self, adapter: UniversalHFTextToImageAdapter, tokenizer: Optional[Any] = None, text_model: Optional[torch.nn.Module] = None, model_dir: Optional[Union[str, Path]] = None, ): self.adapter = adapter self.tokenizer = tokenizer self.text_model = text_model self.model_dir = Path(model_dir) if model_dir is not None else None self.sdxl_tokenizer = None self.sdxl_tokenizer_2 = None self.sdxl_text_encoder = None self.sdxl_text_encoder_2 = None self.local_text_embedding = None self._prompt_cache: "OrderedDict[str, Dict[str, Any]]" = OrderedDict() self._prompt_cache_capacity = 32 if self.text_model is not None: self.adapter.text_model = self.text_model self.adapter.freeze_text_model() @classmethod def from_pretrained( cls, model_dir: Union[str, Path], *, text_model: Optional[torch.nn.Module] = None, tokenizer: Optional[Any] = None, device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, local_files_only: bool = True, **_: Any, ) -> "ImageGenPipeline": model_path = Path(model_dir) if tokenizer is None: tokenizer_dir = model_path / "tokenizer" if tokenizer_dir.exists(): tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, use_fast=False, local_files_only=local_files_only, trust_remote_code=True, ) adapter = UniversalHFTextToImageAdapter.from_pretrained( model_path, text_model=text_model, device=device or "cpu", ) if torch_dtype is not None: adapter = adapter.to(dtype=torch_dtype) return cls(adapter=adapter, tokenizer=tokenizer, text_model=text_model, model_dir=model_path) def to(self, device: Union[str, torch.device], dtype: Optional[torch.dtype] = None) -> "ImageGenPipeline": self.adapter.to(device=device) if dtype is not None: self.adapter.to(dtype=dtype) if self.text_model is not None and hasattr(self.text_model, "to"): self.text_model.to(device=device) return self @property def device(self) -> torch.device: return next(self.adapter.parameters()).device @staticmethod def _normalize_prompt_text(prompt: str) -> str: text = re.sub(r"\s+", " ", str(prompt)).strip() if not text: return text pieces = [part.strip() for part in re.split(r"\s*(?:,|;|\n)\s*", text) if part.strip()] if len(pieces) <= 1: return text deduped: List[str] = [] seen: set[str] = set() for piece in pieces: key = re.sub(r"\s+", " ", piece).casefold() if key not in seen: seen.add(key) deduped.append(piece) return ", ".join(deduped) def _normalize_prompt(self, prompt: Union[str, List[str]]) -> Union[str, List[str]]: if isinstance(prompt, str): return self._normalize_prompt_text(prompt) return [self._normalize_prompt_text(item) for item in prompt] def _cache_fingerprint(self, prompt: Union[str, List[str]], encoded: Dict[str, torch.Tensor], **parts: Any) -> str: digest = hashlib.sha256() digest.update(repr(prompt).encode("utf-8")) for key in ("input_ids", "attention_mask"): tensor = encoded.get(key) if tensor is not None: digest.update(key.encode("utf-8")) digest.update(str(tuple(tensor.shape)).encode("ascii")) digest.update(str(tensor.dtype).encode("ascii")) digest.update(tensor.detach().cpu().contiguous().numpy().tobytes()) for key, value in sorted(parts.items()): digest.update(f"{key}={value!r}".encode("utf-8")) return digest.hexdigest() def _get_prompt_cache(self, key: str) -> Optional[Dict[str, Any]]: hit = self._prompt_cache.get(key) if hit is None: return None self._prompt_cache.move_to_end(key) return hit def _put_prompt_cache(self, key: str, value: Dict[str, Any]) -> None: self._prompt_cache[key] = value self._prompt_cache.move_to_end(key) while len(self._prompt_cache) > self._prompt_cache_capacity: self._prompt_cache.popitem(last=False) def _tokenize(self, prompt: Union[str, List[str]], max_length: int) -> Dict[str, torch.Tensor]: if self.tokenizer is None: raise ValueError("Tokenizer is not loaded. Pass tokenizer=... or include ImageGen/tokenizer.") prompts = [prompt] if isinstance(prompt, str) else prompt pad_tok = self.tokenizer.pad_token or "<|endoftext|>" prompts = [p if p != "" else pad_tok for p in prompts] encoded = self.tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) return {k: v.to(self.device) for k, v in encoded.items()} def _ensure_local_text_embedding(self) -> torch.nn.Embedding: if self.local_text_embedding is not None: return self.local_text_embedding if self.model_dir is None: raise ValueError("No text_model is attached and no model_dir is available for local embeddings.") from safetensors.torch import load_file root = self.model_dir.parent state_path = root / "model.safetensors" if not state_path.exists(): raise ValueError(f"No text_model is attached and local embedding weights are missing: {state_path}") state = load_file(str(state_path), device="cpu") for key in ( "model.language_model.embed_tokens.weight", "language_model.embed_tokens.weight", "model.embed_tokens.weight", ): if key in state: dtype = next(self.adapter.parameters()).dtype self.local_text_embedding = torch.nn.Embedding.from_pretrained(state[key].to(dtype=dtype), freeze=True) self.local_text_embedding.to(device=self.device) return self.local_text_embedding raise KeyError("Could not find embed_tokens.weight in local model.safetensors.") def _resolve_sdxl_text_source(self) -> str: if self.model_dir is not None: local_text_stack = self.model_dir / "models" / "Phillnet-2-SDXL-TextEncoders" if ( local_text_stack.exists() and (local_text_stack / "tokenizer").exists() and (local_text_stack / "tokenizer_2").exists() and (local_text_stack / "text_encoder").exists() and (local_text_stack / "text_encoder_2").exists() ): return str(local_text_stack) backend = self.adapter.image_generator for attr in ("sdxl_text_encoder_model_name_or_path", "pretrained_unet_model_name_or_path", "vae_model_name_or_path"): value = getattr(backend, attr, None) if value: path = Path(str(value)) has_text_stack = ( path.exists() and (path / "tokenizer").exists() and (path / "tokenizer_2").exists() and (path / "text_encoder").exists() and (path / "text_encoder_2").exists() ) if has_text_stack: return str(path) return "stabilityai/sdxl-turbo" def _ensure_sdxl_text_stack(self) -> None: if self.sdxl_text_encoder is not None and self.sdxl_text_encoder_2 is not None: return from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer model_name = self._resolve_sdxl_text_source() self.sdxl_tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") self.sdxl_tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer_2") self.sdxl_text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder") self.sdxl_text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( model_name, subfolder="text_encoder_2", ) dtype = next(self.adapter.parameters()).dtype self.sdxl_text_encoder.to(device=self.device, dtype=dtype).eval() self.sdxl_text_encoder_2.to(device=self.device, dtype=dtype).eval() for module in (self.sdxl_text_encoder, self.sdxl_text_encoder_2): for param in module.parameters(): param.requires_grad_(False) def _encode_sdxl_prompt(self, prompt: Union[str, List[str]]) -> Dict[str, torch.Tensor]: self._ensure_sdxl_text_stack() prompts = [prompt] if isinstance(prompt, str) else prompt def encode(tokenizer: Any, encoder: torch.nn.Module) -> Any: tokens = tokenizer( prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) return encoder(tokens.input_ids.to(self.device), output_hidden_states=True) out_1 = encode(self.sdxl_tokenizer, self.sdxl_text_encoder) out_2 = encode(self.sdxl_tokenizer_2, self.sdxl_text_encoder_2) pooled = out_2.text_embeds if hasattr(out_2, "text_embeds") else out_2[0] return { "prompt_embeds": torch.cat([out_1.hidden_states[-2], out_2.hidden_states[-2]], dim=-1), "pooled_prompt_embeds": pooled, } @staticmethod def _tensor_to_pil(images: torch.Tensor) -> List[Image.Image]: images = images.detach().float().cpu().clamp(0, 1) if images.ndim == 3: images = images.unsqueeze(0) if images.shape[1] not in (1, 3, 4): raise ValueError(f"Expected image tensor [B,C,H,W], got {tuple(images.shape)}") if images.shape[1] == 1: images = images.repeat(1, 3, 1, 1) if images.shape[1] == 4: images = images[:, :3] images = (images.permute(0, 2, 3, 1).numpy() * 255).round().astype("uint8") return [Image.fromarray(image) for image in images] @staticmethod def _polish_image(image: Image.Image, strength: float = 0.22) -> Image.Image: strength = max(0.0, min(float(strength), 1.0)) if strength <= 0.0: return image base = image.convert("RGB") denoised = base.filter(ImageFilter.MedianFilter(size=3)) blended = Image.blend(base, denoised, strength) blended = ImageEnhance.Sharpness(blended).enhance(1.08) blended = ImageEnhance.Contrast(blended).enhance(1.03) return blended @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], *, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: Optional[int] = None, guidance_scale: float = 0.0, seed: Optional[int] = None, generation_strategy: str = "prior", refinement_steps: int = 2, quality_strength: float = 1.0, contract_strength: float = 0.0, contract_maps: Optional[torch.Tensor] = None, refiner_lora_strength: float = 0.0, latent_refiner_strength: float = 0.0, structure_prior_strength: float = 0.0, reference_pass_steps: int = 0, reference_latent_strength: float = 0.75, image_quality_polish: bool = False, image_quality_polish_strength: float = 0.22, output_type: str = "pil", return_dict: bool = True, **kwargs: Any, ) -> Union[ImageGenPipelineOutput, List[Image.Image], torch.Tensor]: steps = int(num_inference_steps or self.adapter.image_generator.default_inference_steps) original_prompt = prompt prompt = self._normalize_prompt(prompt) prompt_was_normalized = prompt != original_prompt encoded = self._tokenize(prompt, max_length=int(self.adapter.max_condition_tokens)) strategy = generation_strategy.lower().strip() use_memory = kwargs.get("use_memory", True) cache_key = self._cache_fingerprint( prompt, encoded, strategy=strategy, use_memory=use_memory, device=str(self.device), dtype=str(next(self.adapter.parameters()).dtype), text_model_attached=self.text_model is not None, ) cached_prompt = self._get_prompt_cache(cache_key) call_kwargs: Dict[str, Any] = { "attention_mask": encoded.get("attention_mask"), "height": height, "width": width, "steps": steps, "guidance_scale": guidance_scale, "seed": seed, **kwargs, } cache_hit = cached_prompt is not None if cached_prompt is not None and "inputs_embeds" in cached_prompt: call_kwargs["inputs_embeds"] = cached_prompt["inputs_embeds"].to(self.device) elif self.text_model is not None and hasattr(self.text_model, "get_input_embeddings"): input_embeddings = self.text_model.get_input_embeddings()(encoded["input_ids"]) call_kwargs["inputs_embeds"] = input_embeddings else: call_kwargs["inputs_embeds"] = self._ensure_local_text_embedding()(encoded["input_ids"]) if guidance_scale > 1.0: if negative_prompt is None: if isinstance(prompt, str): negative_prompt = "" else: negative_prompt = [""] * len(prompt) negative_prompt = self._normalize_prompt(negative_prompt) neg_encoded = self._tokenize(negative_prompt, max_length=int(self.adapter.max_condition_tokens)) neg_kwargs = { "attention_mask": neg_encoded.get("attention_mask"), "use_memory": False, } if self.text_model is not None and hasattr(self.text_model, "get_input_embeddings"): neg_input_embeddings = self.text_model.get_input_embeddings()(neg_encoded["input_ids"]) neg_kwargs["inputs_embeds"] = neg_input_embeddings else: neg_kwargs["inputs_embeds"] = self._ensure_local_text_embedding()(neg_encoded["input_ids"]) negative_conditioning = self.adapter.encode_inputs(**neg_kwargs) call_kwargs["negative_conditioning"] = negative_conditioning if strategy in {"prior", "text_prior", "condition"}: condition_kwargs = { "attention_mask": call_kwargs["attention_mask"], "height": height, "width": width, "refinement_steps": refinement_steps, "quality_strength": quality_strength, "contract_strength": contract_strength, "contract_maps": contract_maps, "refiner_lora_strength": refiner_lora_strength, "latent_refiner_strength": latent_refiner_strength, "structure_prior_strength": structure_prior_strength, "use_memory": use_memory, } if "inputs_embeds" in call_kwargs: condition_kwargs["inputs_embeds"] = call_kwargs["inputs_embeds"] else: condition_kwargs["input_ids"] = encoded["input_ids"] generated = self.adapter.condition_to_image(**condition_kwargs) elif strategy in {"diffusion", "latent_diffusion"}: if cached_prompt is not None and "sdxl_conditioning" in cached_prompt: call_kwargs["sdxl_conditioning"] = cached_prompt["sdxl_conditioning"] else: call_kwargs["sdxl_conditioning"] = self._encode_sdxl_prompt(prompt) if guidance_scale > 1.0: call_kwargs["negative_sdxl_conditioning"] = self._encode_sdxl_prompt(negative_prompt) reference_latents = None if int(reference_pass_steps) > 0: reference_kwargs = dict(call_kwargs) reference_kwargs["steps"] = int(reference_pass_steps) reference_latents = self.adapter.generate( **reference_kwargs, return_latents=True, quality_strength=quality_strength, contract_strength=contract_strength, contract_maps=contract_maps, latent_refiner_strength=latent_refiner_strength, structure_prior_strength=structure_prior_strength, ) generated = self.adapter.generate( **call_kwargs, quality_strength=quality_strength, contract_strength=contract_strength, contract_maps=contract_maps, init_latents=reference_latents, init_latent_strength=reference_latent_strength, latent_refiner_strength=latent_refiner_strength, structure_prior_strength=structure_prior_strength, ) else: raise ValueError( "generation_strategy must be 'prior' or 'diffusion', " f"got {generation_strategy!r}." ) self._put_prompt_cache( cache_key, { "inputs_embeds": call_kwargs["inputs_embeds"].detach(), **({"sdxl_conditioning": call_kwargs["sdxl_conditioning"]} if "sdxl_conditioning" in call_kwargs else {}), }, ) metadata = { "prompt": prompt, "original_prompt": original_prompt, "prompt_was_normalized": prompt_was_normalized, "prompt_cache_hit": cache_hit, "prompt_cache_key": cache_key, "prompt_cache_entries": len(self._prompt_cache), "use_memory": use_memory, "generation_strategy": strategy, "reference_pass_steps": int(reference_pass_steps), "reference_latent_strength": reference_latent_strength, "used_reference_latents": int(reference_pass_steps) > 0, "image_quality_polish": bool(image_quality_polish), "image_quality_polish_strength": image_quality_polish_strength if image_quality_polish else 0.0, } if output_type == "pt": return ImageGenPipelineOutput(images=[], tensors=generated, prompt=prompt, metadata=metadata) if return_dict else generated images = self._tensor_to_pil(generated) if image_quality_polish: images = [self._polish_image(image, image_quality_polish_strength) for image in images] return ImageGenPipelineOutput(images=images, tensors=generated, prompt=prompt, metadata=metadata) if return_dict else images