File size: 2,023 Bytes
9b57ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import os
import inspect
from PIL import Image
from tqdm import tqdm
from utils_cohp.utils import init_pipelines

Image.MAX_IMAGE_PIXELS = None


class Generator:
    def __init__(
        self, pipe_name, pipe_type, pipe_init_kwargs, device=None
    ):
        self.pipe_names = pipe_name
        self.pipe_type = pipe_type
        self.pipe_init_kwargs = pipe_init_kwargs
        self.pipelines = init_pipelines(
            pipe_name, pipe_init_kwargs, device
        )

    def generate_imgs(
        self,
        batch_size,
        generation_path,
        info_dict,
        device,
        weight_dtype,
        seed,
        generation_kwargs,

    ):

        torch.cuda.set_device(device)
        device = torch.device(device)
        generator = torch.Generator().manual_seed(seed)
        
        pipeline_signature = inspect.signature(self.pipelines)
        pipeline_params = pipeline_signature.parameters.keys()
        
        if 'height' not in pipeline_params:
            generation_kwargs.pop('height', None)
            print(f"Warning: Pipeline does not support 'height' parameter, removing from kwargs")
        if 'width' not in pipeline_params:
            generation_kwargs.pop('width', None)
            print(f"Warning: Pipeline does not support 'width' parameter, removing from kwargs")
        

        outputs = self.pipelines(
            prompt=info_dict['caption'], generator=generator,num_images_per_prompt = batch_size, **generation_kwargs
        )
        if self.pipe_type == "t2i":
            images = outputs.images
        elif self.pipe_type == "t2v":
            images = outputs.frames[0]
        image_paths = []
        for idx, image in enumerate(images):            
            img_path = os.path.join(
                generation_path, info_dict["save_name"] + f"_{idx}.png"
            ) 
            os.makedirs(generation_path,exist_ok=True)         
            image.save(img_path)
            image_paths.append(img_path)
        return image_paths