File size: 967 Bytes
4ac0357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from fastai.text.all import *
from pathlib import Path
import pandas as pd
import tiktoken

enc = tiktoken.get_encoding("o200k_base")

def tokenizer(s): 
    ids = enc.encode(s)
    tokens_list = [enc.decode([i]) for i in ids]
    return tokens_list

def main():
    path = Path('data/chat_data.txt')
    text = path.read_text(encoding='utf-8')

    dls = TextDataLoaders.from_df(
        pd.DataFrame({'text':[text]}), 
        text_col='text',
        is_lm=True,
        tok_func=tokenizer,
        seq_len=256
    )

    learn = language_model_learner(
        dls, 
        arch=AWD_LSTM,
        metrics=[accuracy, Perplexity()],
        pretrained=False
    ).to_fp16()

    learn.fit_one_cycle(5000, 1e-3)

    # Export full learner (architecture + weights + vocab)
    learn.export('model.pkl')

    TEXT = """Hi!"""
    generated = learn.predict(TEXT, 200, temperature=0.9)
    print("\nGenerated text:\n", generated)

if __name__ == "__main__":
    main()