Yuchan
commited on
Update Mo_jax.py
Browse files
Mo_jax.py
CHANGED
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
| 8 |
import sentencepiece as spm
|
| 9 |
from functools import partial
|
| 10 |
from typing import Any, Callable, Optional, Tuple, Sequence
|
| 11 |
-
|
| 12 |
import jax
|
| 13 |
import jax.numpy as jnp
|
| 14 |
from jax import random
|
|
@@ -17,6 +17,13 @@ from flax.training import train_state, checkpoints
|
|
| 17 |
import optax
|
| 18 |
import tqdm
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# ------------------
|
| 21 |
# Config
|
| 22 |
# ------------------
|
|
@@ -32,6 +39,18 @@ SEED = 42
|
|
| 32 |
LEARNING_RATE = 1e-4
|
| 33 |
EPOCHS = 1
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Derived
|
| 36 |
NUM_DEVICES = jax.device_count()
|
| 37 |
assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
|
|
@@ -379,3 +398,4 @@ def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_l
|
|
| 379 |
# quick generate
|
| 380 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 381 |
print(generate_text(state, "์ง๋ 2๋
๋์ ์ถ์ฐ์ฐ์ด ๊ตญ๊ฐ๊ฐ ํ์ํ ์ฐ๊ตฌ๋ฅผ", p=0.9))
|
|
|
|
|
|
| 8 |
import sentencepiece as spm
|
| 9 |
from functools import partial
|
| 10 |
from typing import Any, Callable, Optional, Tuple, Sequence
|
| 11 |
+
import requests
|
| 12 |
import jax
|
| 13 |
import jax.numpy as jnp
|
| 14 |
from jax import random
|
|
|
|
| 17 |
import optax
|
| 18 |
import tqdm
|
| 19 |
|
| 20 |
+
def download_file(url, save_path):
|
| 21 |
+
r = requests.get(url, stream=True)
|
| 22 |
+
r.raise_for_status()
|
| 23 |
+
with open(save_path, "wb") as f:
|
| 24 |
+
for chunk in r.iter_content(8192*2):
|
| 25 |
+
f.write(chunk)
|
| 26 |
+
print(f"โ
{save_path} ์ ์ฅ๋จ")
|
| 27 |
# ------------------
|
| 28 |
# Config
|
| 29 |
# ------------------
|
|
|
|
| 39 |
LEARNING_RATE = 1e-4
|
| 40 |
EPOCHS = 1
|
| 41 |
|
| 42 |
+
if not os.path.exists(CORPUS_PATH):
|
| 43 |
+
download_file(
|
| 44 |
+
"https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true",
|
| 45 |
+
CORPUS_PATH
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if not os.path.exists(VOCAB_MODEL):
|
| 49 |
+
download_file(
|
| 50 |
+
"https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
|
| 51 |
+
VOCAB_MODEL
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
# Derived
|
| 55 |
NUM_DEVICES = jax.device_count()
|
| 56 |
assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
|
|
|
|
| 398 |
# quick generate
|
| 399 |
print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====")
|
| 400 |
print(generate_text(state, "์ง๋ 2๋
๋์ ์ถ์ฐ์ฐ์ด ๊ตญ๊ฐ๊ฐ ํ์ํ ์ฐ๊ตฌ๋ฅผ", p=0.9))
|
| 401 |
+
|