# Prediction interface for Cog ⚙️ # https://cog.run/python from cog import BasePredictor, Input, Path import os from factories import UNet_conditional from wrapper import DiffusionManager, Schedule import torch import re from bert_vectorize import vectorize_text_with_bert from logger import save_grid_with_label import torchvision import time class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" # self.model = torch.load("./weights.pth") # Initialize model, diffusion manager, and set up environment device = "cpu" model_dir = "runs/run_3_jxa" self.device = device self.model_dir = model_dir # Create directories if they do not exist os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) # Load model self.net = UNet_conditional(num_classes=768,device=device) self.net.to(self.device) self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False)) # Set up DiffusionManager self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) self.wrapper.set_schedule(Schedule.LINEAR) def predict( self, prompt: str = Input(description="Text prompt"), amt: int = Input(description="Amt", default=8) ) -> Path: """Run a single prediction on the model""" # processed_input = preprocess(image) # output = self.model(processed_image, scale) # return postprocess(output) # Vectorize the prompt vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() return torchvision.utils.make_grid(generated).cpu().numpy()