|
|
""" |
|
|
抽象基底クラス - すべての言語モデルの共通インターフェース |
|
|
|
|
|
リスコフ置換原則(LSP)に準拠し、どのモデル実装も |
|
|
同じインターフェースで置換可能にする |
|
|
""" |
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class ModelConfig: |
|
|
""" |
|
|
モデル設定を保持するイミュータブルなデータクラス |
|
|
|
|
|
Attributes: |
|
|
name: UI表示名 |
|
|
model_id: HuggingFace model ID |
|
|
embedding_dim: embedding次元数 |
|
|
vocab_size: 語彙サイズ |
|
|
""" |
|
|
name: str |
|
|
model_id: str |
|
|
embedding_dim: int |
|
|
vocab_size: int |
|
|
|
|
|
|
|
|
class BaseLanguageModel(ABC): |
|
|
""" |
|
|
言語モデルの抽象基底クラス |
|
|
|
|
|
すべてのモデル実装はこのクラスを継承し、 |
|
|
定義されたインターフェースを実装する必要がある |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
""" |
|
|
Args: |
|
|
config: モデル設定 |
|
|
""" |
|
|
self._config = config |
|
|
self._model = None |
|
|
self._tokenizer = None |
|
|
self._is_loaded = False |
|
|
|
|
|
@property |
|
|
def config(self) -> ModelConfig: |
|
|
"""モデル設定を取得""" |
|
|
return self._config |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
"""モデルがロード済みかどうか""" |
|
|
return self._is_loaded |
|
|
|
|
|
@abstractmethod |
|
|
def load(self) -> None: |
|
|
""" |
|
|
モデルとトークナイザーをロードする |
|
|
|
|
|
Raises: |
|
|
RuntimeError: モデルのロードに失敗した場合 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def forward_with_noise( |
|
|
self, noise: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
ノイズを入力として順伝播を実行 |
|
|
|
|
|
Args: |
|
|
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim] |
|
|
|
|
|
Returns: |
|
|
Tuple[logits, corrupted_logits]: |
|
|
- logits: 生のlogits |
|
|
- corrupted_logits: ノイズ加算後のlogits |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def decode_indices(self, indices: List[int]) -> List[str]: |
|
|
""" |
|
|
トークンインデックスをデコードして文字列リストに変換 |
|
|
|
|
|
Args: |
|
|
indices: トークンインデックスのリスト |
|
|
|
|
|
Returns: |
|
|
デコードされた文字列のリスト |
|
|
""" |
|
|
pass |
|
|
|
|
|
def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor: |
|
|
""" |
|
|
入力用のランダムノイズを生成 |
|
|
|
|
|
Args: |
|
|
seq_len: シーケンス長 |
|
|
batch_size: バッチサイズ |
|
|
|
|
|
Returns: |
|
|
ノイズテンソル [batch_size, seq_len, embedding_dim] |
|
|
""" |
|
|
return torch.randn(batch_size, seq_len, self._config.embedding_dim) |
|
|
|