|
|
import random |
|
|
import torch |
|
|
import json |
|
|
|
|
|
class RandomNumberIterator: |
|
|
def __init__(self, model, scale, batch_size, n_classes=1000): |
|
|
self.model = model |
|
|
self.scale = scale |
|
|
self.batch_size = batch_size |
|
|
self.n_classes = n_classes |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
label = torch.LongTensor([random.randint(0, self.n_classes - 1) for _ in range(self.batch_size)]).to(self.model.device) |
|
|
conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label}) |
|
|
if self.scale != 1.0: |
|
|
conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)}) |
|
|
else: |
|
|
conditioned_unconditioning = None |
|
|
|
|
|
return conditioning, conditioned_unconditioning |
|
|
|
|
|
class UniformNumberIterator: |
|
|
def __init__(self, model, scale, batch_size, num_samples_per_class, n_classes=1000): |
|
|
self.model = model |
|
|
self.scale = scale |
|
|
self.batch_size = batch_size |
|
|
self.num_samples_per_class = num_samples_per_class |
|
|
self.n_classes = n_classes |
|
|
self.current_value = 0 |
|
|
self.current_num_cls_sample = 0 |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
|
|
|
batch = [self.current_value] * self.batch_size |
|
|
self.current_num_cls_sample += self.batch_size |
|
|
if self.current_num_cls_sample >= self.num_samples_per_class: |
|
|
|
|
|
self.current_value = (self.current_value + 1) % self.n_classes |
|
|
self.current_num_cls_sample = 0 |
|
|
|
|
|
label = torch.LongTensor(batch).to(self.model.device) |
|
|
conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label}) |
|
|
if self.scale != 1.0: |
|
|
conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)}) |
|
|
else: |
|
|
conditioned_unconditioning = None |
|
|
|
|
|
return conditioning, conditioned_unconditioning |
|
|
|
|
|
class TextFileIterator: |
|
|
def __init__(self, model, scale, file_path, batch_size, max_prompts=None, n_samples_per_prompt=1): |
|
|
self.model = model |
|
|
self.scale = scale |
|
|
self.unconditional_conditioning = self.model.get_learned_conditioning([""]) |
|
|
|
|
|
self.file_path = file_path |
|
|
self.batch_size = batch_size |
|
|
self.max_prompts = max_prompts |
|
|
self.n_samples_per_prompt = n_samples_per_prompt |
|
|
self.prompt_index = 0 |
|
|
self.prompts = self._load_prompts() |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if self.prompt_index >= len(self.prompts): |
|
|
raise StopIteration |
|
|
|
|
|
batch_prompts = self.prompts[self.prompt_index:self.prompt_index + self.batch_size] |
|
|
self.prompt_index += len(batch_prompts) |
|
|
|
|
|
conditioning = self.model.get_learned_conditioning(batch_prompts) |
|
|
conditioned_unconditioning = self.unconditional_conditioning.repeat(len(batch_prompts), 1, 1) |
|
|
return conditioning, conditioned_unconditioning |
|
|
|
|
|
|
|
|
def _load_prompts(self): |
|
|
try: |
|
|
prompts = [] |
|
|
if self.file_path.endswith('json'): |
|
|
with open(self.file_path, 'r', encoding='utf-8') as file: |
|
|
mscoco_data = json.load(file) |
|
|
for annotation in mscoco_data['annotations']: |
|
|
prompts.append(annotation['caption']) |
|
|
else: |
|
|
for prompt in open(self.file_path): |
|
|
prompts = [prompt.strip() for prompt in open(self.file_path)] |
|
|
if self.max_prompts is not None: |
|
|
prompts = prompts[:self.max_prompts] |
|
|
prompts = [prompt for prompt in prompts for _ in range(self.n_samples_per_prompt)] |
|
|
return prompts |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"File not found: {self.file_path}") |
|
|
return [] |
|
|
except IOError as e: |
|
|
print(f"Error reading file {self.file_path}: {e}") |
|
|
return [] |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Error decoding JSON in file {self.file_path}: {e}") |
|
|
return [] |
|
|
|