Spaces:
Sleeping
Sleeping
Teslim Olunlade
commited on
Commit
·
ede725f
1
Parent(s):
5b67bcb
Fixed trained model
Browse files- app/train.py +4 -3
app/train.py
CHANGED
|
@@ -11,6 +11,7 @@ from transformers import (
|
|
| 11 |
TFAutoModelForSequenceClassification,
|
| 12 |
)
|
| 13 |
|
|
|
|
| 14 |
output_dir = "out/model"
|
| 15 |
checkpoint_path = "out/cp.ckpt"
|
| 16 |
|
|
@@ -19,7 +20,7 @@ label2id = {label: id for id, label in enumerate(labels)}
|
|
| 19 |
id2label = {id: label for id, label in enumerate(labels)}
|
| 20 |
|
| 21 |
dataset = load_dataset("csv", data_files="train.csv")
|
| 22 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 23 |
|
| 24 |
|
| 25 |
def process_data(row):
|
|
@@ -40,7 +41,7 @@ def process_data(row):
|
|
| 40 |
|
| 41 |
|
| 42 |
model = TFAutoModelForSequenceClassification.from_pretrained(
|
| 43 |
-
|
| 44 |
problem_type="multi_label_classification",
|
| 45 |
num_labels=len(labels),
|
| 46 |
label2id=label2id,
|
|
@@ -67,5 +68,5 @@ push_to_hub_callback = PushToHubCallback(
|
|
| 67 |
hub_model_id="ogtega/tweet-toxicity-classifier",
|
| 68 |
)
|
| 69 |
|
| 70 |
-
model.compile(optimizer=Adam(3e-5), loss="
|
| 71 |
model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])
|
|
|
|
| 11 |
TFAutoModelForSequenceClassification,
|
| 12 |
)
|
| 13 |
|
| 14 |
+
base_model = "distilbert-base-uncased"
|
| 15 |
output_dir = "out/model"
|
| 16 |
checkpoint_path = "out/cp.ckpt"
|
| 17 |
|
|
|
|
| 20 |
id2label = {id: label for id, label in enumerate(labels)}
|
| 21 |
|
| 22 |
dataset = load_dataset("csv", data_files="train.csv")
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 24 |
|
| 25 |
|
| 26 |
def process_data(row):
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
model = TFAutoModelForSequenceClassification.from_pretrained(
|
| 44 |
+
base_model,
|
| 45 |
problem_type="multi_label_classification",
|
| 46 |
num_labels=len(labels),
|
| 47 |
label2id=label2id,
|
|
|
|
| 68 |
hub_model_id="ogtega/tweet-toxicity-classifier",
|
| 69 |
)
|
| 70 |
|
| 71 |
+
model.compile(optimizer=Adam(3e-5), loss="BinaryCrossentropy")
|
| 72 |
model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])
|