| """ |
| GPT-Neo モデル実装 |
| |
| EleutherAI GPT-Neo 125Mの実装を提供する |
| """ |
| from typing import List, Tuple |
|
|
| import torch |
| from transformers import GPTNeoForCausalLM, GPT2Tokenizer |
|
|
| from .base import BaseLanguageModel, ModelConfig |
|
|
|
|
| |
| GPT_NEO_125M_CONFIG = ModelConfig( |
| name="GPT-Neo 125M", |
| model_id="EleutherAI/gpt-neo-125M", |
| embedding_dim=768, |
| vocab_size=50257, |
| ) |
|
|
|
|
| class GPTNeoModel(BaseLanguageModel): |
| """ |
| GPT-Neoモデルの実装 |
| |
| EleutherAI GPT-NeoをラップしBaseLanguageModelインターフェースを実装 |
| """ |
|
|
| |
| LOGITS_NOISE_SCALE = 10.0 |
|
|
| def load(self) -> None: |
| """モデルとトークナイザーをロード""" |
| if self._is_loaded: |
| return |
|
|
| try: |
| self._model = GPTNeoForCausalLM.from_pretrained(self._config.model_id) |
| |
| self._tokenizer = GPT2Tokenizer.from_pretrained(self._config.model_id) |
| self._model.eval() |
| self._is_loaded = True |
| except Exception as e: |
| raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}") |
|
|
| def forward_with_noise( |
| self, noise: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ノイズを入力として順伝播を実行""" |
| if not self._is_loaded: |
| raise RuntimeError("Model not loaded. Call load() first.") |
|
|
| with torch.no_grad(): |
| outputs = self._model(inputs_embeds=noise) |
| logits = outputs.logits |
|
|
| logits_noise = ( |
| torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE |
| ) |
| corrupted_logits = logits + logits_noise |
|
|
| return logits, corrupted_logits |
|
|
| def decode_indices(self, indices: List[int]) -> List[str]: |
| """トークンインデックスをデコード""" |
| if not self._is_loaded: |
| raise RuntimeError("Model not loaded. Call load() first.") |
|
|
| return [self._tokenizer.decode([i]) for i in indices] |
|
|