| import torch | |
| import os | |
| import inspect | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from utils_cohp.utils import init_pipelines | |
| Image.MAX_IMAGE_PIXELS = None | |
| class Generator: | |
| def __init__( | |
| self, pipe_name, pipe_type, pipe_init_kwargs, device=None | |
| ): | |
| self.pipe_names = pipe_name | |
| self.pipe_type = pipe_type | |
| self.pipe_init_kwargs = pipe_init_kwargs | |
| self.pipelines = init_pipelines( | |
| pipe_name, pipe_init_kwargs, device | |
| ) | |
| def generate_imgs( | |
| self, | |
| batch_size, | |
| generation_path, | |
| info_dict, | |
| device, | |
| weight_dtype, | |
| seed, | |
| generation_kwargs, | |
| ): | |
| torch.cuda.set_device(device) | |
| device = torch.device(device) | |
| generator = torch.Generator().manual_seed(seed) | |
| pipeline_signature = inspect.signature(self.pipelines) | |
| pipeline_params = pipeline_signature.parameters.keys() | |
| if 'height' not in pipeline_params: | |
| generation_kwargs.pop('height', None) | |
| print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs") | |
| if 'width' not in pipeline_params: | |
| generation_kwargs.pop('width', None) | |
| print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs") | |
| outputs = self.pipelines( | |
| prompt=info_dict['caption'], generator=generator,num_images_per_prompt = batch_size, **generation_kwargs | |
| ) | |
| if self.pipe_type == "t2i": | |
| images = outputs.images | |
| elif self.pipe_type == "t2v": | |
| images = outputs.frames[0] | |
| image_paths = [] | |
| for idx, image in enumerate(images): | |
| img_path = os.path.join( | |
| generation_path, info_dict["save_name"] + f"_{idx}.png" | |
| ) | |
| os.makedirs(generation_path,exist_ok=True) | |
| image.save(img_path) | |
| image_paths.append(img_path) | |
| return image_paths |