JBlitzar
geaojiger
3802079
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
import os
import re
import torch
from bert_vectorize import vectorize_text_with_bert
import time
import torchvision
from logger import save_grid_with_label
EXPERIMENT_DIRECTORY = "runs/run_3_jxa"
device = "mps" if torch.backends.mps.is_available() else "cpu"
try:
os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
except:
print("Skipping making directory, directory already exists")
net = UNet_conditional(num_classes=768)
net.to(device)
net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"),weights_only=True))
wrapper = DiffusionManager(net, device=device, noise_steps=1000)
wrapper.set_schedule(Schedule.LINEAR)
def generate_sample_save_images(prompt, amt=1):
path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
save_grid_with_label(torchvision.utils.make_grid(generated),prompt, path)
if __name__ == "__main__":
generate_sample_save_images(input("Prompt? "), 8)