Spaces:
Runtime error
Runtime error
| # Prediction interface for Cog ⚙️ | |
| # https://cog.run/python | |
| from cog import BasePredictor, Input, Path | |
| import os | |
| from factories import UNet_conditional | |
| from wrapper import DiffusionManager, Schedule | |
| import torch | |
| import re | |
| from bert_vectorize import vectorize_text_with_bert | |
| from logger import save_grid_with_label | |
| import torchvision | |
| import time | |
| class Predictor(BasePredictor): | |
| def setup(self) -> None: | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| # self.model = torch.load("./weights.pth") | |
| # Initialize model, diffusion manager, and set up environment | |
| device = "cpu" | |
| model_dir = "runs/run_3_jxa" | |
| self.device = device | |
| self.model_dir = model_dir | |
| # Create directories if they do not exist | |
| os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) | |
| # Load model | |
| self.net = UNet_conditional(num_classes=768,device=device) | |
| self.net.to(self.device) | |
| self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False)) | |
| # Set up DiffusionManager | |
| self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) | |
| self.wrapper.set_schedule(Schedule.LINEAR) | |
| def predict( | |
| self, | |
| prompt: str = Input(description="Text prompt"), | |
| amt: int = Input(description="Amt", default=8) | |
| ) -> Path: | |
| """Run a single prediction on the model""" | |
| # processed_input = preprocess(image) | |
| # output = self.model(processed_image, scale) | |
| # return postprocess(output) | |
| # Vectorize the prompt | |
| vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) | |
| generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() | |
| return torchvision.utils.make_grid(generated).cpu().numpy() | |