from typing import Any, Callable, Dict, List, Optional, Union import torch from diffusers import StableDiffusionImg2ImgPipeline from pipelines.twoStepPipeline import two_step_pipeline from util.commons import disable_safety_checker, download_image class Text2Img: def load(self, model_dir: str): self.pipe = two_step_pipeline.from_pretrained( model_dir, torch_dtype=torch.float16 ).to("cuda") self.pipe.enable_xformers_memory_efficient_attention() disable_safety_checker(self.pipe) @torch.inference_mode() def process( self, prompt: Union[str, List[str]] = None, modified_prompts: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, iteration: float = 3.0, ): return self.pipe.two_step_pipeline( prompt=prompt, modified_prompts=modified_prompts, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, iteration=iteration, ).images class Img2Img: def load(self, model_dir: str): self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_dir, torch_dtype=torch.float16 ).to("cuda") self.pipe.enable_xformers_memory_efficient_attention() disable_safety_checker(self.pipe) @torch.inference_mode() def process( self, prompt: List[str], imageUrl: str, negative_prompt: List[str], steps: int ): image = download_image(imageUrl) return self.pipe.__call__( prompt=prompt, image=image, strength=0.75, negative_prompt=negative_prompt, guidance_scale=7.5, num_images_per_prompt=1, num_inference_steps=steps, ).images