Spaces:
Sleeping
Sleeping
Teslim Olunlade
commited on
Commit
·
75c005d
1
Parent(s):
25ac0e5
Successfully trained
Browse files- app/main.py +3 -3
- app/train.py +54 -0
- requirements.txt +1 -0
app/main.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import AutoTokenizer
|
| 3 |
from transformers import (
|
|
@@ -15,9 +17,7 @@ model_name = st.selectbox(
|
|
| 15 |
"Select the model you want to use below.",
|
| 16 |
(
|
| 17 |
"distilbert-base-uncased-finetuned-sst-2-english",
|
| 18 |
-
"
|
| 19 |
-
"finiteautomata/bertweet-base-sentiment-analysis",
|
| 20 |
-
"ProsusAI/finbert",
|
| 21 |
),
|
| 22 |
)
|
| 23 |
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
from transformers import AutoTokenizer
|
| 5 |
from transformers import (
|
|
|
|
| 17 |
"Select the model you want to use below.",
|
| 18 |
(
|
| 19 |
"distilbert-base-uncased-finetuned-sst-2-english",
|
| 20 |
+
"roberta-large-mnli",
|
|
|
|
|
|
|
| 21 |
),
|
| 22 |
)
|
| 23 |
|
app/train.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
TFAutoModelForSequenceClassification,
|
| 10 |
+
TFTrainer,
|
| 11 |
+
TFTrainingArguments,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
| 15 |
+
label2id = {label: id for id, label in enumerate(labels)}
|
| 16 |
+
id2label = {id: label for id, label in enumerate(labels)}
|
| 17 |
+
|
| 18 |
+
data = pd.read_csv("./train.csv")
|
| 19 |
+
|
| 20 |
+
batch_encodings = defaultdict(list)
|
| 21 |
+
batch_labels = list()
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
| 23 |
+
|
| 24 |
+
for _, row in data.sample(n=3).iterrows():
|
| 25 |
+
text = row["comment_text"]
|
| 26 |
+
text_labels = {k: row[k] for k in row.keys() if k in labels}
|
| 27 |
+
|
| 28 |
+
encoding = tokenizer(text, padding="max_length", truncation=True)
|
| 29 |
+
|
| 30 |
+
batch_labels.append([text_labels[id2label[id]] for id in range(len(labels))])
|
| 31 |
+
|
| 32 |
+
for key in encoding.keys():
|
| 33 |
+
batch_encodings[key].append(encoding[key])
|
| 34 |
+
|
| 35 |
+
model = TFAutoModelForSequenceClassification.from_pretrained(
|
| 36 |
+
"bert-base-cased",
|
| 37 |
+
problem_type="multi_label_classification",
|
| 38 |
+
num_labels=len(labels),
|
| 39 |
+
label2id=label2id,
|
| 40 |
+
id2label=id2label,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
training_args = TFTrainingArguments(output_dir="test_trainer")
|
| 44 |
+
train_dataset = tf.data.Dataset.from_tensor_slices(
|
| 45 |
+
(dict(batch_encodings), batch_labels)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
trainer = TFTrainer(
|
| 49 |
+
model=model,
|
| 50 |
+
args=training_args,
|
| 51 |
+
train_dataset=train_dataset,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
trainer.train()
|
requirements.txt
CHANGED
|
@@ -7,6 +7,7 @@ cachetools==5.3.0
|
|
| 7 |
certifi==2022.12.7
|
| 8 |
charset-normalizer==3.1.0
|
| 9 |
click==8.1.3
|
|
|
|
| 10 |
decorator==5.1.1
|
| 11 |
entrypoints==0.4
|
| 12 |
filelock==3.10.7
|
|
|
|
| 7 |
certifi==2022.12.7
|
| 8 |
charset-normalizer==3.1.0
|
| 9 |
click==8.1.3
|
| 10 |
+
datasets==2.11.0
|
| 11 |
decorator==5.1.1
|
| 12 |
entrypoints==0.4
|
| 13 |
filelock==3.10.7
|