""" デブリ生成器 言語モデルにノイズを入力してデブリ(言語断片)を生成する 単一責任原則(SRP)に従い、生成ロジックのみを担当 """ import time from dataclasses import dataclass from typing import List, Optional import torch from ..models.base import BaseLanguageModel @dataclass class DebrisResult: """ デブリ生成結果を保持するイミュータブルなデータクラス Attributes: debris: 生成されたトークン文字列のリスト seed: 使用した乱数シード noise: 入力ノイズテンソル logits: 生のlogitsテンソル corrupted_logits: ノイズ加算後のlogitsテンソル """ debris: List[str] seed: int noise: torch.Tensor logits: torch.Tensor corrupted_logits: torch.Tensor class DebrisGenerator: """ デブリ生成器 言語モデルを使用してランダムノイズから 言語断片(デブリ)を生成する 依存性逆転原則(DIP)に従い、具象クラスではなく BaseLanguageModel抽象クラスに依存する """ # デフォルトのシーケンス長 DEFAULT_SEQ_LEN = 32 def __init__(self, model: BaseLanguageModel): """ Args: model: 使用する言語モデル(BaseLanguageModelを実装) """ self._model = model @property def model(self) -> BaseLanguageModel: """使用中のモデルを取得""" return self._model def generate( self, seed: Optional[int] = None, seq_len: int = DEFAULT_SEQ_LEN, ) -> DebrisResult: """ デブリを生成 Args: seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用) seq_len: 生成するシーケンス長 Returns: DebrisResult: 生成結果 Raises: RuntimeError: モデルが未ロードの場合 """ # シードの設定 if seed is None: seed = time.time_ns() torch.manual_seed(seed) # モデルがロードされていなければロード if not self._model.is_loaded: self._model.load() # ノイズ生成と順伝播 noise = self._model.generate_noise(seq_len=seq_len) logits, corrupted_logits = self._model.forward_with_noise(noise) # argmaxでインデックス抽出 indices = corrupted_logits.argmax(dim=-1).squeeze().tolist() # インデックスをトークン文字列にデコード debris = self._model.decode_indices(indices) return DebrisResult( debris=debris, seed=seed, noise=noise, logits=logits, corrupted_logits=corrupted_logits, )