Update README.md
Browse files
README.md
CHANGED
|
@@ -4,6 +4,114 @@ language:
|
|
| 4 |
- ja
|
| 5 |
---
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
## Environment
|
| 8 |
|
| 9 |
### Git Information
|
|
|
|
| 4 |
- ja
|
| 5 |
---
|
| 6 |
|
| 7 |
+
## nanochat-jp_base
|
| 8 |
+
|
| 9 |
+
karpathyさん(元スタンドード 元テスラ 元OpenAIの)教育目的PJである[nanochat](https://github.com/karpathy/nanochat)
|
| 10 |
+
のd20版を日本語データ([kajuma/ABEJA-CC-JA-edu 10%](https://huggingface.co/datasets/kajuma/ABEJA-CC-JA-edu)を使って事前学習させたモデルです。
|
| 11 |
+
|
| 12 |
+
事前学習のみのため、補完しかできませんが、補完できることは確認済です。
|
| 13 |
+
|
| 14 |
+
ホームディレクトリ(~/.cache/nanochat/)に
|
| 15 |
+
- base_checkpoints_jp
|
| 16 |
+
- tokenizer
|
| 17 |
+
を配置する事で中間学習、SFTを実行する事ができると思います。
|
| 18 |
+
|
| 19 |
+
中間学習、SFTは比較的軽い処理なのでバッチサイズを減らせばローカルPCで実行可能です。
|
| 20 |
+
|
| 21 |
+
## 謝辞
|
| 22 |
+
以下の方たちのお力添えがなければこのモデルは完成しませんでした。ありがとうございます!
|
| 23 |
+
- karpathyさん
|
| 24 |
+
- kajumaさん
|
| 25 |
+
- ABEJA社
|
| 26 |
+
- 日本語でブログやWebサイトを執筆してくださった皆様
|
| 27 |
+
|
| 28 |
+
## 単体動作確認スクリプト
|
| 29 |
+
|
| 30 |
+
Linux 前提です。文章の続きの補完のみです。
|
| 31 |
+
|
| 32 |
+
1. [nanochat](https://github.com/karpathy/nanochat)をクローン
|
| 33 |
+
2. speedrun.shを動かす(失敗するが.venvは作ってくれる)
|
| 34 |
+
3. source .venv/bin/activate
|
| 35 |
+
4. ~/.cache/nanochat/に本リポジトリのbase_checkpoints_jpとtokenizerを配置
|
| 36 |
+
5. 以下のスクリプトをnanochatディレクトリ配下で動かす
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
# test_pretrained_jp.py
|
| 40 |
+
|
| 41 |
+
import os
|
| 42 |
+
import sys
|
| 43 |
+
sys.path.append(os.getcwd())
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
from nanochat.common import get_base_dir
|
| 47 |
+
from nanochat.checkpoint_manager import load_model_from_dir
|
| 48 |
+
|
| 49 |
+
# --- 設定 ---
|
| 50 |
+
MODEL_DIR_NAME = "base_checkpoints_jp"
|
| 51 |
+
MODEL_TAG = "d20"
|
| 52 |
+
STEP = None
|
| 53 |
+
MAX_NEW_TOKENS = 100
|
| 54 |
+
TEMPERATURE = 0.7
|
| 55 |
+
TOP_K = 50
|
| 56 |
+
|
| 57 |
+
# --- メイン実行部 ---
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
print("--- 事前学習済み日本語モデル テストスクリプト ---")
|
| 60 |
+
|
| 61 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 62 |
+
print(f"使用デバイス: {device}")
|
| 63 |
+
|
| 64 |
+
use_bf16 = (device == 'cuda' and torch.cuda.is_bf16_supported())
|
| 65 |
+
autocast_ctx = torch.amp.autocast(device_type=device, dtype=torch.bfloat16, enabled=use_bf16)
|
| 66 |
+
if use_bf16:
|
| 67 |
+
print("bfloat16がサポートされています。混合精度で推論を実行します。")
|
| 68 |
+
|
| 69 |
+
base_dir = get_base_dir()
|
| 70 |
+
checkpoints_dir = os.path.join(base_dir, MODEL_DIR_NAME)
|
| 71 |
+
|
| 72 |
+
print(f"モデルを次のパスから読み込みます: {os.path.join(checkpoints_dir, MODEL_TAG)}")
|
| 73 |
+
if not os.path.exists(os.path.join(checkpoints_dir, MODEL_TAG)):
|
| 74 |
+
print("\nFATAL: モデルディレクトリが見つかりません。")
|
| 75 |
+
print(f"ローカルの '{checkpoints_dir}' 以下に 'd20' などのディレクトリとしてモデルファイルが配置されているか確認してください。")
|
| 76 |
+
sys.exit(1)
|
| 77 |
+
|
| 78 |
+
model, tokenizer, meta = load_model_from_dir(
|
| 79 |
+
checkpoints_dir, device, phase="eval", model_tag=MODEL_TAG, step=STEP
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
print("\nモデルとトークナイザーのロードが完了しました。")
|
| 83 |
+
print("事前学習済みモデルは、対話ではなく『文章の続き』を生成します。")
|
| 84 |
+
|
| 85 |
+
while True:
|
| 86 |
+
try:
|
| 87 |
+
user_input = input("\nプロンプトを入力してください (終了するにはCtrl+C): ")
|
| 88 |
+
if not user_input:
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
prompt_tokens = tokenizer.encode(user_input, prepend=tokenizer.get_bos_token_id())
|
| 92 |
+
|
| 93 |
+
print("-" * 30)
|
| 94 |
+
print("生成開始...")
|
| 95 |
+
print("入力プロンプト: ", user_input, end="")
|
| 96 |
+
|
| 97 |
+
# ★★★ 修正点2: no_grad と autocast の両方で囲む ★★★
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
with autocast_ctx:
|
| 100 |
+
stream = model.generate(prompt_tokens, max_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_k=TOP_K)
|
| 101 |
+
|
| 102 |
+
for token in stream:
|
| 103 |
+
print(tokenizer.decode([token]), end="", flush=True)
|
| 104 |
+
|
| 105 |
+
print("\n" + "-" * 30)
|
| 106 |
+
|
| 107 |
+
except KeyboardInterrupt:
|
| 108 |
+
print("\n終了します。")
|
| 109 |
+
break
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"\nエラーが発生しました: {e}")
|
| 112 |
+
break
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
## Environment
|
| 116 |
|
| 117 |
### Git Information
|