Teslim Olunlade commited on
Commit
ede725f
·
1 Parent(s): 5b67bcb

Fixed trained model

Browse files
Files changed (1) hide show
  1. app/train.py +4 -3
app/train.py CHANGED
@@ -11,6 +11,7 @@ from transformers import (
11
  TFAutoModelForSequenceClassification,
12
  )
13
 
 
14
  output_dir = "out/model"
15
  checkpoint_path = "out/cp.ckpt"
16
 
@@ -19,7 +20,7 @@ label2id = {label: id for id, label in enumerate(labels)}
19
  id2label = {id: label for id, label in enumerate(labels)}
20
 
21
  dataset = load_dataset("csv", data_files="train.csv")
22
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
23
 
24
 
25
  def process_data(row):
@@ -40,7 +41,7 @@ def process_data(row):
40
 
41
 
42
  model = TFAutoModelForSequenceClassification.from_pretrained(
43
- "bert-base-uncased",
44
  problem_type="multi_label_classification",
45
  num_labels=len(labels),
46
  label2id=label2id,
@@ -67,5 +68,5 @@ push_to_hub_callback = PushToHubCallback(
67
  hub_model_id="ogtega/tweet-toxicity-classifier",
68
  )
69
 
70
- model.compile(optimizer=Adam(3e-5), loss="categorical_crossentropy")
71
  model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])
 
11
  TFAutoModelForSequenceClassification,
12
  )
13
 
14
+ base_model = "distilbert-base-uncased"
15
  output_dir = "out/model"
16
  checkpoint_path = "out/cp.ckpt"
17
 
 
20
  id2label = {id: label for id, label in enumerate(labels)}
21
 
22
  dataset = load_dataset("csv", data_files="train.csv")
23
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
24
 
25
 
26
  def process_data(row):
 
41
 
42
 
43
  model = TFAutoModelForSequenceClassification.from_pretrained(
44
+ base_model,
45
  problem_type="multi_label_classification",
46
  num_labels=len(labels),
47
  label2id=label2id,
 
68
  hub_model_id="ogtega/tweet-toxicity-classifier",
69
  )
70
 
71
+ model.compile(optimizer=Adam(3e-5), loss="BinaryCrossentropy")
72
  model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])