Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Optional, Union | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import torchvision.transforms as T | |
| from .hunyuanimage_pipeline import HunyuanImagePipeline, HunyuanImagePipelineConfig | |
| from hyimage.models.model_zoo import ( | |
| HUNYUANIMAGE_REFINER_DIT, | |
| HUNYUANIMAGE_REFINER_VAE_32x, | |
| HUNYUANIMAGE_REFINER_TEXT_ENCODER, | |
| ) | |
| class HunYuanImageRefinerPipelineConfig(HunyuanImagePipelineConfig): | |
| """ | |
| Configuration class for HunyuanImage refiner pipeline. | |
| Inherits from HunyuanImagePipelineConfig and overrides specific parameters | |
| for the refiner functionality. | |
| """ | |
| default_sampling_steps: int = 4 | |
| shift: int = 1 | |
| version: str = "v1.0" | |
| cfg_mode: str = "" | |
| def create_default( | |
| cls, | |
| version: str = "v1.0", | |
| use_distilled: bool = False, | |
| **kwargs, | |
| ): | |
| dit_config = HUNYUANIMAGE_REFINER_DIT() | |
| vae_config = HUNYUANIMAGE_REFINER_VAE_32x() | |
| text_encoder_config = HUNYUANIMAGE_REFINER_TEXT_ENCODER() | |
| return cls( | |
| dit_config=dit_config, | |
| vae_config=vae_config, | |
| text_encoder_config=text_encoder_config, | |
| reprompt_config=None, | |
| version=version, | |
| **kwargs, | |
| ) | |
| class HunYuanImageRefinerPipeline(HunyuanImagePipeline): | |
| """A refiner pipeline for HunyuanImage that inherits from the main pipeline. | |
| This pipeline refines existing images using the same model architecture | |
| but with different default parameters and an image input. | |
| """ | |
| def __init__(self, config: HunYuanImageRefinerPipelineConfig, **kwargs): | |
| """Initialize the refiner pipeline. | |
| Args: | |
| config: Refiner-specific configuration | |
| **kwargs: Additional arguments passed to parent class | |
| """ | |
| assert isinstance(config, HunYuanImageRefinerPipelineConfig) | |
| super().__init__(config, **kwargs) | |
| assert self.cfg_distilled | |
| def _condition_aug(self, latents, noise=None, strength=0.3): | |
| """Apply conditioning augmentation for refiner. | |
| Args: | |
| latents: Input latents tensor | |
| noise: Optional noise tensor, if None will be generated | |
| strength: Augmentation strength factor | |
| Returns: | |
| Augmented latents tensor | |
| """ | |
| if noise is None: | |
| noise = torch.randn_like(latents) | |
| return strength * noise + (1 - strength) * latents | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| width: int = 2048, | |
| height: int = 2048, | |
| use_reprompt: bool = False, | |
| num_inference_steps: Optional[int] = None, | |
| guidance_scale: Optional[float] = None, | |
| shift: int = 4, | |
| seed: Optional[int] = 42, | |
| image: Optional[Image.Image] = None, | |
| **kwargs, | |
| ) -> Image.Image: | |
| """Refine an existing image using text guidance. | |
| Args: | |
| prompt: Text prompt describing the desired refinement | |
| negative_prompt: Negative prompt for guidance | |
| width: Image width | |
| height: Image height | |
| use_reprompt: Whether to use reprompt (ignored for refiner) | |
| num_inference_steps: Number of denoising steps (overrides config if provided) | |
| guidance_scale: Strength of classifier-free guidance (overrides config if provided) | |
| seed: Random seed for reproducibility | |
| image: Image to be refined (required for refiner) | |
| **kwargs: Additional arguments | |
| Returns: | |
| Refined PIL Image | |
| """ | |
| if image is None: | |
| raise ValueError("Image parameter is required for refiner pipeline") | |
| if seed is not None: | |
| generator = torch.Generator(device='cpu').manual_seed(seed) | |
| else: | |
| generator = None | |
| sampling_steps = ( | |
| num_inference_steps | |
| if num_inference_steps is not None | |
| else self.default_sampling_steps | |
| ) | |
| guidance_scale = ( | |
| guidance_scale if guidance_scale is not None else self.default_guidance_scale | |
| ) | |
| shift = shift if shift is not None else self.shift | |
| # Print log about current refinement task | |
| print("=" * 60) | |
| print("🔧 HunyuanImage Refinement Task") | |
| print("-" * 60) | |
| print(f"Prompt: {prompt}") | |
| print(f"Guidance Scale: {guidance_scale}") | |
| print(f"Shift: {self.shift}") | |
| print(f"Seed: {seed}") | |
| print(f"Image Size: {width} x {height}") | |
| print(f"Sampling Steps: {sampling_steps}") | |
| print("=" * 60) | |
| # Encode prompts | |
| pos_text_emb, pos_text_mask = self._encode_text(prompt) | |
| latents = self._prepare_latents(width, height, generator=generator) | |
| _pil_to_tensor = T.Compose( | |
| [ | |
| T.ToTensor(), # convert to tensor and normalize to [0, 1] | |
| T.Normalize([0.5], [0.5]), # transform to [-1, 1] | |
| ] | |
| ) | |
| image_tensor = ( | |
| _pil_to_tensor(image).unsqueeze(0).to("cuda", dtype=self.vae.dtype) | |
| ) | |
| cond_latents = self.vae.encode( | |
| image_tensor.to(self.device, dtype=self.vae.dtype) | |
| ).latent_dist.sample() | |
| if ( | |
| hasattr(self.vae.config, "shift_factor") | |
| and self.vae.config.shift_factor | |
| ): | |
| cond_latents.sub_(self.vae.config.shift_factor).mul_( | |
| self.vae.config.scaling_factor | |
| ) | |
| else: | |
| cond_latents.mul_(self.vae.config.scaling_factor) | |
| # Add frame dimension for refiner model | |
| cond_latents = cond_latents.unsqueeze(2) # (b c 1 h w) | |
| # Apply conditioning augmentation | |
| cond_latents = self._condition_aug(cond_latents) | |
| timesteps, sigmas = self.get_timesteps_sigmas(sampling_steps, shift) | |
| text_emb = pos_text_emb | |
| text_mask = pos_text_mask | |
| for i, t in enumerate(tqdm(timesteps, desc="Refining", total=len(timesteps))): | |
| # Concatenate noise latents with condition latents for refiner input | |
| latent_model_input = torch.cat([latents, cond_latents], dim=1) | |
| t_expand = t.repeat(latent_model_input.shape[0]) | |
| # Predict noise with guidance | |
| noise_pred = self._denoise_step( | |
| latent_model_input, | |
| t_expand, | |
| text_emb, | |
| text_mask, | |
| None, | |
| None, | |
| guidance_scale, | |
| timesteps_r=None, | |
| ) | |
| latents = self.step(latents, noise_pred, sigmas, i) | |
| refined_image = self._decode_latents(latents) | |
| # Convert to PIL Image | |
| refined_image = (refined_image.squeeze(0).permute(1, 2, 0) * 255).byte().numpy() | |
| pil_image = Image.fromarray(refined_image) | |
| return pil_image | |
| def from_pretrained( | |
| cls, | |
| model_name: str = "hunyuanimage-refiner", | |
| use_distilled: bool = False, | |
| **kwargs, | |
| ): | |
| """Create refiner pipeline from pretrained model. | |
| Args: | |
| model_name: Model name, currently only supports "hunyuanimage-refiner" | |
| use_distilled: Whether to use distilled model (unused for refiner) | |
| **kwargs: Additional configuration options | |
| """ | |
| if model_name == "hunyuanimage-refiner": | |
| version = "v1.0" | |
| else: | |
| raise ValueError( | |
| f"Unsupported refiner model name: {model_name}. Supported names: 'hunyuanimage-refiner'" | |
| ) | |
| config = HunYuanImageRefinerPipelineConfig.create_default( | |
| version=version, **kwargs | |
| ) | |
| return cls(config=config) | |
| def from_config(cls, config: Union[HunYuanImageRefinerPipelineConfig, HunyuanImagePipelineConfig]): | |
| """Create refiner pipeline from configuration object. | |
| Args: | |
| config: Configuration object for the pipeline | |
| Returns: | |
| Initialized refiner pipeline instance | |
| """ | |
| return cls(config=config) | |
| # Convenience function for easy access | |
| def RefinerPipeline( | |
| model_name: str = "hunyuanimage-refiner", | |
| **kwargs, | |
| ): | |
| """Factory function to create HunYuanImageRefinerPipeline. | |
| Args: | |
| model_name: Model name, currently only supports "hunyuanimage-refiner" | |
| **kwargs: Additional configuration options | |
| Returns: | |
| Initialized refiner pipeline instance | |
| """ | |
| return HunYuanImageRefinerPipeline.from_pretrained( | |
| model_name, **kwargs | |
| ) | |