Commit ·
e2d86a3
1
Parent(s): e5c407e
Update datasets
Browse files- api.py +1 -1
- app.py +25 -0
- datasets_voz_gemini/dev.csv +3 -0
- datasets_voz_gemini/test.csv +3 -0
- datasets_voz_gemini/train.csv +3 -0
- scripts/merge_datasets.py +1 -1
- utils/convert_vihsd_gemini.py +18 -5
api.py
CHANGED
|
@@ -40,7 +40,7 @@ def load_model_lstm():
|
|
| 40 |
model = model.to(device)
|
| 41 |
return model, device
|
| 42 |
|
| 43 |
-
def inference(model, device, comments: str | list, threshold: float = 0.
|
| 44 |
if isinstance(comments, str):
|
| 45 |
comments = [comments]
|
| 46 |
elif not isinstance(comments, list):
|
|
|
|
| 40 |
model = model.to(device)
|
| 41 |
return model, device
|
| 42 |
|
| 43 |
+
def inference(model, device, comments: str | list, threshold: float = 0.6):
|
| 44 |
if isinstance(comments, str):
|
| 45 |
comments = [comments]
|
| 46 |
elif not isinstance(comments, list):
|
app.py
CHANGED
|
@@ -1,6 +1,31 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from api import load_model_bert, load_model_lstm, inference
|
| 3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Set up the Streamlit app
|
| 6 |
def app():
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from api import load_model_bert, load_model_lstm, inference
|
| 3 |
import pandas as pd
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Download the model files from Hugging Face Hub: https://huggingface.co/jesse-tong/vietnamese_hate_speech_detection_phobert
|
| 8 |
+
# to vietnamese_hate_speech_detection_phobert directory
|
| 9 |
+
if os.path.exists("vietnamese_hate_speech_detection_phobert") == False:
|
| 10 |
+
try:
|
| 11 |
+
os.mkdir("vietnamese_hate_speech_detection_phobert")
|
| 12 |
+
except FileExistsError:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
# Download the model files
|
| 16 |
+
hf_hub_download(
|
| 17 |
+
repo_id="jesse-tong/vietnamese_hate_speech_detection_phobert",
|
| 18 |
+
filename="vinai_phobert-base-v2_finetuned.pth",
|
| 19 |
+
repo_type="model",
|
| 20 |
+
local_dir="vietnamese_hate_speech_detection_phobert"
|
| 21 |
+
)
|
| 22 |
+
hf_hub_download(
|
| 23 |
+
repo_id="jesse-tong/vietnamese_hate_speech_detection_phobert",
|
| 24 |
+
filename="distilled_lstm_model.pth",
|
| 25 |
+
repo_type="model",
|
| 26 |
+
local_dir="vietnamese_hate_speech_detection_phobert"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
|
| 30 |
# Set up the Streamlit app
|
| 31 |
def app():
|
datasets_voz_gemini/dev.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ee6611003e74122c4ff00a33949c0be5d64485afe08968b150b55883fd2c9b6
|
| 3 |
+
size 865240
|
datasets_voz_gemini/test.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d29e8328cdb0f415c32981bfe008fb59fd1ad7a84cf4e1948e1f16855205f35
|
| 3 |
+
size 1814883
|
datasets_voz_gemini/train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e30f5595d43adcc84998a31956b47343ccfb75b062fb8965ee2fc503b06111c6
|
| 3 |
+
size 8002834
|
scripts/merge_datasets.py
CHANGED
|
@@ -19,7 +19,7 @@ def merge_csv_files(directories, output_file, target_name):
|
|
| 19 |
merged.to_csv(output_file, index=False)
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
-
directories = ["../datasets_vithsd", "../datasets_vihsd_gemini"]
|
| 23 |
merge_csv_files(directories, "../datasets/train.csv", "train.csv")
|
| 24 |
merge_csv_files(directories, "../datasets/dev.csv", "dev.csv")
|
| 25 |
merge_csv_files(directories, "../datasets/test.csv", "test.csv")
|
|
|
|
| 19 |
merged.to_csv(output_file, index=False)
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
+
directories = ["../datasets_vithsd", "../datasets_vihsd_gemini" , "../datasets_voz_gemini"]
|
| 23 |
merge_csv_files(directories, "../datasets/train.csv", "train.csv")
|
| 24 |
merge_csv_files(directories, "../datasets/dev.csv", "dev.csv")
|
| 25 |
merge_csv_files(directories, "../datasets/test.csv", "test.csv")
|
utils/convert_vihsd_gemini.py
CHANGED
|
@@ -12,7 +12,7 @@ def setup_genai(api_key):
|
|
| 12 |
"""Configure the Google Generative AI client with your API key"""
|
| 13 |
return genai.Client(api_key=api_key)
|
| 14 |
|
| 15 |
-
def classify_text(model, text):
|
| 16 |
"""Classify Vietnamese text into hate speech categories using Google's Generative AI"""
|
| 17 |
prompt = f"""
|
| 18 |
Analyze the following Vietnamese text for hate speech (each sentence is separated by a newline):
|
|
@@ -26,6 +26,8 @@ def classify_text(model, text):
|
|
| 26 |
- politics (political hate speech)
|
| 27 |
If the text doesn't specify a person or group in a category, return 0 for that category.
|
| 28 |
Else, return 1 for CLEAN, 2 for OFFENSIVE, or 3 for HATE.
|
|
|
|
|
|
|
| 29 |
|
| 30 |
For each sentence in the text, return only 5 numbers separated by commas (corresponding to the label of individual, groups, religion/creed, race/ethnicity, politics) and numbers for each sentence seperated by newlines, like (with no other text):
|
| 31 |
0,1,0,0,0
|
|
@@ -42,7 +44,7 @@ def classify_text(model, text):
|
|
| 42 |
print(f"Error classifying text: {e}")
|
| 43 |
return None
|
| 44 |
|
| 45 |
-
def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text"):
|
| 46 |
"""Process a single CSV file to match the test.csv format"""
|
| 47 |
print(f"Processing {input_file}...")
|
| 48 |
|
|
@@ -66,6 +68,8 @@ def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="f
|
|
| 66 |
if col not in df.columns:
|
| 67 |
# Change column type to int if it doesn't exist
|
| 68 |
df[col] = 0
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Process each batch (100 rows at a time)
|
| 71 |
batch_size = 100
|
|
@@ -78,8 +82,16 @@ def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="f
|
|
| 78 |
continue
|
| 79 |
|
| 80 |
# Join 50 rows by newlines, and classify all at once
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Try 2 more times, else skip
|
| 85 |
if classifications is None:
|
|
@@ -135,6 +147,7 @@ def main():
|
|
| 135 |
parser.add_argument("--output_dir", required=True, help="Directory to save processed files")
|
| 136 |
parser.add_argument("--api_key", required=True, help="Google Generative AI API key")
|
| 137 |
parser.add_argument("--pause", type=float, default=4.0, help="Pause between API calls (seconds)")
|
|
|
|
| 138 |
|
| 139 |
args = parser.parse_args()
|
| 140 |
|
|
@@ -158,7 +171,7 @@ def main():
|
|
| 158 |
if os.path.exists(output_file):
|
| 159 |
print(f"Output file {output_file} already exists. Skipping...")
|
| 160 |
continue
|
| 161 |
-
process_file(input_file, output_file, model, args.pause)
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
# This script is used to process ViHSD CSV files with Google Generative AI
|
|
|
|
| 12 |
"""Configure the Google Generative AI client with your API key"""
|
| 13 |
return genai.Client(api_key=api_key)
|
| 14 |
|
| 15 |
+
def classify_text(model, text, suggest_label=False):
|
| 16 |
"""Classify Vietnamese text into hate speech categories using Google's Generative AI"""
|
| 17 |
prompt = f"""
|
| 18 |
Analyze the following Vietnamese text for hate speech (each sentence is separated by a newline):
|
|
|
|
| 26 |
- politics (political hate speech)
|
| 27 |
If the text doesn't specify a person or group in a category, return 0 for that category.
|
| 28 |
Else, return 1 for CLEAN, 2 for OFFENSIVE, or 3 for HATE.
|
| 29 |
+
|
| 30 |
+
{'The number at the end of the sentence (between <SuggestLabel> and </SuggestLabel> tags is the suggestion label for the sentence. (0 is normal/clean, 1 is offensive/hate in at least one category)' if suggest_label else ''}
|
| 31 |
|
| 32 |
For each sentence in the text, return only 5 numbers separated by commas (corresponding to the label of individual, groups, religion/creed, race/ethnicity, politics) and numbers for each sentence seperated by newlines, like (with no other text):
|
| 33 |
0,1,0,0,0
|
|
|
|
| 44 |
print(f"Error classifying text: {e}")
|
| 45 |
return None
|
| 46 |
|
| 47 |
+
def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text", suggest_column="labels"):
|
| 48 |
"""Process a single CSV file to match the test.csv format"""
|
| 49 |
print(f"Processing {input_file}...")
|
| 50 |
|
|
|
|
| 68 |
if col not in df.columns:
|
| 69 |
# Change column type to int if it doesn't exist
|
| 70 |
df[col] = 0
|
| 71 |
+
|
| 72 |
+
print("Suggesting labels: ", 'True' if suggest_column in df.columns else 'False')
|
| 73 |
|
| 74 |
# Process each batch (100 rows at a time)
|
| 75 |
batch_size = 100
|
|
|
|
| 82 |
continue
|
| 83 |
|
| 84 |
# Join 50 rows by newlines, and classify all at once
|
| 85 |
+
batch_strings = [str(sentence) for sentence in batch_df['content'].tolist()]
|
| 86 |
+
suggest_label = False
|
| 87 |
+
if suggest_column in df.columns:
|
| 88 |
+
batch_strings = [str(sentence) + " " + f"<SuggestLabel>{str(label)}</SuggestLabel>" for sentence, label in zip(batch_strings, batch_df[suggest_column].tolist())]
|
| 89 |
+
suggest_label = True
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
text_to_classify = "\n".join(batch_strings)
|
| 93 |
+
classifications = classify_text(model, text_to_classify, suggest_label=suggest_label)
|
| 94 |
+
|
| 95 |
|
| 96 |
# Try 2 more times, else skip
|
| 97 |
if classifications is None:
|
|
|
|
| 147 |
parser.add_argument("--output_dir", required=True, help="Directory to save processed files")
|
| 148 |
parser.add_argument("--api_key", required=True, help="Google Generative AI API key")
|
| 149 |
parser.add_argument("--pause", type=float, default=4.0, help="Pause between API calls (seconds)")
|
| 150 |
+
parser.add_argument("--text_col", default="free_text", help="Column name for text content in input CSV files")
|
| 151 |
|
| 152 |
args = parser.parse_args()
|
| 153 |
|
|
|
|
| 171 |
if os.path.exists(output_file):
|
| 172 |
print(f"Output file {output_file} already exists. Skipping...")
|
| 173 |
continue
|
| 174 |
+
process_file(input_file, output_file, model, args.pause, text_col=args.text_col)
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
# This script is used to process ViHSD CSV files with Google Generative AI
|