File size: 2,925 Bytes
d1033d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""
抽象基底クラス - すべての言語モデルの共通インターフェース
リスコフ置換原則(LSP)に準拠し、どのモデル実装も
同じインターフェースで置換可能にする
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple
import torch
@dataclass(frozen=True)
class ModelConfig:
"""
モデル設定を保持するイミュータブルなデータクラス
Attributes:
name: UI表示名
model_id: HuggingFace model ID
embedding_dim: embedding次元数
vocab_size: 語彙サイズ
"""
name: str
model_id: str
embedding_dim: int
vocab_size: int
class BaseLanguageModel(ABC):
"""
言語モデルの抽象基底クラス
すべてのモデル実装はこのクラスを継承し、
定義されたインターフェースを実装する必要がある
"""
def __init__(self, config: ModelConfig):
"""
Args:
config: モデル設定
"""
self._config = config
self._model = None
self._tokenizer = None
self._is_loaded = False
@property
def config(self) -> ModelConfig:
"""モデル設定を取得"""
return self._config
@property
def is_loaded(self) -> bool:
"""モデルがロード済みかどうか"""
return self._is_loaded
@abstractmethod
def load(self) -> None:
"""
モデルとトークナイザーをロードする
Raises:
RuntimeError: モデルのロードに失敗した場合
"""
pass
@abstractmethod
def forward_with_noise(
self, noise: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
ノイズを入力として順伝播を実行
Args:
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
Returns:
Tuple[logits, corrupted_logits]:
- logits: 生のlogits
- corrupted_logits: ノイズ加算後のlogits
"""
pass
@abstractmethod
def decode_indices(self, indices: List[int]) -> List[str]:
"""
トークンインデックスをデコードして文字列リストに変換
Args:
indices: トークンインデックスのリスト
Returns:
デコードされた文字列のリスト
"""
pass
def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor:
"""
入力用のランダムノイズを生成
Args:
seq_len: シーケンス長
batch_size: バッチサイズ
Returns:
ノイズテンソル [batch_size, seq_len, embedding_dim]
"""
return torch.randn(batch_size, seq_len, self._config.embedding_dim)
|