|
|
from factories import UNet_conditional |
|
|
from wrapper import DiffusionManager, Schedule |
|
|
import os |
|
|
import re |
|
|
import torch |
|
|
from bert_vectorize import vectorize_text_with_bert, cleanup |
|
|
import time |
|
|
import torchvision |
|
|
from logger import save_grid_with_label |
|
|
from clip_score import select_top_n_images |
|
|
from torchinfo import summary |
|
|
|
|
|
|
|
|
|
|
|
EXPERIMENT_DIRECTORY = "runs/run_3_jxa_resumed" |
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
try: |
|
|
os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred")) |
|
|
except: |
|
|
print("Skipping making directory, directory already exists") |
|
|
|
|
|
net = UNet_conditional(num_classes=768) |
|
|
net.to(device) |
|
|
net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"))) |
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
|
return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item() |
|
|
print(f"Parameters: {count_parameters(net)}") |
|
|
|
|
|
|
|
|
|
|
|
wrapper = DiffusionManager(net, device=device, noise_steps=1000) |
|
|
wrapper.set_schedule(Schedule.LINEAR) |
|
|
|
|
|
|
|
|
def infer(prompt, amt=1, topn=8): |
|
|
|
|
|
path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png") |
|
|
|
|
|
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) |
|
|
|
|
|
generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu() |
|
|
|
|
|
generated, _ = select_top_n_images(generated, prompt, n=topn) |
|
|
|
|
|
save_grid_with_label(torchvision.utils.make_grid(generated),prompt + f"({topn} best of {amt})", path) |
|
|
|
|
|
|
|
|
def run_jobs(): |
|
|
n=8 |
|
|
bestof=32 |
|
|
print(f"using best {bestof} of {n}") |
|
|
processed_tasks = set() |
|
|
def read_jobs(): |
|
|
try: |
|
|
with open("inference_jobs.txt", 'r') as file: |
|
|
tasks = file.readlines() |
|
|
return [task.strip() for task in tasks] |
|
|
except FileNotFoundError: |
|
|
return [] |
|
|
|
|
|
tasks = read_jobs() |
|
|
new_tasks = [task for task in tasks if task not in processed_tasks] |
|
|
while new_tasks: |
|
|
|
|
|
|
|
|
if new_tasks: |
|
|
for task in new_tasks: |
|
|
infer(task, n,bestof) |
|
|
processed_tasks.add(task) |
|
|
tasks = read_jobs() |
|
|
new_tasks = [task for task in tasks if task not in processed_tasks] |
|
|
|
|
|
cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
run_jobs() |