text-to-image-model / runner.py
JBlitzar
try
fc9acd0
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
import os
import re
import torch
from bert_vectorize import vectorize_text_with_bert, cleanup
import time
import torchvision
from logger import save_grid_with_label
from clip_score import select_top_n_images
from torchinfo import summary
EXPERIMENT_DIRECTORY = "runs/run_3_jxa_resumed"
device = "mps" if torch.backends.mps.is_available() else "cpu"
try:
os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
except:
print("Skipping making directory, directory already exists")
net = UNet_conditional(num_classes=768)
net.to(device)
net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt")))
def count_parameters(model):
return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item()
print(f"Parameters: {count_parameters(net)}")
wrapper = DiffusionManager(net, device=device, noise_steps=1000)
wrapper.set_schedule(Schedule.LINEAR)
def infer(prompt, amt=1, topn=8):
path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
generated, _ = select_top_n_images(generated, prompt, n=topn)
save_grid_with_label(torchvision.utils.make_grid(generated),prompt + f"({topn} best of {amt})", path)
def run_jobs():
n=8
bestof=32
print(f"using best {bestof} of {n}")
processed_tasks = set()
def read_jobs():
try:
with open("inference_jobs.txt", 'r') as file:
tasks = file.readlines()
return [task.strip() for task in tasks]
except FileNotFoundError:
return []
tasks = read_jobs()
new_tasks = [task for task in tasks if task not in processed_tasks]
while new_tasks:
if new_tasks:
for task in new_tasks:
infer(task, n,bestof)
processed_tasks.add(task)
tasks = read_jobs()
new_tasks = [task for task in tasks if task not in processed_tasks]
cleanup()
if __name__ == "__main__":
#infer(input("Prompt? "), 8)
run_jobs()