File size: 3,851 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from koja_diffuser.util import tensor_to_bytes
from koja_diffuser.model import Encoder, Decoder, DiffusionTranslate
from koja_diffuser.tokenizer.ko import KoreanTokenizer
from koja_diffuser.tokenizer.ja import JapaneseTokenizer
from koja_diffuser.config import Stage2Config, Config, get_config
from koja_diffuser.runtime.sample import sample_logits
from typing import Literal
from torch import Tensor
from pathlib import Path
from huggingface_hub import hf_hub_download
import dataclasses


def load_lang(lang: Literal["ko", "ja"], data, *, device: torch.device):
    config = get_config(lang, data["config"])
    tokenizer = (KoreanTokenizer if lang == "ko" else JapaneseTokenizer)(
        tensor_to_bytes(data["parquet"])
    )
    vocab_size = len(tokenizer)
    encoder = Encoder(
        vocab_size=vocab_size,
        latent_size=config.latent_size,
        max_len=config.max_len,
    ).to(device)
    encoder.load_state_dict(data["encoder"])
    encoder.eval()

    decoder = Decoder(vocab_size=vocab_size, max_len=config.max_len).to(device)
    decoder.load_state_dict(data["decoder"])
    decoder.eval()

    return LangModels(tokenizer, encoder, decoder, device, config)


@dataclasses.dataclass(frozen=True)
class LangModels:
    tokenizer: KoreanTokenizer | JapaneseTokenizer
    encoder: Encoder
    decoder: Decoder
    device: torch.device
    config: Config

    def encode(self, names: list[str], ages: list[int]):
        encoded_ids = [
            self.tokenizer.encode(
                name,
                add_eos=True,
                max_len=self.config.max_len,
            )
            for name in names
        ]

        names = torch.tensor(
            encoded_ids,
            dtype=torch.long,
        ).to(device=self.device, non_blocking=True)

        ages = torch.tensor(ages, dtype=torch.long, device=self.device).clamp(
            min=0, max=9
        )

        return self.encoder(names, ages), encoded_ids

    def decode(
        self,
        names: Tensor,
        *,
        sampling_mode: Literal["greedy", "sample"] = "sample",
        temperature: float = 0.8,
        top_k: int = 20,
        top_p: float = 0.9,
        generator: torch.Generator | None = None,
    ):
        if sampling_mode == "greedy":
            logits = self.decoder(names).argmax(dim=-1)
        else:
            logits = sample_logits(
                self.decoder(names), temperature, top_k, top_p, generator
            )
        decoded: list[list[int]] = logits.cpu().tolist()
        res: list[str] = []
        for ids in decoded:
            res.append(self.tokenizer.decode(ids))
        return res, decoded


@dataclasses.dataclass(frozen=True)
class Models:
    ko: LangModels
    ja: LangModels
    bridge_kj: DiffusionTranslate
    bridge_jk: DiffusionTranslate
    config: Stage2Config


def load_model(*, path="./dist/full.pt", device: torch.device):
    if not Path(path).is_file():
        path = hf_hub_download(
            repo_id="MOKA-AYUMU/KoJaNameDiffuser",
            filename="full.pt",
            revision="main",
        )

    ckpt = torch.load(path, map_location=device)

    ko = load_lang("ko", ckpt["ko"], device=device)
    ja = load_lang("ja", ckpt["ja"], device=device)

    bridge_kj = DiffusionTranslate(
        source_latent_size=ckpt["ko"]["config"]["latent_size"],
        target_latent_size=ckpt["ja"]["config"]["latent_size"],
    ).to(device=device)
    bridge_kj.load_state_dict(ckpt["bridge_kj"])
    bridge_kj.eval()

    bridge_jk = DiffusionTranslate(
        source_latent_size=ckpt["ja"]["config"]["latent_size"],
        target_latent_size=ckpt["ko"]["config"]["latent_size"],
    ).to(device=device)
    bridge_jk.load_state_dict(ckpt["bridge_jk"])
    bridge_jk.eval()

    return Models(ko, ja, bridge_kj, bridge_jk, Stage2Config(**ckpt["config"]))