from factories import UNet_conditional from wrapper import DiffusionManager, Schedule import os import re import torch from bert_vectorize import vectorize_text_with_bert import time import torchvision from logger import save_grid_with_label EXPERIMENT_DIRECTORY = "runs/run_3_jxa" device = "mps" if torch.backends.mps.is_available() else "cpu" try: os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred")) except: print("Skipping making directory, directory already exists") net = UNet_conditional(num_classes=768) net.to(device) net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"),weights_only=True)) wrapper = DiffusionManager(net, device=device, noise_steps=1000) wrapper.set_schedule(Schedule.LINEAR) def generate_sample_save_images(prompt, amt=1): path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png") vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu() save_grid_with_label(torchvision.utils.make_grid(generated),prompt, path) if __name__ == "__main__": generate_sample_save_images(input("Prompt? "), 8)