Update train.py
Browse files
train.py
CHANGED
|
@@ -2,8 +2,11 @@ import os, pickle, json, torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from model import GPT, GPTConfig
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
chars = sorted(list(set(text)))
|
| 9 |
vocab_size = len(chars)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from model import GPT, GPTConfig
|
| 4 |
|
| 5 |
+
# Load both original and extra data
|
| 6 |
+
with open("data/ai_gf/input.txt", "r", encoding="utf-8") as f1, \
|
| 7 |
+
open("data/ai_gf/input_extra.txt", "r", encoding="utf-8") as f2:
|
| 8 |
+
text = f1.read() + "\n\n" + f2.read()
|
| 9 |
+
|
| 10 |
|
| 11 |
chars = sorted(list(set(text)))
|
| 12 |
vocab_size = len(chars)
|