|
|
from fastai.text.all import * |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
import tiktoken |
|
|
|
|
|
enc = tiktoken.get_encoding("o200k_base") |
|
|
|
|
|
def tokenizer(s): |
|
|
ids = enc.encode(s) |
|
|
tokens_list = [enc.decode([i]) for i in ids] |
|
|
return tokens_list |
|
|
|
|
|
def main(): |
|
|
path = Path('data/chat_data.txt') |
|
|
text = path.read_text(encoding='utf-8') |
|
|
|
|
|
dls = TextDataLoaders.from_df( |
|
|
pd.DataFrame({'text':[text]}), |
|
|
text_col='text', |
|
|
is_lm=True, |
|
|
tok_func=tokenizer, |
|
|
seq_len=256 |
|
|
) |
|
|
|
|
|
learn = language_model_learner( |
|
|
dls, |
|
|
arch=AWD_LSTM, |
|
|
metrics=[accuracy, Perplexity()], |
|
|
pretrained=False |
|
|
).to_fp16() |
|
|
|
|
|
learn.fit_one_cycle(5000, 1e-3) |
|
|
|
|
|
|
|
|
learn.export('model.pkl') |
|
|
|
|
|
TEXT = """Hi!""" |
|
|
generated = learn.predict(TEXT, 200, temperature=0.9) |
|
|
print("\nGenerated text:\n", generated) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|