from datasets import load_dataset from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.neural_network import MLPClassifier from sklearn.pipeline import Pipeline from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report import numpy as np import joblib # ---- Load Aegis 2.0 ---- ds = load_dataset("nvidia/Aegis-AI-Content-Safety-Dataset-2.0") TEXT_COL = "response" LABEL_COL = "response_label" # ---- Binary mapping: safe -> 1, everything else -> 0 ---- SAFE_TOKENS = {"safe"} # keep lowercase; Aegis uses "safe" / "needs_caution" / unsafe categories def to_binary_label(raw): if raw is None: return 0 raw = str(raw).strip().lower() return 1 if raw in SAFE_TOKENS else 0 train = ds["train"] # Filter out empty/missing texts records = [r for r in train if r.get(TEXT_COL) and isinstance(r[TEXT_COL], str) and r[TEXT_COL].strip()] X = [r[TEXT_COL].strip() for r in records] y = [to_binary_label(r.get(LABEL_COL)) for r in records] # Train/val split, test_size=15% X_temp, X_test, y_temp, y_test = train_test_split( X, y, test_size=0.15, random_state=42, stratify=y ) # split the remaining data into train/val (e.g. 85% -> 70% train, 15% val) X_train, X_val, y_train, y_val = train_test_split( X_temp, y_temp, test_size=0.1765, random_state=42, stratify=y_temp ) print(f"Train size: {len(X_train)}, Val size: {len(X_val)}, Test size: {len(X_test)}") # ---- MLP baseline ---- pipe = Pipeline([ ("tfidf", TfidfVectorizer(max_features=100_000, ngram_range=(1,2), min_df=3)), ("clf", MLPClassifier(hidden_layer_sizes=(128, 64), activation="relu", batch_size=256, early_stopping=True, #to stop if no val improvement max_iter=10, verbose=True, random_state=42)) ]) pipe.fit(X_train, y_train) print("Validation results:") pred_val = pipe.predict(X_val) print(classification_report(y_val, pred_val, digits=3)) print("Test results:") pred_test = pipe.predict(X_test) print(classification_report(y_test, pred_test, digits=3)) print("Train accuracy:", pipe.score(X_train, y_train)) print("Val accuracy:", pipe.score(X_val, y_val)) print("Test accuracy:", pipe.score(X_test, y_test)) joblib.dump(pipe, "mlp_tfidf_aegis2.joblib") print("Saved to mlp_tfidf_aegis2.joblib")