Phillnet-2 / ImageGen /pipeline.py
ayjays132's picture
Upload 470 files
ad2ce18 verified
from __future__ import annotations
import hashlib
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image, ImageEnhance, ImageFilter
from transformers import AutoTokenizer
from .model import UniversalHFTextToImageAdapter
@dataclass
class ImageGenPipelineOutput:
images: List[Image.Image]
tensors: Optional[torch.Tensor] = None
conditioning: Optional[Any] = None
prompt: Optional[Union[str, List[str]]] = None
metadata: Dict[str, Any] = field(default_factory=dict)
class ImageGenPipeline:
"""
Hugging Face-style pipeline for the trained ImageGen adapter.
The pipeline preserves the adapter checkpoint exactly. It loads
`adapter_model.pt`, routes text through the supplied/shared Qwen text model
when available, and exposes a `DiffusionPipeline`-like `from_pretrained`
and `__call__` surface.
"""
config_name = "model_index.json"
def __init__(
self,
adapter: UniversalHFTextToImageAdapter,
tokenizer: Optional[Any] = None,
text_model: Optional[torch.nn.Module] = None,
model_dir: Optional[Union[str, Path]] = None,
):
self.adapter = adapter
self.tokenizer = tokenizer
self.text_model = text_model
self.model_dir = Path(model_dir) if model_dir is not None else None
self.sdxl_tokenizer = None
self.sdxl_tokenizer_2 = None
self.sdxl_text_encoder = None
self.sdxl_text_encoder_2 = None
self.local_text_embedding = None
self._prompt_cache: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
self._prompt_cache_capacity = 32
if self.text_model is not None:
self.adapter.text_model = self.text_model
self.adapter.freeze_text_model()
@classmethod
def from_pretrained(
cls,
model_dir: Union[str, Path],
*,
text_model: Optional[torch.nn.Module] = None,
tokenizer: Optional[Any] = None,
device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
local_files_only: bool = True,
**_: Any,
) -> "ImageGenPipeline":
model_path = Path(model_dir)
if tokenizer is None:
tokenizer_dir = model_path / "tokenizer"
if tokenizer_dir.exists():
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir,
use_fast=False,
local_files_only=local_files_only,
trust_remote_code=True,
)
adapter = UniversalHFTextToImageAdapter.from_pretrained(
model_path,
text_model=text_model,
device=device or "cpu",
)
if torch_dtype is not None:
adapter = adapter.to(dtype=torch_dtype)
return cls(adapter=adapter, tokenizer=tokenizer, text_model=text_model, model_dir=model_path)
def to(self, device: Union[str, torch.device], dtype: Optional[torch.dtype] = None) -> "ImageGenPipeline":
self.adapter.to(device=device)
if dtype is not None:
self.adapter.to(dtype=dtype)
if self.text_model is not None and hasattr(self.text_model, "to"):
self.text_model.to(device=device)
return self
@property
def device(self) -> torch.device:
return next(self.adapter.parameters()).device
@staticmethod
def _normalize_prompt_text(prompt: str) -> str:
text = re.sub(r"\s+", " ", str(prompt)).strip()
if not text:
return text
pieces = [part.strip() for part in re.split(r"\s*(?:,|;|\n)\s*", text) if part.strip()]
if len(pieces) <= 1:
return text
deduped: List[str] = []
seen: set[str] = set()
for piece in pieces:
key = re.sub(r"\s+", " ", piece).casefold()
if key not in seen:
seen.add(key)
deduped.append(piece)
return ", ".join(deduped)
def _normalize_prompt(self, prompt: Union[str, List[str]]) -> Union[str, List[str]]:
if isinstance(prompt, str):
return self._normalize_prompt_text(prompt)
return [self._normalize_prompt_text(item) for item in prompt]
def _cache_fingerprint(self, prompt: Union[str, List[str]], encoded: Dict[str, torch.Tensor], **parts: Any) -> str:
digest = hashlib.sha256()
digest.update(repr(prompt).encode("utf-8"))
for key in ("input_ids", "attention_mask"):
tensor = encoded.get(key)
if tensor is not None:
digest.update(key.encode("utf-8"))
digest.update(str(tuple(tensor.shape)).encode("ascii"))
digest.update(str(tensor.dtype).encode("ascii"))
digest.update(tensor.detach().cpu().contiguous().numpy().tobytes())
for key, value in sorted(parts.items()):
digest.update(f"{key}={value!r}".encode("utf-8"))
return digest.hexdigest()
def _get_prompt_cache(self, key: str) -> Optional[Dict[str, Any]]:
hit = self._prompt_cache.get(key)
if hit is None:
return None
self._prompt_cache.move_to_end(key)
return hit
def _put_prompt_cache(self, key: str, value: Dict[str, Any]) -> None:
self._prompt_cache[key] = value
self._prompt_cache.move_to_end(key)
while len(self._prompt_cache) > self._prompt_cache_capacity:
self._prompt_cache.popitem(last=False)
def _tokenize(self, prompt: Union[str, List[str]], max_length: int) -> Dict[str, torch.Tensor]:
if self.tokenizer is None:
raise ValueError("Tokenizer is not loaded. Pass tokenizer=... or include ImageGen/tokenizer.")
prompts = [prompt] if isinstance(prompt, str) else prompt
pad_tok = self.tokenizer.pad_token or "<|endoftext|>"
prompts = [p if p != "" else pad_tok for p in prompts]
encoded = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
return {k: v.to(self.device) for k, v in encoded.items()}
def _ensure_local_text_embedding(self) -> torch.nn.Embedding:
if self.local_text_embedding is not None:
return self.local_text_embedding
if self.model_dir is None:
raise ValueError("No text_model is attached and no model_dir is available for local embeddings.")
from safetensors.torch import load_file
root = self.model_dir.parent
state_path = root / "model.safetensors"
if not state_path.exists():
raise ValueError(f"No text_model is attached and local embedding weights are missing: {state_path}")
state = load_file(str(state_path), device="cpu")
for key in (
"model.language_model.embed_tokens.weight",
"language_model.embed_tokens.weight",
"model.embed_tokens.weight",
):
if key in state:
dtype = next(self.adapter.parameters()).dtype
self.local_text_embedding = torch.nn.Embedding.from_pretrained(state[key].to(dtype=dtype), freeze=True)
self.local_text_embedding.to(device=self.device)
return self.local_text_embedding
raise KeyError("Could not find embed_tokens.weight in local model.safetensors.")
def _resolve_sdxl_text_source(self) -> str:
if self.model_dir is not None:
local_text_stack = self.model_dir / "models" / "Phillnet-2-SDXL-TextEncoders"
if (
local_text_stack.exists()
and (local_text_stack / "tokenizer").exists()
and (local_text_stack / "tokenizer_2").exists()
and (local_text_stack / "text_encoder").exists()
and (local_text_stack / "text_encoder_2").exists()
):
return str(local_text_stack)
backend = self.adapter.image_generator
for attr in ("sdxl_text_encoder_model_name_or_path", "pretrained_unet_model_name_or_path", "vae_model_name_or_path"):
value = getattr(backend, attr, None)
if value:
path = Path(str(value))
has_text_stack = (
path.exists()
and (path / "tokenizer").exists()
and (path / "tokenizer_2").exists()
and (path / "text_encoder").exists()
and (path / "text_encoder_2").exists()
)
if has_text_stack:
return str(path)
return "stabilityai/sdxl-turbo"
def _ensure_sdxl_text_stack(self) -> None:
if self.sdxl_text_encoder is not None and self.sdxl_text_encoder_2 is not None:
return
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
model_name = self._resolve_sdxl_text_source()
self.sdxl_tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
self.sdxl_tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer_2")
self.sdxl_text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
self.sdxl_text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
model_name,
subfolder="text_encoder_2",
)
dtype = next(self.adapter.parameters()).dtype
self.sdxl_text_encoder.to(device=self.device, dtype=dtype).eval()
self.sdxl_text_encoder_2.to(device=self.device, dtype=dtype).eval()
for module in (self.sdxl_text_encoder, self.sdxl_text_encoder_2):
for param in module.parameters():
param.requires_grad_(False)
def _encode_sdxl_prompt(self, prompt: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
self._ensure_sdxl_text_stack()
prompts = [prompt] if isinstance(prompt, str) else prompt
def encode(tokenizer: Any, encoder: torch.nn.Module) -> Any:
tokens = tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return encoder(tokens.input_ids.to(self.device), output_hidden_states=True)
out_1 = encode(self.sdxl_tokenizer, self.sdxl_text_encoder)
out_2 = encode(self.sdxl_tokenizer_2, self.sdxl_text_encoder_2)
pooled = out_2.text_embeds if hasattr(out_2, "text_embeds") else out_2[0]
return {
"prompt_embeds": torch.cat([out_1.hidden_states[-2], out_2.hidden_states[-2]], dim=-1),
"pooled_prompt_embeds": pooled,
}
@staticmethod
def _tensor_to_pil(images: torch.Tensor) -> List[Image.Image]:
images = images.detach().float().cpu().clamp(0, 1)
if images.ndim == 3:
images = images.unsqueeze(0)
if images.shape[1] not in (1, 3, 4):
raise ValueError(f"Expected image tensor [B,C,H,W], got {tuple(images.shape)}")
if images.shape[1] == 1:
images = images.repeat(1, 3, 1, 1)
if images.shape[1] == 4:
images = images[:, :3]
images = (images.permute(0, 2, 3, 1).numpy() * 255).round().astype("uint8")
return [Image.fromarray(image) for image in images]
@staticmethod
def _polish_image(image: Image.Image, strength: float = 0.22) -> Image.Image:
strength = max(0.0, min(float(strength), 1.0))
if strength <= 0.0:
return image
base = image.convert("RGB")
denoised = base.filter(ImageFilter.MedianFilter(size=3))
blended = Image.blend(base, denoised, strength)
blended = ImageEnhance.Sharpness(blended).enhance(1.08)
blended = ImageEnhance.Contrast(blended).enhance(1.03)
return blended
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
*,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 512,
num_inference_steps: Optional[int] = None,
guidance_scale: float = 0.0,
seed: Optional[int] = None,
generation_strategy: str = "prior",
refinement_steps: int = 2,
quality_strength: float = 1.0,
contract_strength: float = 0.0,
contract_maps: Optional[torch.Tensor] = None,
refiner_lora_strength: float = 0.0,
latent_refiner_strength: float = 0.0,
structure_prior_strength: float = 0.0,
reference_pass_steps: int = 0,
reference_latent_strength: float = 0.75,
image_quality_polish: bool = False,
image_quality_polish_strength: float = 0.22,
output_type: str = "pil",
return_dict: bool = True,
**kwargs: Any,
) -> Union[ImageGenPipelineOutput, List[Image.Image], torch.Tensor]:
steps = int(num_inference_steps or self.adapter.image_generator.default_inference_steps)
original_prompt = prompt
prompt = self._normalize_prompt(prompt)
prompt_was_normalized = prompt != original_prompt
encoded = self._tokenize(prompt, max_length=int(self.adapter.max_condition_tokens))
strategy = generation_strategy.lower().strip()
use_memory = kwargs.get("use_memory", True)
cache_key = self._cache_fingerprint(
prompt,
encoded,
strategy=strategy,
use_memory=use_memory,
device=str(self.device),
dtype=str(next(self.adapter.parameters()).dtype),
text_model_attached=self.text_model is not None,
)
cached_prompt = self._get_prompt_cache(cache_key)
call_kwargs: Dict[str, Any] = {
"attention_mask": encoded.get("attention_mask"),
"height": height,
"width": width,
"steps": steps,
"guidance_scale": guidance_scale,
"seed": seed,
**kwargs,
}
cache_hit = cached_prompt is not None
if cached_prompt is not None and "inputs_embeds" in cached_prompt:
call_kwargs["inputs_embeds"] = cached_prompt["inputs_embeds"].to(self.device)
elif self.text_model is not None and hasattr(self.text_model, "get_input_embeddings"):
input_embeddings = self.text_model.get_input_embeddings()(encoded["input_ids"])
call_kwargs["inputs_embeds"] = input_embeddings
else:
call_kwargs["inputs_embeds"] = self._ensure_local_text_embedding()(encoded["input_ids"])
if guidance_scale > 1.0:
if negative_prompt is None:
if isinstance(prompt, str):
negative_prompt = ""
else:
negative_prompt = [""] * len(prompt)
negative_prompt = self._normalize_prompt(negative_prompt)
neg_encoded = self._tokenize(negative_prompt, max_length=int(self.adapter.max_condition_tokens))
neg_kwargs = {
"attention_mask": neg_encoded.get("attention_mask"),
"use_memory": False,
}
if self.text_model is not None and hasattr(self.text_model, "get_input_embeddings"):
neg_input_embeddings = self.text_model.get_input_embeddings()(neg_encoded["input_ids"])
neg_kwargs["inputs_embeds"] = neg_input_embeddings
else:
neg_kwargs["inputs_embeds"] = self._ensure_local_text_embedding()(neg_encoded["input_ids"])
negative_conditioning = self.adapter.encode_inputs(**neg_kwargs)
call_kwargs["negative_conditioning"] = negative_conditioning
if strategy in {"prior", "text_prior", "condition"}:
condition_kwargs = {
"attention_mask": call_kwargs["attention_mask"],
"height": height,
"width": width,
"refinement_steps": refinement_steps,
"quality_strength": quality_strength,
"contract_strength": contract_strength,
"contract_maps": contract_maps,
"refiner_lora_strength": refiner_lora_strength,
"latent_refiner_strength": latent_refiner_strength,
"structure_prior_strength": structure_prior_strength,
"use_memory": use_memory,
}
if "inputs_embeds" in call_kwargs:
condition_kwargs["inputs_embeds"] = call_kwargs["inputs_embeds"]
else:
condition_kwargs["input_ids"] = encoded["input_ids"]
generated = self.adapter.condition_to_image(**condition_kwargs)
elif strategy in {"diffusion", "latent_diffusion"}:
if cached_prompt is not None and "sdxl_conditioning" in cached_prompt:
call_kwargs["sdxl_conditioning"] = cached_prompt["sdxl_conditioning"]
else:
call_kwargs["sdxl_conditioning"] = self._encode_sdxl_prompt(prompt)
if guidance_scale > 1.0:
call_kwargs["negative_sdxl_conditioning"] = self._encode_sdxl_prompt(negative_prompt)
reference_latents = None
if int(reference_pass_steps) > 0:
reference_kwargs = dict(call_kwargs)
reference_kwargs["steps"] = int(reference_pass_steps)
reference_latents = self.adapter.generate(
**reference_kwargs,
return_latents=True,
quality_strength=quality_strength,
contract_strength=contract_strength,
contract_maps=contract_maps,
latent_refiner_strength=latent_refiner_strength,
structure_prior_strength=structure_prior_strength,
)
generated = self.adapter.generate(
**call_kwargs,
quality_strength=quality_strength,
contract_strength=contract_strength,
contract_maps=contract_maps,
init_latents=reference_latents,
init_latent_strength=reference_latent_strength,
latent_refiner_strength=latent_refiner_strength,
structure_prior_strength=structure_prior_strength,
)
else:
raise ValueError(
"generation_strategy must be 'prior' or 'diffusion', "
f"got {generation_strategy!r}."
)
self._put_prompt_cache(
cache_key,
{
"inputs_embeds": call_kwargs["inputs_embeds"].detach(),
**({"sdxl_conditioning": call_kwargs["sdxl_conditioning"]} if "sdxl_conditioning" in call_kwargs else {}),
},
)
metadata = {
"prompt": prompt,
"original_prompt": original_prompt,
"prompt_was_normalized": prompt_was_normalized,
"prompt_cache_hit": cache_hit,
"prompt_cache_key": cache_key,
"prompt_cache_entries": len(self._prompt_cache),
"use_memory": use_memory,
"generation_strategy": strategy,
"reference_pass_steps": int(reference_pass_steps),
"reference_latent_strength": reference_latent_strength,
"used_reference_latents": int(reference_pass_steps) > 0,
"image_quality_polish": bool(image_quality_polish),
"image_quality_polish_strength": image_quality_polish_strength if image_quality_polish else 0.0,
}
if output_type == "pt":
return ImageGenPipelineOutput(images=[], tensors=generated, prompt=prompt, metadata=metadata) if return_dict else generated
images = self._tensor_to_pil(generated)
if image_quality_polish:
images = [self._polish_image(image, image_quality_polish_strength) for image in images]
return ImageGenPipelineOutput(images=images, tensors=generated, prompt=prompt, metadata=metadata) if return_dict else images