File size: 1,874 Bytes
f86c7c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Prediction interface for Cog ⚙️
# https://cog.run/python
from cog import BasePredictor, Input, Path
import os
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
import torch
import re
from bert_vectorize import vectorize_text_with_bert
from logger import save_grid_with_label
import torchvision
import time
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# self.model = torch.load("./weights.pth")
# Initialize model, diffusion manager, and set up environment
device = "cpu"
model_dir = "runs/run_3_jxa"
self.device = device
self.model_dir = model_dir
# Create directories if they do not exist
os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
# Load model
self.net = UNet_conditional(num_classes=768,device=device)
self.net.to(self.device)
self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False))
# Set up DiffusionManager
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
self.wrapper.set_schedule(Schedule.LINEAR)
def predict(
self,
prompt: str = Input(description="Text prompt"),
amt: int = Input(description="Amt", default=8)
) -> Path:
"""Run a single prediction on the model"""
# processed_input = preprocess(image)
# output = self.model(processed_image, scale)
# return postprocess(output)
# Vectorize the prompt
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
return torchvision.utils.make_grid(generated).cpu().numpy()
|