|
|
""" |
|
|
デブリ生成器 |
|
|
|
|
|
言語モデルにノイズを入力してデブリ(言語断片)を生成する |
|
|
単一責任原則(SRP)に従い、生成ロジックのみを担当 |
|
|
""" |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from ..models.base import BaseLanguageModel |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DebrisResult: |
|
|
""" |
|
|
デブリ生成結果を保持するイミュータブルなデータクラス |
|
|
|
|
|
Attributes: |
|
|
debris: 生成されたトークン文字列のリスト |
|
|
seed: 使用した乱数シード |
|
|
noise: 入力ノイズテンソル |
|
|
logits: 生のlogitsテンソル |
|
|
corrupted_logits: ノイズ加算後のlogitsテンソル |
|
|
""" |
|
|
debris: List[str] |
|
|
seed: int |
|
|
noise: torch.Tensor |
|
|
logits: torch.Tensor |
|
|
corrupted_logits: torch.Tensor |
|
|
|
|
|
|
|
|
class DebrisGenerator: |
|
|
""" |
|
|
デブリ生成器 |
|
|
|
|
|
言語モデルを使用してランダムノイズから |
|
|
言語断片(デブリ)を生成する |
|
|
|
|
|
依存性逆転原則(DIP)に従い、具象クラスではなく |
|
|
BaseLanguageModel抽象クラスに依存する |
|
|
""" |
|
|
|
|
|
|
|
|
DEFAULT_SEQ_LEN = 32 |
|
|
|
|
|
def __init__(self, model: BaseLanguageModel): |
|
|
""" |
|
|
Args: |
|
|
model: 使用する言語モデル(BaseLanguageModelを実装) |
|
|
""" |
|
|
self._model = model |
|
|
|
|
|
@property |
|
|
def model(self) -> BaseLanguageModel: |
|
|
"""使用中のモデルを取得""" |
|
|
return self._model |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
seed: Optional[int] = None, |
|
|
seq_len: int = DEFAULT_SEQ_LEN, |
|
|
) -> DebrisResult: |
|
|
""" |
|
|
デブリを生成 |
|
|
|
|
|
Args: |
|
|
seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用) |
|
|
seq_len: 生成するシーケンス長 |
|
|
|
|
|
Returns: |
|
|
DebrisResult: 生成結果 |
|
|
|
|
|
Raises: |
|
|
RuntimeError: モデルが未ロードの場合 |
|
|
""" |
|
|
|
|
|
if seed is None: |
|
|
seed = time.time_ns() |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
if not self._model.is_loaded: |
|
|
self._model.load() |
|
|
|
|
|
|
|
|
noise = self._model.generate_noise(seq_len=seq_len) |
|
|
logits, corrupted_logits = self._model.forward_with_noise(noise) |
|
|
|
|
|
|
|
|
indices = corrupted_logits.argmax(dim=-1).squeeze().tolist() |
|
|
|
|
|
|
|
|
debris = self._model.decode_indices(indices) |
|
|
|
|
|
return DebrisResult( |
|
|
debris=debris, |
|
|
seed=seed, |
|
|
noise=noise, |
|
|
logits=logits, |
|
|
corrupted_logits=corrupted_logits, |
|
|
) |
|
|
|