File size: 2,869 Bytes
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
デブリ生成器

言語モデルにノイズを入力してデブリ(言語断片)を生成する
単一責任原則(SRP)に従い、生成ロジックのみを担当
"""
import time
from dataclasses import dataclass
from typing import List, Optional

import torch

from ..models.base import BaseLanguageModel


@dataclass
class DebrisResult:
    """
    デブリ生成結果を保持するイミュータブルなデータクラス

    Attributes:
        debris: 生成されたトークン文字列のリスト
        seed: 使用した乱数シード
        noise: 入力ノイズテンソル
        logits: 生のlogitsテンソル
        corrupted_logits: ノイズ加算後のlogitsテンソル
    """
    debris: List[str]
    seed: int
    noise: torch.Tensor
    logits: torch.Tensor
    corrupted_logits: torch.Tensor


class DebrisGenerator:
    """
    デブリ生成器

    言語モデルを使用してランダムノイズから
    言語断片(デブリ)を生成する

    依存性逆転原則(DIP)に従い、具象クラスではなく
    BaseLanguageModel抽象クラスに依存する
    """

    # デフォルトのシーケンス長
    DEFAULT_SEQ_LEN = 32

    def __init__(self, model: BaseLanguageModel):
        """
        Args:
            model: 使用する言語モデル(BaseLanguageModelを実装)
        """
        self._model = model

    @property
    def model(self) -> BaseLanguageModel:
        """使用中のモデルを取得"""
        return self._model

    def generate(
        self,
        seed: Optional[int] = None,
        seq_len: int = DEFAULT_SEQ_LEN,
    ) -> DebrisResult:
        """
        デブリを生成

        Args:
            seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用)
            seq_len: 生成するシーケンス長

        Returns:
            DebrisResult: 生成結果

        Raises:
            RuntimeError: モデルが未ロードの場合
        """
        # シードの設定
        if seed is None:
            seed = time.time_ns()
        torch.manual_seed(seed)

        # モデルがロードされていなければロード
        if not self._model.is_loaded:
            self._model.load()

        # ノイズ生成と順伝播
        noise = self._model.generate_noise(seq_len=seq_len)
        logits, corrupted_logits = self._model.forward_with_noise(noise)

        # argmaxでインデックス抽出
        indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()

        # インデックスをトークン文字列にデコード
        debris = self._model.decode_indices(indices)

        return DebrisResult(
            debris=debris,
            seed=seed,
            noise=noise,
            logits=logits,
            corrupted_logits=corrupted_logits,
        )