will / src /generators /debris_generator.py
matt1847's picture
リファクタ: srcディレクトリ構造への移行とDocker対応
d1033d4
"""
デブリ生成器
言語モデルにノイズを入力してデブリ(言語断片)を生成する
単一責任原則(SRP)に従い、生成ロジックのみを担当
"""
import time
from dataclasses import dataclass
from typing import List, Optional
import torch
from ..models.base import BaseLanguageModel
@dataclass
class DebrisResult:
"""
デブリ生成結果を保持するイミュータブルなデータクラス
Attributes:
debris: 生成されたトークン文字列のリスト
seed: 使用した乱数シード
noise: 入力ノイズテンソル
logits: 生のlogitsテンソル
corrupted_logits: ノイズ加算後のlogitsテンソル
"""
debris: List[str]
seed: int
noise: torch.Tensor
logits: torch.Tensor
corrupted_logits: torch.Tensor
class DebrisGenerator:
"""
デブリ生成器
言語モデルを使用してランダムノイズから
言語断片(デブリ)を生成する
依存性逆転原則(DIP)に従い、具象クラスではなく
BaseLanguageModel抽象クラスに依存する
"""
# デフォルトのシーケンス長
DEFAULT_SEQ_LEN = 32
def __init__(self, model: BaseLanguageModel):
"""
Args:
model: 使用する言語モデル(BaseLanguageModelを実装)
"""
self._model = model
@property
def model(self) -> BaseLanguageModel:
"""使用中のモデルを取得"""
return self._model
def generate(
self,
seed: Optional[int] = None,
seq_len: int = DEFAULT_SEQ_LEN,
) -> DebrisResult:
"""
デブリを生成
Args:
seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用)
seq_len: 生成するシーケンス長
Returns:
DebrisResult: 生成結果
Raises:
RuntimeError: モデルが未ロードの場合
"""
# シードの設定
if seed is None:
seed = time.time_ns()
torch.manual_seed(seed)
# モデルがロードされていなければロード
if not self._model.is_loaded:
self._model.load()
# ノイズ生成と順伝播
noise = self._model.generate_noise(seq_len=seq_len)
logits, corrupted_logits = self._model.forward_with_noise(noise)
# argmaxでインデックス抽出
indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
# インデックスをトークン文字列にデコード
debris = self._model.decode_indices(indices)
return DebrisResult(
debris=debris,
seed=seed,
noise=noise,
logits=logits,
corrupted_logits=corrupted_logits,
)