text-to-image-model / predict.py
JBlitzar
commit
f86c7c7
# Prediction interface for Cog ⚙️
# https://cog.run/python
from cog import BasePredictor, Input, Path
import os
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
import torch
import re
from bert_vectorize import vectorize_text_with_bert
from logger import save_grid_with_label
import torchvision
import time
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# self.model = torch.load("./weights.pth")
# Initialize model, diffusion manager, and set up environment
device = "cpu"
model_dir = "runs/run_3_jxa"
self.device = device
self.model_dir = model_dir
# Create directories if they do not exist
os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
# Load model
self.net = UNet_conditional(num_classes=768,device=device)
self.net.to(self.device)
self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False))
# Set up DiffusionManager
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
self.wrapper.set_schedule(Schedule.LINEAR)
def predict(
self,
prompt: str = Input(description="Text prompt"),
amt: int = Input(description="Amt", default=8)
) -> Path:
"""Run a single prediction on the model"""
# processed_input = preprocess(image)
# output = self.model(processed_image, scale)
# return postprocess(output)
# Vectorize the prompt
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
return torchvision.utils.make_grid(generated).cpu().numpy()