pierjoe commited on
Commit
e2eaa11
·
verified ·
1 Parent(s): 8cf5d34

Upload minitransformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. minitransformer.py +9 -4
minitransformer.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from transformers import AutoTokenizer, logging
8
  import pandas as pd
9
  from tqdm import tqdm
10
-
11
 
12
  logging.set_verbosity_error()
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -15,14 +15,14 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
  # ----------------- CONFIG -----------------
16
  SAVE_EVERY = 5
17
  MODEL_NAME = "mini_transformer_v3"
18
- N_DATA_WORKERS = 6
19
  PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False
20
- BATCH_SIZE = 128
21
  EVAL_EVERY = 5
22
  LEARNING_RATE = 3e-4
23
  NUM_EPOCHS = 50
24
  USE_AMP = True
25
- STRIDE = 32
26
  CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}"
27
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
28
  DATASET = "DATA/generated_dataset_very_big.csv"
@@ -263,6 +263,11 @@ for epoch in range(start_epoch, NUM_EPOCHS):
263
  },
264
  os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"),
265
  )
 
 
 
 
 
266
 
267
  # check GPU utilization metrics here:
268
  # nvidia-smi dmon -s u
 
7
  from transformers import AutoTokenizer, logging
8
  import pandas as pd
9
  from tqdm import tqdm
10
+ from safetensors.torch import save_file
11
 
12
  logging.set_verbosity_error()
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
15
  # ----------------- CONFIG -----------------
16
  SAVE_EVERY = 5
17
  MODEL_NAME = "mini_transformer_v3"
18
+ N_DATA_WORKERS = 8
19
  PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False
20
+ BATCH_SIZE = 512
21
  EVAL_EVERY = 5
22
  LEARNING_RATE = 3e-4
23
  NUM_EPOCHS = 50
24
  USE_AMP = True
25
+ STRIDE = 64
26
  CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}"
27
  os.makedirs(CHECKPOINT_DIR, exist_ok=True)
28
  DATASET = "DATA/generated_dataset_very_big.csv"
 
263
  },
264
  os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"),
265
  )
266
+ save_file(
267
+ model.state_dict(),
268
+ os.path.join(CHECKPOINT_DIR, f"model_{epoch+1}.safetensors"),
269
+ )
270
+
271
 
272
  # check GPU utilization metrics here:
273
  # nvidia-smi dmon -s u