File size: 2,869 Bytes
d1033d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""
デブリ生成器
言語モデルにノイズを入力してデブリ(言語断片)を生成する
単一責任原則(SRP)に従い、生成ロジックのみを担当
"""
import time
from dataclasses import dataclass
from typing import List, Optional
import torch
from ..models.base import BaseLanguageModel
@dataclass
class DebrisResult:
"""
デブリ生成結果を保持するイミュータブルなデータクラス
Attributes:
debris: 生成されたトークン文字列のリスト
seed: 使用した乱数シード
noise: 入力ノイズテンソル
logits: 生のlogitsテンソル
corrupted_logits: ノイズ加算後のlogitsテンソル
"""
debris: List[str]
seed: int
noise: torch.Tensor
logits: torch.Tensor
corrupted_logits: torch.Tensor
class DebrisGenerator:
"""
デブリ生成器
言語モデルを使用してランダムノイズから
言語断片(デブリ)を生成する
依存性逆転原則(DIP)に従い、具象クラスではなく
BaseLanguageModel抽象クラスに依存する
"""
# デフォルトのシーケンス長
DEFAULT_SEQ_LEN = 32
def __init__(self, model: BaseLanguageModel):
"""
Args:
model: 使用する言語モデル(BaseLanguageModelを実装)
"""
self._model = model
@property
def model(self) -> BaseLanguageModel:
"""使用中のモデルを取得"""
return self._model
def generate(
self,
seed: Optional[int] = None,
seq_len: int = DEFAULT_SEQ_LEN,
) -> DebrisResult:
"""
デブリを生成
Args:
seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用)
seq_len: 生成するシーケンス長
Returns:
DebrisResult: 生成結果
Raises:
RuntimeError: モデルが未ロードの場合
"""
# シードの設定
if seed is None:
seed = time.time_ns()
torch.manual_seed(seed)
# モデルがロードされていなければロード
if not self._model.is_loaded:
self._model.load()
# ノイズ生成と順伝播
noise = self._model.generate_noise(seq_len=seq_len)
logits, corrupted_logits = self._model.forward_with_noise(noise)
# argmaxでインデックス抽出
indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
# インデックスをトークン文字列にデコード
debris = self._model.decode_indices(indices)
return DebrisResult(
debris=debris,
seed=seed,
noise=noise,
logits=logits,
corrupted_logits=corrupted_logits,
)
|