""" 抽象基底クラス - すべての言語モデルの共通インターフェース リスコフ置換原則(LSP)に準拠し、どのモデル実装も 同じインターフェースで置換可能にする """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Tuple import torch @dataclass(frozen=True) class ModelConfig: """ モデル設定を保持するイミュータブルなデータクラス Attributes: name: UI表示名 model_id: HuggingFace model ID embedding_dim: embedding次元数 vocab_size: 語彙サイズ """ name: str model_id: str embedding_dim: int vocab_size: int class BaseLanguageModel(ABC): """ 言語モデルの抽象基底クラス すべてのモデル実装はこのクラスを継承し、 定義されたインターフェースを実装する必要がある """ def __init__(self, config: ModelConfig): """ Args: config: モデル設定 """ self._config = config self._model = None self._tokenizer = None self._is_loaded = False @property def config(self) -> ModelConfig: """モデル設定を取得""" return self._config @property def is_loaded(self) -> bool: """モデルがロード済みかどうか""" return self._is_loaded @abstractmethod def load(self) -> None: """ モデルとトークナイザーをロードする Raises: RuntimeError: モデルのロードに失敗した場合 """ pass @abstractmethod def forward_with_noise( self, noise: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ ノイズを入力として順伝播を実行 Args: noise: 入力ノイズテンソル [batch, seq_len, embedding_dim] Returns: Tuple[logits, corrupted_logits]: - logits: 生のlogits - corrupted_logits: ノイズ加算後のlogits """ pass @abstractmethod def decode_indices(self, indices: List[int]) -> List[str]: """ トークンインデックスをデコードして文字列リストに変換 Args: indices: トークンインデックスのリスト Returns: デコードされた文字列のリスト """ pass def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor: """ 入力用のランダムノイズを生成 Args: seq_len: シーケンス長 batch_size: バッチサイズ Returns: ノイズテンソル [batch_size, seq_len, embedding_dim] """ return torch.randn(batch_size, seq_len, self._config.embedding_dim)