Teslim Olunlade commited on
Commit
10a61c2
·
1 Parent(s): 131f8ea

Added model saving

Browse files
Files changed (1) hide show
  1. app/train.py +9 -1
app/train.py CHANGED
@@ -1,10 +1,13 @@
1
  #!/usr/bin/env python3
2
 
3
  import numpy as np
 
4
  from datasets import load_dataset
5
  from tensorflow.keras.optimizers import Adam
6
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
7
 
 
 
8
  labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
9
  label2id = {label: id for id, label in enumerate(labels)}
10
  id2label = {id: label for id, label in enumerate(labels)}
@@ -46,5 +49,10 @@ tf_dataset = model.prepare_tf_dataset(
46
  encoded["train"], batch_size=16, shuffle=True, tokenizer=tokenizer
47
  )
48
 
 
 
 
 
49
  model.compile(optimizer=Adam(3e-5), loss="categorical_crossentropy")
50
- model.fit(tf_dataset)
 
 
1
  #!/usr/bin/env python3
2
 
3
  import numpy as np
4
+ import tensorflow as tf
5
  from datasets import load_dataset
6
  from tensorflow.keras.optimizers import Adam
7
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
8
 
9
+ checkpoint_path = "out/cp.ckpt"
10
+
11
  labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
12
  label2id = {label: id for id, label in enumerate(labels)}
13
  id2label = {id: label for id, label in enumerate(labels)}
 
49
  encoded["train"], batch_size=16, shuffle=True, tokenizer=tokenizer
50
  )
51
 
52
+ cp_callback = tf.keras.callbacks.ModelCheckpoint(
53
+ filepath=checkpoint_path, save_weights_only=True, verbose=1
54
+ )
55
+
56
  model.compile(optimizer=Adam(3e-5), loss="categorical_crossentropy")
57
+ model.fit(tf_dataset, callbacks=[cp_callback])
58
+ model.save('out/model')