File size: 967 Bytes
4ac0357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from fastai.text.all import *
from pathlib import Path
import pandas as pd
import tiktoken
enc = tiktoken.get_encoding("o200k_base")
def tokenizer(s):
ids = enc.encode(s)
tokens_list = [enc.decode([i]) for i in ids]
return tokens_list
def main():
path = Path('data/chat_data.txt')
text = path.read_text(encoding='utf-8')
dls = TextDataLoaders.from_df(
pd.DataFrame({'text':[text]}),
text_col='text',
is_lm=True,
tok_func=tokenizer,
seq_len=256
)
learn = language_model_learner(
dls,
arch=AWD_LSTM,
metrics=[accuracy, Perplexity()],
pretrained=False
).to_fp16()
learn.fit_one_cycle(5000, 1e-3)
# Export full learner (architecture + weights + vocab)
learn.export('model.pkl')
TEXT = """Hi!"""
generated = learn.predict(TEXT, 200, temperature=0.9)
print("\nGenerated text:\n", generated)
if __name__ == "__main__":
main()
|