will / src /models /base.py
matt1847's picture
リファクタ: srcディレクトリ構造への移行とDocker対応
d1033d4
"""
抽象基底クラス - すべての言語モデルの共通インターフェース
リスコフ置換原則(LSP)に準拠し、どのモデル実装も
同じインターフェースで置換可能にする
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple
import torch
@dataclass(frozen=True)
class ModelConfig:
"""
モデル設定を保持するイミュータブルなデータクラス
Attributes:
name: UI表示名
model_id: HuggingFace model ID
embedding_dim: embedding次元数
vocab_size: 語彙サイズ
"""
name: str
model_id: str
embedding_dim: int
vocab_size: int
class BaseLanguageModel(ABC):
"""
言語モデルの抽象基底クラス
すべてのモデル実装はこのクラスを継承し、
定義されたインターフェースを実装する必要がある
"""
def __init__(self, config: ModelConfig):
"""
Args:
config: モデル設定
"""
self._config = config
self._model = None
self._tokenizer = None
self._is_loaded = False
@property
def config(self) -> ModelConfig:
"""モデル設定を取得"""
return self._config
@property
def is_loaded(self) -> bool:
"""モデルがロード済みかどうか"""
return self._is_loaded
@abstractmethod
def load(self) -> None:
"""
モデルとトークナイザーをロードする
Raises:
RuntimeError: モデルのロードに失敗した場合
"""
pass
@abstractmethod
def forward_with_noise(
self, noise: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
ノイズを入力として順伝播を実行
Args:
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
Returns:
Tuple[logits, corrupted_logits]:
- logits: 生のlogits
- corrupted_logits: ノイズ加算後のlogits
"""
pass
@abstractmethod
def decode_indices(self, indices: List[int]) -> List[str]:
"""
トークンインデックスをデコードして文字列リストに変換
Args:
indices: トークンインデックスのリスト
Returns:
デコードされた文字列のリスト
"""
pass
def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor:
"""
入力用のランダムノイズを生成
Args:
seq_len: シーケンス長
batch_size: バッチサイズ
Returns:
ノイズテンソル [batch_size, seq_len, embedding_dim]
"""
return torch.randn(batch_size, seq_len, self._config.embedding_dim)