File size: 4,401 Bytes
d382778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import random
import torch
import json 

class RandomNumberIterator:
    def __init__(self, model, scale, batch_size, n_classes=1000):
        self.model = model
        self.scale = scale
        self.batch_size = batch_size
        self.n_classes = n_classes
    
    def __iter__(self):
        return self
    
    def __next__(self):
        label = torch.LongTensor([random.randint(0, self.n_classes - 1) for _ in range(self.batch_size)]).to(self.model.device)
        conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label})
        if self.scale != 1.0:
            conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)})
        else:
            conditioned_unconditioning = None

        return conditioning, conditioned_unconditioning
    
class UniformNumberIterator:
    def __init__(self, model, scale, batch_size, num_samples_per_class, n_classes=1000):
        self.model = model
        self.scale = scale
        self.batch_size = batch_size
        self.num_samples_per_class = num_samples_per_class
        self.n_classes = n_classes
        self.current_value = 0 
        self.current_num_cls_sample = 0

    def __iter__(self):
        return self
    
    def __next__(self):
        # Prepare the batch with the current value
        batch = [self.current_value] * self.batch_size
        self.current_num_cls_sample += self.batch_size
        if self.current_num_cls_sample >= self.num_samples_per_class:
            # Update the current value, cycling through 0 to 1000
            self.current_value = (self.current_value + 1) % self.n_classes
            self.current_num_cls_sample = 0 

        label = torch.LongTensor(batch).to(self.model.device)
        conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label})
        if self.scale != 1.0:
            conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)})
        else:
            conditioned_unconditioning = None

        return conditioning, conditioned_unconditioning
    
class TextFileIterator:
    def __init__(self, model, scale, file_path, batch_size, max_prompts=None, n_samples_per_prompt=1):
        self.model = model
        self.scale = scale
        self.unconditional_conditioning = self.model.get_learned_conditioning([""])

        self.file_path = file_path
        self.batch_size = batch_size
        self.max_prompts = max_prompts
        self.n_samples_per_prompt = n_samples_per_prompt
        self.prompt_index = 0
        self.prompts = self._load_prompts()

    def __iter__(self):
        return self

    def __next__(self):
        if self.prompt_index >= len(self.prompts):
            raise StopIteration

        batch_prompts = self.prompts[self.prompt_index:self.prompt_index + self.batch_size]
        self.prompt_index += len(batch_prompts)

        conditioning = self.model.get_learned_conditioning(batch_prompts)
        conditioned_unconditioning = self.unconditional_conditioning.repeat(len(batch_prompts), 1, 1)
        return conditioning, conditioned_unconditioning

        
    def _load_prompts(self):
        try:
            prompts = []
            if self.file_path.endswith('json'):
                with open(self.file_path, 'r', encoding='utf-8') as file:
                    mscoco_data = json.load(file)
                    for annotation in mscoco_data['annotations']:
                        prompts.append(annotation['caption'])
            else:
                for prompt in open(self.file_path): 
                    prompts = [prompt.strip() for prompt in open(self.file_path)]
            if self.max_prompts is not None:
                prompts = prompts[:self.max_prompts]
            prompts = [prompt for prompt in prompts for _ in range(self.n_samples_per_prompt)]
            return prompts

        except FileNotFoundError:
            print(f"File not found: {self.file_path}")
            return []
        except IOError as e:
            print(f"Error reading file {self.file_path}: {e}")
            return []
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON in file {self.file_path}: {e}")
            return []