Yuchan commited on
Commit
be37607
ยท
verified ยท
1 Parent(s): 63439a6

Update Mo_jax.py

Browse files
Files changed (1) hide show
  1. Mo_jax.py +21 -1
Mo_jax.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  import sentencepiece as spm
9
  from functools import partial
10
  from typing import Any, Callable, Optional, Tuple, Sequence
11
-
12
  import jax
13
  import jax.numpy as jnp
14
  from jax import random
@@ -17,6 +17,13 @@ from flax.training import train_state, checkpoints
17
  import optax
18
  import tqdm
19
 
 
 
 
 
 
 
 
20
  # ------------------
21
  # Config
22
  # ------------------
@@ -32,6 +39,18 @@ SEED = 42
32
  LEARNING_RATE = 1e-4
33
  EPOCHS = 1
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Derived
36
  NUM_DEVICES = jax.device_count()
37
  assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
@@ -379,3 +398,4 @@ def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_l
379
  # quick generate
380
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
381
  print(generate_text(state, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))
 
 
8
  import sentencepiece as spm
9
  from functools import partial
10
  from typing import Any, Callable, Optional, Tuple, Sequence
11
+ import requests
12
  import jax
13
  import jax.numpy as jnp
14
  from jax import random
 
17
  import optax
18
  import tqdm
19
 
20
+ def download_file(url, save_path):
21
+ r = requests.get(url, stream=True)
22
+ r.raise_for_status()
23
+ with open(save_path, "wb") as f:
24
+ for chunk in r.iter_content(8192*2):
25
+ f.write(chunk)
26
+ print(f"โœ… {save_path} ์ €์žฅ๋จ")
27
  # ------------------
28
  # Config
29
  # ------------------
 
39
  LEARNING_RATE = 1e-4
40
  EPOCHS = 1
41
 
42
+ if not os.path.exists(CORPUS_PATH):
43
+ download_file(
44
+ "https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true",
45
+ CORPUS_PATH
46
+ )
47
+
48
+ if not os.path.exists(VOCAB_MODEL):
49
+ download_file(
50
+ "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
51
+ VOCAB_MODEL
52
+ )
53
+
54
  # Derived
55
  NUM_DEVICES = jax.device_count()
56
  assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
 
398
  # quick generate
399
  print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
400
  print(generate_text(state, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))
401
+