File size: 1,874 Bytes
f86c7c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path
import os
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
import torch
import re
from bert_vectorize import vectorize_text_with_bert
from logger import save_grid_with_label
import torchvision
import time


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        # self.model = torch.load("./weights.pth")
        # Initialize model, diffusion manager, and set up environment
        device = "cpu"
        model_dir = "runs/run_3_jxa"
        self.device = device
        self.model_dir = model_dir
        
        # Create directories if they do not exist
        os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)

        # Load model
        self.net = UNet_conditional(num_classes=768,device=device)
        self.net.to(self.device)
        self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False))

        # Set up DiffusionManager
        self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
        self.wrapper.set_schedule(Schedule.LINEAR)

    def predict(
        self,
        prompt: str = Input(description="Text prompt"),
        amt: int = Input(description="Amt", default=8)
    ) -> Path:
        """Run a single prediction on the model"""
        # processed_input = preprocess(image)
        # output = self.model(processed_image, scale)
        # return postprocess(output)


        # Vectorize the prompt
        vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)

        generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()

        return torchvision.utils.make_grid(generated).cpu().numpy()