| import os |
| import glob |
| import pandas as pd |
| import argparse |
| from google import genai |
| from tqdm import tqdm |
| import time |
| import re |
| from word_segmentation_vi import word_segmentation_vi |
|
|
| def setup_genai(api_key): |
| """Configure the Google Generative AI client with your API key""" |
| return genai.Client(api_key=api_key) |
|
|
| def classify_text(model, text, suggest_label=False): |
| """Classify Vietnamese text into hate speech categories using Google's Generative AI""" |
| prompt = f""" |
| Analyze the following Vietnamese text for hate speech (each sentence is separated by a newline): |
| "{text}" |
| |
| Rate it on these categories (0=NORMAL, 1=CLEAN, 2=OFFENSIVE, 3=HATE): |
| - individual (targeting specific individuals) |
| - groups (targeting groups or organizations) |
| - religion/creed (targeting religious groups or beliefs) |
| - race/ethnicity (racial/ethnic hate speech) |
| - politics (political hate speech) |
| If the text doesn't specify a person or group in a category, return 0 for that category. |
| Else, return 1 for CLEAN, 2 for OFFENSIVE, or 3 for HATE. |
| |
| {'The number at the end of the sentence (between <SuggestLabel> and </SuggestLabel> tags is the suggestion label for the sentence. (0 is normal/clean, 1 is offensive/hate in at least one category)' if suggest_label else ''} |
| |
| For each sentence in the text, return only 5 numbers separated by commas (corresponding to the label of individual, groups, religion/creed, race/ethnicity, politics) and numbers for each sentence seperated by newlines, like (with no other text): |
| 0,1,0,0,0 |
| 1,0,0,0,2 |
| """ |
| |
| try: |
| response = model.models.generate_content(model="gemini-2.0-flash", contents=prompt) |
| values = response.text.strip().split('\n') |
| values = [line.split(',') for line in values] |
| return values |
| |
| except Exception as e: |
| print(f"Error classifying text: {e}") |
| return None |
|
|
| def process_file(input_file, output_file, model, rate_limit_pause=4, text_col="free_text", suggest_column="labels"): |
| """Process a single CSV file to match the test.csv format""" |
| print(f"Processing {input_file}...") |
| |
| |
| try: |
| df = pd.read_csv(input_file) |
| except Exception as e: |
| print(f"Error reading {input_file}: {e}") |
| return |
| |
| |
| if text_col in df.columns: |
| df.rename(columns={text_col: 'content'}, inplace=True) |
| elif 'content' not in df.columns: |
| print(f"Error: 'content' column not found in {input_file}") |
| return |
| |
| |
| category_columns = ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics'] |
| for col in category_columns: |
| if col not in df.columns: |
| |
| df[col] = 0 |
|
|
| print("Suggesting labels: ", 'True' if suggest_column in df.columns else 'False') |
| |
| |
| batch_size = 100 |
| for start in tqdm(range(0, len(df), batch_size), desc="Processing batches"): |
| end = min(start + batch_size, len(df)) |
| batch_df = df.iloc[start:end] |
| |
| |
| if all(batch_df[cat].all() != 0 for cat in category_columns): |
| continue |
| |
| |
| batch_strings = [str(sentence) for sentence in batch_df['content'].tolist()] |
| suggest_label = False |
| if suggest_column in df.columns: |
| batch_strings = [str(sentence) + " " + f"<SuggestLabel>{str(label)}</SuggestLabel>" for sentence, label in zip(batch_strings, batch_df[suggest_column].tolist())] |
| suggest_label = True |
|
|
|
|
| text_to_classify = "\n".join(batch_strings) |
| classifications = classify_text(model, text_to_classify, suggest_label=suggest_label) |
|
|
|
|
| |
| if classifications is None: |
| for _ in range(2): |
| classifications = classify_text(model, text_to_classify) |
| if classifications is not None: |
| break |
| time.sleep(rate_limit_pause) |
| else: |
| print(f"Error classifying batch starting at index {start}. Skipping...") |
| continue |
|
|
| try: |
| |
| for i, row in enumerate(classifications): |
| for j, col in enumerate(category_columns): |
| df.at[start + i, col] = int(row[j]) |
| except Exception as e: |
| for _ in range(2): |
| classifications = classify_text(model, text_to_classify) |
| if classifications is not None: |
| break |
| time.sleep(rate_limit_pause) |
| else: |
| print(f"Error classifying batch starting at index {start}. Skipping...") |
| continue |
| |
| try: |
| for i, row in enumerate(classifications): |
| for j, col in enumerate(category_columns): |
| df.at[start + i, col] = int(row[j]) |
| except Exception as e: |
| print(f"Error updating DataFrame: {e}") |
| continue |
| |
| time.sleep(rate_limit_pause) |
| |
| |
| df['content'] = df['content'].apply(lambda x: word_segmentation_vi(str(x))) |
| |
| |
| for col in category_columns: |
| df[col] = df[col].astype(int) |
| |
| if 'label_id' in df.columns: |
| df.drop(columns=['label_id'], inplace=True) |
| df.to_csv(output_file, index=False) |
| print(f"Saved processed file to {output_file}") |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Process ViHSD CSV files with Google Generative AI") |
| parser.add_argument("--input_dir", required=True, help="Directory containing input CSV files") |
| parser.add_argument("--output_dir", required=True, help="Directory to save processed files") |
| parser.add_argument("--api_key", required=True, help="Google Generative AI API key") |
| parser.add_argument("--pause", type=float, default=4.0, help="Pause between API calls (seconds)") |
| parser.add_argument("--text_col", default="free_text", help="Column name for text content in input CSV files") |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| model = setup_genai(args.api_key) |
| |
| |
| csv_files = glob.glob(os.path.join(args.input_dir, "*.csv")) |
| if not csv_files: |
| print(f"No CSV files found in {args.input_dir}") |
| return |
| |
| print(f"Found {len(csv_files)} CSV files to process") |
| |
| |
| for input_file in csv_files: |
| output_file = os.path.join(args.output_dir, os.path.basename(input_file)) |
| if os.path.exists(output_file): |
| print(f"Output file {output_file} already exists. Skipping...") |
| continue |
| process_file(input_file, output_file, model, args.pause, text_col=args.text_col) |
|
|
| if __name__ == "__main__": |
| |
| |
| main() |