HPSv3 / hpsv3 /cohp /generator.py
sdsdgwe's picture
update
9b57ce7
import torch
import os
import inspect
from PIL import Image
from tqdm import tqdm
from utils_cohp.utils import init_pipelines
Image.MAX_IMAGE_PIXELS = None
class Generator:
def __init__(
self, pipe_name, pipe_type, pipe_init_kwargs, device=None
):
self.pipe_names = pipe_name
self.pipe_type = pipe_type
self.pipe_init_kwargs = pipe_init_kwargs
self.pipelines = init_pipelines(
pipe_name, pipe_init_kwargs, device
)
def generate_imgs(
self,
batch_size,
generation_path,
info_dict,
device,
weight_dtype,
seed,
generation_kwargs,
):
torch.cuda.set_device(device)
device = torch.device(device)
generator = torch.Generator().manual_seed(seed)
pipeline_signature = inspect.signature(self.pipelines)
pipeline_params = pipeline_signature.parameters.keys()
if 'height' not in pipeline_params:
generation_kwargs.pop('height', None)
print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
if 'width' not in pipeline_params:
generation_kwargs.pop('width', None)
print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
outputs = self.pipelines(
prompt=info_dict['caption'], generator=generator,num_images_per_prompt = batch_size, **generation_kwargs
)
if self.pipe_type == "t2i":
images = outputs.images
elif self.pipe_type == "t2v":
images = outputs.frames[0]
image_paths = []
for idx, image in enumerate(images):
img_path = os.path.join(
generation_path, info_dict["save_name"] + f"_{idx}.png"
)
os.makedirs(generation_path,exist_ok=True)
image.save(img_path)
image_paths.append(img_path)
return image_paths