File size: 4,887 Bytes
fcfea15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional
import os

from PIL import Image
import torch

from infer_runtime.infer_config import InferConfig, load_infer_config_class_from_pyfile
from infer_runtime.prompt_rewrite import rewrite_prompt
from infer_runtime.settings import InferSettings
from modules.models import load_dit, load_pipeline
from modules.utils import _dynamic_resize_from_bucket, seed_everything


@dataclass
class InferenceParams:
    prompt: str
    image: Optional[Image.Image]
    height: int
    width: int
    steps: int
    guidance_scale: float
    seed: int
    neg_prompt: str
    basesize: int


class EditModel:
    def __init__(
        self,
        settings: InferSettings,
        device: torch.device,
        hsdp_shard_dim_override: int | None = None,
    ):
        self.settings = settings
        self.device = device
        self._rewrite_cache: dict[str, str] = {}

        config_class = load_infer_config_class_from_pyfile(settings.config_path)
        self.cfg: InferConfig = config_class()
        self.cfg.dit_ckpt = settings.ckpt_path
        self.cfg.training_mode = False
        if hsdp_shard_dim_override is not None:
            self.cfg.hsdp_shard_dim = hsdp_shard_dim_override
        if int(os.environ.get('WORLD_SIZE', '1')) > 1 and self.cfg.hsdp_shard_dim > 1:
            self.cfg.use_fsdp_inference = True

        self.dit = load_dit(self.cfg, device=self.device)
        self.dit.requires_grad_(False)
        self.dit.eval()
        self.pipeline = load_pipeline(self.cfg, self.dit, self.device)

    def current_device(self) -> torch.device:
        return self.device

    def move_to_device(self, device: torch.device) -> torch.device:
        target = torch.device(device)
        if self.device == target:
            return self.device

        self.dit = self.dit.to(device=target)
        self.pipeline = self.pipeline.to(target)
        self.device = target
        return self.device

    def move_to_cpu(self) -> torch.device:
        return self.move_to_device(torch.device('cpu'))

    def move_to_gpu(self, device: torch.device | None = None) -> torch.device:
        target = torch.device(device) if device is not None else torch.device('cuda')
        return self.move_to_device(target)

    def maybe_rewrite_prompt(self, prompt: str, image: Optional[Image.Image], enabled: bool) -> str:
        if not enabled:
            return str(prompt or '')
        cache_key = f"prompt={prompt.strip()}"
        if image is not None:
            cache_key += f"|image={image.size[0]}x{image.size[1]}"
        if cache_key not in self._rewrite_cache:
            self._rewrite_cache[cache_key] = rewrite_prompt(
                prompt,
                image,
                model=self.settings.rewrite_model,
                api_key=self.settings.openai_api_key,
                base_url=self.settings.openai_base_url,
            )
        return self._rewrite_cache[cache_key]

    @torch.no_grad()
    def infer(self, params: InferenceParams) -> Image.Image:
        if params.image is None:
            prompts = [params.prompt]
            negative_prompt = [params.neg_prompt]
            images = None
            height = params.height
            width = params.width
        else:
            processed = _dynamic_resize_from_bucket(params.image, basesize=params.basesize)
            width, height = processed.size
            image_tokens = '<image>\n'
            prompts = [f"<|im_start|>user\n{image_tokens}{params.prompt}<|im_end|>\n"]
            negative_prompt = [f"<|im_start|>user\n{image_tokens}{params.neg_prompt}<|im_end|>\n"]
            images = [processed]

        generator_device = 'cuda' if self.device.type == 'cuda' else 'cpu'
        generator = torch.Generator(device=generator_device).manual_seed(int(params.seed))
        output = self.pipeline(
            prompt=prompts,
            negative_prompt=negative_prompt,
            images=images,
            height=height,
            width=width,
            num_frames=1,
            num_inference_steps=params.steps,
            guidance_scale=params.guidance_scale,
            generator=generator,
            num_videos_per_prompt=1,
            output_type='pt',
            return_dict=False,
        )
        image_tensor = (output[0, -1, 0] * 255).to(torch.uint8).cpu()
        return Image.fromarray(image_tensor.permute(1, 2, 0).numpy())


def build_model(
    settings: InferSettings,
    device: torch.device | None = None,
    hsdp_shard_dim_override: int | None = None,
) -> EditModel:
    seed_everything(settings.default_seed)
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return EditModel(
        settings=settings,
        device=device,
        hsdp_shard_dim_override=hsdp_shard_dim_override,
    )