Spaces:
Runtime error
Runtime error
| from factories import UNet_conditional | |
| from wrapper import DiffusionManager, Schedule | |
| import os | |
| import re | |
| import torch | |
| from bert_vectorize import vectorize_text_with_bert, cleanup | |
| import time | |
| import torchvision | |
| from logger import save_grid_with_label | |
| from clip_score import select_top_n_images | |
| from torchinfo import summary | |
| EXPERIMENT_DIRECTORY = "runs/run_3_jxa_resumed" | |
| device ="cpu" | |
| try: | |
| os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred")) | |
| except: | |
| print("Skipping making directory, directory already exists") | |
| net = UNet_conditional(num_classes=768) | |
| net.to(device) | |
| net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"))) | |
| def count_parameters(model): | |
| return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item() | |
| print(f"Parameters: {count_parameters(net)}") | |
| wrapper = DiffusionManager(net, device=device, noise_steps=1000) | |
| wrapper.set_schedule(Schedule.LINEAR) | |
| def infer(prompt, amt=1, topn=8): | |
| path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png") | |
| vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) | |
| generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu() | |
| generated, _ = select_top_n_images(generated, prompt, n=topn) | |
| save_grid_with_label(torchvision.utils.make_grid(generated),prompt + f"({topn} best of {amt})", path) | |
| def run_jobs(): | |
| n=8 | |
| bestof=32 | |
| print(f"using best {bestof} of {n}") | |
| processed_tasks = set() | |
| def read_jobs(): | |
| try: | |
| with open("inference_jobs.txt", 'r') as file: | |
| tasks = file.readlines() | |
| return [task.strip() for task in tasks] | |
| except FileNotFoundError: | |
| return [] | |
| tasks = read_jobs() | |
| new_tasks = [task for task in tasks if task not in processed_tasks] | |
| while new_tasks: | |
| if new_tasks: | |
| for task in new_tasks: | |
| infer(task, n,bestof) | |
| processed_tasks.add(task) | |
| tasks = read_jobs() | |
| new_tasks = [task for task in tasks if task not in processed_tasks] | |
| cleanup() | |
| if __name__ == "__main__": | |
| #infer(input("Prompt? "), 8) | |
| run_jobs() |