ChatLSTM / train.py
openagi-agi's picture
Update train.py
4ac0357 verified
from fastai.text.all import *
from pathlib import Path
import pandas as pd
import tiktoken
enc = tiktoken.get_encoding("o200k_base")
def tokenizer(s):
ids = enc.encode(s)
tokens_list = [enc.decode([i]) for i in ids]
return tokens_list
def main():
path = Path('data/chat_data.txt')
text = path.read_text(encoding='utf-8')
dls = TextDataLoaders.from_df(
pd.DataFrame({'text':[text]}),
text_col='text',
is_lm=True,
tok_func=tokenizer,
seq_len=256
)
learn = language_model_learner(
dls,
arch=AWD_LSTM,
metrics=[accuracy, Perplexity()],
pretrained=False
).to_fp16()
learn.fit_one_cycle(5000, 1e-3)
# Export full learner (architecture + weights + vocab)
learn.export('model.pkl')
TEXT = """Hi!"""
generated = learn.predict(TEXT, 200, temperature=0.9)
print("\nGenerated text:\n", generated)
if __name__ == "__main__":
main()