jayparmr's picture
Upload 18 files
4adca93
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from pipelines.twoStepPipeline import two_step_pipeline
from util.commons import disable_safety_checker, download_image
class Text2Img:
def load(self, model_dir: str):
self.pipe = two_step_pipeline.from_pretrained(
model_dir, torch_dtype=torch.float16
).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
disable_safety_checker(self.pipe)
@torch.inference_mode()
def process(
self,
prompt: Union[str, List[str]] = None,
modified_prompts: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
iteration: float = 3.0,
):
return self.pipe.two_step_pipeline(
prompt=prompt,
modified_prompts=modified_prompts,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
iteration=iteration,
).images
class Img2Img:
def load(self, model_dir: str):
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_dir, torch_dtype=torch.float16
).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
disable_safety_checker(self.pipe)
@torch.inference_mode()
def process(
self, prompt: List[str], imageUrl: str, negative_prompt: List[str], steps: int
):
image = download_image(imageUrl)
return self.pipe.__call__(
prompt=prompt,
image=image,
strength=0.75,
negative_prompt=negative_prompt,
guidance_scale=7.5,
num_images_per_prompt=1,
num_inference_steps=steps,
).images