File size: 3,016 Bytes
ef06509
 
371a997
 
ef06509
 
 
 
 
371a997
ef06509
5721855
371a997
 
ef06509
 
5721855
ef06509
 
371a997
ef06509
 
 
371a997
5721855
 
 
 
 
371a997
ef06509
371a997
5721855
ef06509
 
 
 
5721855
 
ef06509
 
5721855
371a997
5721855
 
ef06509
5721855
ef06509
 
 
 
 
 
 
 
 
 
371a997
ef06509
 
 
 
 
 
 
 
 
371a997
 
ef06509
 
 
 
 
 
 
 
 
 
 
5721855
ef06509
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import pathlib
import shlex
import subprocess
import sys

import huggingface_hub
import numpy as np
import torch
from torch import nn

if os.getenv("SYSTEM") == "spaces":
    subprocess.run(shlex.split("sed -i '14,21d' StyleSwin/op/fused_act.py"), check=False)  # noqa: S603
    subprocess.run(shlex.split("sed -i '12,19d' StyleSwin/op/upfirdn2d.py"), check=False)  # noqa: S603

current_dir = pathlib.Path(__file__).parent
submodule_dir = current_dir / "StyleSwin"
sys.path.insert(0, submodule_dir.as_posix())

from models.generator import Generator  # noqa: E402  # pyright: ignore[reportMissingImports]


class Model:
    MODEL_NAMES = (
        "CelebAHQ_256",
        "FFHQ_256",
        "LSUNChurch_256",
        "CelebAHQ_1024",
        "FFHQ_1024",
    )

    def __init__(self) -> None:
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self._download_all_models()
        self.model_name = self.MODEL_NAMES[3]
        self.model = self._load_model(self.model_name)

        self.std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None].to(self.device)
        self.mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None].to(self.device)

    def _load_model(self, model_name: str) -> nn.Module:
        size = int(model_name.split("_")[1])
        channel_multiplier = 1 if size == 1024 else 2  # noqa: PLR2004
        model = Generator(size, style_dim=512, n_mlp=8, channel_multiplier=channel_multiplier)
        ckpt_path = huggingface_hub.hf_hub_download("public-data/StyleSwin", f"models/{model_name}.pt")
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["g_ema"])
        model.to(self.device)
        model.eval()
        return model

    def set_model(self, model_name: str) -> None:
        if model_name == self.model_name:
            return
        self.model_name = model_name
        self.model = self._load_model(model_name)

    def _download_all_models(self) -> None:
        for name in self.MODEL_NAMES:
            self._load_model(name)

    def generate_z(self, seed: int) -> torch.Tensor:
        seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
        z = np.random.RandomState(seed).randn(1, 512)
        return torch.from_numpy(z).float().to(self.device)

    def postprocess(self, tensors: torch.Tensor) -> np.ndarray:
        if not tensors.dim() == 4:  # noqa: PLR2004
            raise ValueError("tensors must be 4-dimensional")
        tensors = tensors * self.std + self.mean
        tensors = (tensors * 255).clamp(0, 255).to(torch.uint8)
        return tensors.permute(0, 2, 3, 1).cpu().numpy()

    @torch.inference_mode()
    def generate_image(self, seed: int) -> np.ndarray:
        z = self.generate_z(seed)
        out, _ = self.model(z)
        out = self.postprocess(out)
        return out[0]

    def set_model_and_generate_image(self, model_name: str, seed: int) -> np.ndarray:
        self.set_model(model_name)
        return self.generate_image(seed)