openagi-agi commited on
Commit
4ac0357
·
verified ·
1 Parent(s): eefca2c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +42 -0
train.py CHANGED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.text.all import *
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ import tiktoken
5
+
6
+ enc = tiktoken.get_encoding("o200k_base")
7
+
8
+ def tokenizer(s):
9
+ ids = enc.encode(s)
10
+ tokens_list = [enc.decode([i]) for i in ids]
11
+ return tokens_list
12
+
13
+ def main():
14
+ path = Path('data/chat_data.txt')
15
+ text = path.read_text(encoding='utf-8')
16
+
17
+ dls = TextDataLoaders.from_df(
18
+ pd.DataFrame({'text':[text]}),
19
+ text_col='text',
20
+ is_lm=True,
21
+ tok_func=tokenizer,
22
+ seq_len=256
23
+ )
24
+
25
+ learn = language_model_learner(
26
+ dls,
27
+ arch=AWD_LSTM,
28
+ metrics=[accuracy, Perplexity()],
29
+ pretrained=False
30
+ ).to_fp16()
31
+
32
+ learn.fit_one_cycle(5000, 1e-3)
33
+
34
+ # Export full learner (architecture + weights + vocab)
35
+ learn.export('model.pkl')
36
+
37
+ TEXT = """Hi!"""
38
+ generated = learn.predict(TEXT, 200, temperature=0.9)
39
+ print("\nGenerated text:\n", generated)
40
+
41
+ if __name__ == "__main__":
42
+ main()