|
|
|
|
|
|
|
|
|
|
|
from cog import BasePredictor, Input, Path |
|
|
import os |
|
|
from factories import UNet_conditional |
|
|
from wrapper import DiffusionManager, Schedule |
|
|
import torch |
|
|
import re |
|
|
from bert_vectorize import vectorize_text_with_bert |
|
|
from logger import save_grid_with_label |
|
|
import torchvision |
|
|
import time |
|
|
|
|
|
|
|
|
class Predictor(BasePredictor): |
|
|
def setup(self) -> None: |
|
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
model_dir = "runs/run_3_jxa" |
|
|
self.device = device |
|
|
self.model_dir = model_dir |
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) |
|
|
|
|
|
|
|
|
self.net = UNet_conditional(num_classes=768,device=device) |
|
|
self.net.to(self.device) |
|
|
self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False)) |
|
|
|
|
|
|
|
|
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) |
|
|
self.wrapper.set_schedule(Schedule.LINEAR) |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
prompt: str = Input(description="Text prompt"), |
|
|
amt: int = Input(description="Amt", default=8) |
|
|
) -> Path: |
|
|
"""Run a single prediction on the model""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) |
|
|
|
|
|
generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() |
|
|
|
|
|
return torchvision.utils.make_grid(generated).cpu().numpy() |
|
|
|