Upload 6 files
Browse files- ferret_faithfullness.py +157 -0
- ferret_plausibility.py +313 -0
- hyperparameter_tuning_for_rationale.py +326 -0
- hyperparameter_tuning_without_rationale.py +165 -0
- model_training_with_rationale.py +318 -0
- model_training_without_rationale.py +251 -0
ferret_faithfullness.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 7 |
+
from ferret import Benchmark
|
| 8 |
+
from scipy.stats import rankdata
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# ================================================================
|
| 12 |
+
# CONFIGURATION
|
| 13 |
+
# ================================================================
|
| 14 |
+
|
| 15 |
+
# Set your Hugging Face model repo name here
|
| 16 |
+
hf_model_name = "PLACE_YOUR_MODEL_NAME"
|
| 17 |
+
|
| 18 |
+
# CSV test file (expected in current directory)
|
| 19 |
+
test_file = "test.csv"
|
| 20 |
+
|
| 21 |
+
# Batch sizes
|
| 22 |
+
prediction_batch_size = 64
|
| 23 |
+
ferret_batch_size = 1
|
| 24 |
+
|
| 25 |
+
# Label mapping
|
| 26 |
+
label_map = {
|
| 27 |
+
"Negative": 0,
|
| 28 |
+
"Neutral": 1,
|
| 29 |
+
"Positive": 2
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Device
|
| 33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
print(f"[INFO] Using device: {device}")
|
| 35 |
+
|
| 36 |
+
# ================================================================
|
| 37 |
+
# LOAD TEST DATA
|
| 38 |
+
# ================================================================
|
| 39 |
+
|
| 40 |
+
if not os.path.exists(test_file):
|
| 41 |
+
raise FileNotFoundError(f"Test file not found: {test_file}")
|
| 42 |
+
|
| 43 |
+
df_full = pd.read_csv(test_file)
|
| 44 |
+
df_full["final_label"] = df_full["final_label"].astype("category")
|
| 45 |
+
df_full["final_label_numeric"] = df_full["final_label"].map(label_map)
|
| 46 |
+
|
| 47 |
+
texts_all = df_full["Content"].tolist()
|
| 48 |
+
labels_all = df_full["final_label"].tolist()
|
| 49 |
+
print(f"[INFO] ✅ Loaded test data: {len(df_full)} rows.")
|
| 50 |
+
|
| 51 |
+
# ================================================================
|
| 52 |
+
# PIPELINE FOR SINGLE MODEL (MATCHED ONLY)
|
| 53 |
+
# ================================================================
|
| 54 |
+
|
| 55 |
+
print(f"\n==============================")
|
| 56 |
+
print(f"[INFO] 🚀 Starting pipeline for model: {hf_model_name}")
|
| 57 |
+
print(f"==============================")
|
| 58 |
+
|
| 59 |
+
# -----------------------------
|
| 60 |
+
# LOAD MODEL & TOKENIZER
|
| 61 |
+
# -----------------------------
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
|
| 63 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 64 |
+
hf_model_name,
|
| 65 |
+
trust_remote_code=True,
|
| 66 |
+
use_safetensors=True
|
| 67 |
+
)
|
| 68 |
+
model = model.to(device)
|
| 69 |
+
model.eval()
|
| 70 |
+
print("[INFO] ✅ Model loaded.")
|
| 71 |
+
|
| 72 |
+
# -----------------------------
|
| 73 |
+
# PREDICTIONS
|
| 74 |
+
# -----------------------------
|
| 75 |
+
predictions = []
|
| 76 |
+
print("[INFO] 🔎 Predicting on test set...")
|
| 77 |
+
|
| 78 |
+
for i in tqdm(range(0, len(texts_all), prediction_batch_size), desc="Predicting"):
|
| 79 |
+
batch_texts = texts_all[i : i + prediction_batch_size]
|
| 80 |
+
inputs = tokenizer(
|
| 81 |
+
batch_texts,
|
| 82 |
+
padding=True,
|
| 83 |
+
truncation=True,
|
| 84 |
+
return_tensors="pt",
|
| 85 |
+
max_length=256
|
| 86 |
+
)
|
| 87 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model(**inputs)
|
| 90 |
+
preds = torch.argmax(outputs.logits, dim=1).cpu().tolist()
|
| 91 |
+
predictions.extend(preds)
|
| 92 |
+
|
| 93 |
+
# Store predictions
|
| 94 |
+
df = df_full.copy()
|
| 95 |
+
df["prediction"] = predictions
|
| 96 |
+
df["prediction"] = df["prediction"].astype("int8")
|
| 97 |
+
|
| 98 |
+
predictions_filename = f"{hf_model_name.replace('/', '_')}_predictions.csv"
|
| 99 |
+
df.to_csv(predictions_filename, index=False)
|
| 100 |
+
print(f"[INFO] ✅ Predictions saved to {predictions_filename}.")
|
| 101 |
+
|
| 102 |
+
# -----------------------------
|
| 103 |
+
# SPLIT MATCHED ONLY
|
| 104 |
+
# -----------------------------
|
| 105 |
+
matched_df = df[df["prediction"] == df["final_label_numeric"]].reset_index(drop=True)
|
| 106 |
+
print(f"[INFO] ✅ {len(matched_df)} matched rows retained.")
|
| 107 |
+
|
| 108 |
+
# Save matched for records
|
| 109 |
+
matched_df.to_csv(f"{hf_model_name.replace('/', '_')}_matched.csv", index=False)
|
| 110 |
+
|
| 111 |
+
# -----------------------------
|
| 112 |
+
# FERRET ON MATCHED
|
| 113 |
+
# -----------------------------
|
| 114 |
+
if len(matched_df) > 0:
|
| 115 |
+
ferret_rows = []
|
| 116 |
+
bench = Benchmark(model, tokenizer)
|
| 117 |
+
print(f"[INFO] 🚀 Running FERRET on matched rows...")
|
| 118 |
+
for i in tqdm(range(0, len(matched_df), ferret_batch_size), desc="FERRET (Matched)"):
|
| 119 |
+
batch = matched_df.iloc[i : i + ferret_batch_size]
|
| 120 |
+
for _, row in batch.iterrows():
|
| 121 |
+
text = row["Content"]
|
| 122 |
+
label = int(row["final_label_numeric"])
|
| 123 |
+
try:
|
| 124 |
+
explanations = bench.explain(text, target=label)
|
| 125 |
+
evaluations = bench.evaluate_explanations(explanations, target=label)
|
| 126 |
+
except Exception as ex:
|
| 127 |
+
print(f"[WARN] FERRET failed on matched text: {text}\nReason: {ex}")
|
| 128 |
+
continue
|
| 129 |
+
ferret_row = {
|
| 130 |
+
"Text": text,
|
| 131 |
+
"final_label": row["final_label"],
|
| 132 |
+
"final_label_numeric": label,
|
| 133 |
+
"Annotations": row.get("Annotations", ""),
|
| 134 |
+
"Rationale": row.get("Rationale", ""),
|
| 135 |
+
}
|
| 136 |
+
if explanations:
|
| 137 |
+
ferret_row["Tokens"] = " ".join(explanations[0].tokens)
|
| 138 |
+
for expl, evaluation in zip(explanations, evaluations):
|
| 139 |
+
explainer_name = expl.explainer
|
| 140 |
+
scores = expl.scores
|
| 141 |
+
ranks = rankdata(-np.array(scores), method="min")
|
| 142 |
+
ferret_row[f"{explainer_name}_ImportanceScores"] = " ".join(map(str, scores))
|
| 143 |
+
ferret_row[f"{explainer_name}_RankVector"] = " ".join(map(str, ranks))
|
| 144 |
+
if evaluation and hasattr(evaluation, "evaluation_scores"):
|
| 145 |
+
for eval_score in evaluation.evaluation_scores:
|
| 146 |
+
ferret_row[f"{explainer_name}_{eval_score.name}"] = float(eval_score.score)
|
| 147 |
+
ferret_rows.append(ferret_row)
|
| 148 |
+
del explanations
|
| 149 |
+
del evaluations
|
| 150 |
+
gc.collect()
|
| 151 |
+
ferret_filename = f"{hf_model_name.replace('/', '_')}_ferret_matched.csv"
|
| 152 |
+
pd.DataFrame(ferret_rows).to_csv(ferret_filename, index=False)
|
| 153 |
+
print(f"[INFO] ✅ FERRET results saved to {ferret_filename}.")
|
| 154 |
+
else:
|
| 155 |
+
print("[INFO] ⚠ No matched rows to run FERRET on.")
|
| 156 |
+
|
| 157 |
+
print("[INFO] ✅ Pipeline finished for matched rows only!")
|
ferret_plausibility.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import ferret
|
| 3 |
+
from ferret import Benchmark
|
| 4 |
+
import csv
|
| 5 |
+
import gc
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
# ==============================
|
| 13 |
+
# File path and Data Loading
|
| 14 |
+
# ==============================
|
| 15 |
+
input_file = "modelname_ferret_matched.csv" # Columns: 'Text', 'final_label', 'final_label_numeric', 'Annotations', 'Rationale', plus FERRET columns
|
| 16 |
+
|
| 17 |
+
if not os.path.exists(input_file):
|
| 18 |
+
raise FileNotFoundError(f"Input file not found: {input_file}")
|
| 19 |
+
|
| 20 |
+
df = pd.read_csv(input_file)
|
| 21 |
+
|
| 22 |
+
# ==============================
|
| 23 |
+
# Model and Plausibility Configs
|
| 24 |
+
# ==============================
|
| 25 |
+
hf_model_names = [
|
| 26 |
+
#keep your model names as "model_name"
|
| 27 |
+
]
|
| 28 |
+
label_map = {"Negative": 0, "Neutral": 1, "Positive": 2}
|
| 29 |
+
inv_label_map = {v: k for k, v in label_map.items()}
|
| 30 |
+
max_length = 128
|
| 31 |
+
|
| 32 |
+
# ==============================
|
| 33 |
+
# Rationales to Attention Vectors
|
| 34 |
+
# ==============================
|
| 35 |
+
def generate_attention_vectors_from_rationales(df, tokenizer, max_length=128):
|
| 36 |
+
"""
|
| 37 |
+
For each row, generate attention vectors for each annotator.
|
| 38 |
+
Vector is zero if annotator's label does not match final label.
|
| 39 |
+
"""
|
| 40 |
+
all_attention_vectors = []
|
| 41 |
+
token_lengths = []
|
| 42 |
+
max_annotators = 0
|
| 43 |
+
|
| 44 |
+
for _, row in df.iterrows():
|
| 45 |
+
text = str(row["Text"])
|
| 46 |
+
final_label_id = row["final_label_numeric"]
|
| 47 |
+
final_label = inv_label_map[final_label_id]
|
| 48 |
+
|
| 49 |
+
encoding = tokenizer(
|
| 50 |
+
text,
|
| 51 |
+
add_special_tokens=True,
|
| 52 |
+
return_offsets_mapping=True,
|
| 53 |
+
return_attention_mask=False,
|
| 54 |
+
return_token_type_ids=False,
|
| 55 |
+
max_length=max_length,
|
| 56 |
+
truncation=True
|
| 57 |
+
)
|
| 58 |
+
offsets = encoding["offset_mapping"]
|
| 59 |
+
real_token_indices = [i for i, (start, end) in enumerate(offsets) if start != end and start >= 0]
|
| 60 |
+
num_real_tokens = len(real_token_indices)
|
| 61 |
+
token_lengths.append(num_real_tokens)
|
| 62 |
+
|
| 63 |
+
annotations = str(row["Annotations"]).split("|")
|
| 64 |
+
rationales = str(row["Rationale"]).split("|")
|
| 65 |
+
max_annotators = max(max_annotators, len(annotations))
|
| 66 |
+
|
| 67 |
+
row_attention_vectors = []
|
| 68 |
+
|
| 69 |
+
for annot_label, annot_rationale in zip(annotations, rationales):
|
| 70 |
+
vec = np.zeros(num_real_tokens, dtype=np.float32)
|
| 71 |
+
# Set all zeros and skip if annotator label does not match final label
|
| 72 |
+
if not annot_label.strip() or annot_label.split("-")[0].strip() != final_label:
|
| 73 |
+
row_attention_vectors.append(vec)
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
|
| 77 |
+
if not spans:
|
| 78 |
+
row_attention_vectors.append(vec)
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
for span_text in spans:
|
| 82 |
+
start = 0
|
| 83 |
+
while True:
|
| 84 |
+
idx = text.find(span_text, start)
|
| 85 |
+
if idx < 0:
|
| 86 |
+
break
|
| 87 |
+
span_start, span_end = idx, idx + len(span_text)
|
| 88 |
+
for vec_idx, tok_idx in enumerate(real_token_indices):
|
| 89 |
+
tok_start, tok_end = offsets[tok_idx]
|
| 90 |
+
if tok_end > span_start and tok_start < span_end:
|
| 91 |
+
vec[vec_idx] = 1.0
|
| 92 |
+
start = idx + 1
|
| 93 |
+
|
| 94 |
+
row_attention_vectors.append(vec)
|
| 95 |
+
|
| 96 |
+
all_attention_vectors.append(row_attention_vectors)
|
| 97 |
+
|
| 98 |
+
# Write attention vectors to new columns
|
| 99 |
+
for i in range(max_annotators):
|
| 100 |
+
col_name = f"embert_attention_{i+1}"
|
| 101 |
+
col_vectors = []
|
| 102 |
+
for row_vecs, num_tokens in zip(all_attention_vectors, token_lengths):
|
| 103 |
+
if i < len(row_vecs):
|
| 104 |
+
vec_str = " ".join(f"{int(v)}" for v in row_vecs[i])
|
| 105 |
+
else:
|
| 106 |
+
vec_str = " ".join(["0"] * num_tokens)
|
| 107 |
+
col_vectors.append(vec_str)
|
| 108 |
+
df[col_name] = col_vectors
|
| 109 |
+
|
| 110 |
+
return df
|
| 111 |
+
|
| 112 |
+
# ==============================
|
| 113 |
+
# Explanation class for FERRET
|
| 114 |
+
# ==============================
|
| 115 |
+
class Explanation:
|
| 116 |
+
def __init__(self, text, tokens, scores, explainer, target):
|
| 117 |
+
self.text = text
|
| 118 |
+
self.tokens = tokens
|
| 119 |
+
self.scores = np.array(scores, dtype=np.float32)
|
| 120 |
+
self.explainer = explainer
|
| 121 |
+
self.target = target
|
| 122 |
+
|
| 123 |
+
def __repr__(self):
|
| 124 |
+
return f"Explanation(text={self.text!r}, tokens={self.tokens}, scores=array({self.scores}, dtype=float32), explainer={self.explainer!r}, target={self.target})"
|
| 125 |
+
|
| 126 |
+
# ==============================
|
| 127 |
+
# DEVICE SETUP
|
| 128 |
+
# ==============================
|
| 129 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
+
print(f"[INFO] Using device: {device}")
|
| 131 |
+
|
| 132 |
+
# ==============================
|
| 133 |
+
# FERRET PIPELINE
|
| 134 |
+
# ==============================
|
| 135 |
+
for hf_model_name in hf_model_names:
|
| 136 |
+
print(f"\n==============================")
|
| 137 |
+
print(f"[INFO] Starting pipeline for model: {hf_model_name}")
|
| 138 |
+
print(f"==============================")
|
| 139 |
+
|
| 140 |
+
# Load model and tokenizer
|
| 141 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 142 |
+
hf_model_name,
|
| 143 |
+
trust_remote_code=True,
|
| 144 |
+
use_safetensors=True
|
| 145 |
+
)
|
| 146 |
+
model.to(device)
|
| 147 |
+
model.eval()
|
| 148 |
+
|
| 149 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 150 |
+
hf_model_name,
|
| 151 |
+
trust_remote_code=True,
|
| 152 |
+
use_safetensors=True
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
bench = Benchmark(model, tokenizer)
|
| 156 |
+
|
| 157 |
+
df = generate_attention_vectors_from_rationales(df, tokenizer)
|
| 158 |
+
|
| 159 |
+
# List of explainers you want to use
|
| 160 |
+
explainer_names = [
|
| 161 |
+
"Partition SHAP", "LIME", "Gradient", "Gradient (x Input)",
|
| 162 |
+
"Integrated Gradient", "Integrated Gradient (x Input)"
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
ferret_filename = f"{hf_model_name.replace('/', '_')}_ferret_plausibility.csv"
|
| 166 |
+
header_written = os.path.exists(ferret_filename)
|
| 167 |
+
|
| 168 |
+
# To ensure no empty cells: collect all possible output columns
|
| 169 |
+
all_fieldnames = set(["Index", "Text", "final_label", "final_label_numeric", "Annotations", "Rationale"])
|
| 170 |
+
|
| 171 |
+
# --- MAIN LOOP ---
|
| 172 |
+
for idx in tqdm(range(len(df)), desc="FERRET (Plausibility Only)"):
|
| 173 |
+
row = df.iloc[idx]
|
| 174 |
+
|
| 175 |
+
ferret_row = {
|
| 176 |
+
"Index": idx,
|
| 177 |
+
"Text": row["Text"],
|
| 178 |
+
"final_label": row["final_label"],
|
| 179 |
+
"final_label_numeric": int(row["final_label_numeric"]),
|
| 180 |
+
"Annotations": row.get("Annotations", ""),
|
| 181 |
+
"Rationale": row.get("Rationale", ""),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Prepare explanations for all explainers
|
| 185 |
+
row_explanations = {}
|
| 186 |
+
for explainer_name in explainer_names:
|
| 187 |
+
score_col = f"{explainer_name}_ImportanceScores"
|
| 188 |
+
tokens_col = "Tokens"
|
| 189 |
+
if pd.notna(row.get(score_col)) and pd.notna(row.get(tokens_col)):
|
| 190 |
+
try:
|
| 191 |
+
scores = [float(score) for score in str(row[score_col]).split()]
|
| 192 |
+
tokens = str(row[tokens_col]).split()
|
| 193 |
+
target_label = int(row["final_label_numeric"])
|
| 194 |
+
row_explanations[explainer_name] = Explanation(
|
| 195 |
+
text=row["Text"], tokens=tokens, scores=scores,
|
| 196 |
+
explainer=explainer_name, target=target_label
|
| 197 |
+
)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"Could not create explanation for explainer {explainer_name} at index {idx}: {e}")
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
# Discover available metrics for plausibility
|
| 203 |
+
available_metrics = set()
|
| 204 |
+
if row_explanations:
|
| 205 |
+
first_explainer = next(iter(row_explanations.keys()))
|
| 206 |
+
first_explanation = row_explanations[first_explainer]
|
| 207 |
+
for test_annot_idx in range(3):
|
| 208 |
+
test_attn_col = f"embert_attention_{test_annot_idx+1}"
|
| 209 |
+
test_human_rationale_str = str(row.get(test_attn_col, ""))
|
| 210 |
+
test_human_rationale = [int(v) for v in test_human_rationale_str.split() if v.isdigit()]
|
| 211 |
+
if any(test_human_rationale):
|
| 212 |
+
try:
|
| 213 |
+
test_plaus_eval = bench.evaluate_explanations(
|
| 214 |
+
[first_explanation],
|
| 215 |
+
human_rationale=test_human_rationale,
|
| 216 |
+
target=first_explanation.target,
|
| 217 |
+
skip_faithfulness=True
|
| 218 |
+
)
|
| 219 |
+
if test_plaus_eval and len(test_plaus_eval) > 0:
|
| 220 |
+
test_eval_obj = test_plaus_eval[0]
|
| 221 |
+
if hasattr(test_eval_obj, "evaluation_scores") and test_eval_obj.evaluation_scores:
|
| 222 |
+
for score in test_eval_obj.evaluation_scores:
|
| 223 |
+
if score.name in ['auprc_plau', 'token_f1_plau', 'token_iou_plau']:
|
| 224 |
+
available_metrics.add(score.name)
|
| 225 |
+
break
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"Error discovering metrics with {first_explainer} and annotator {test_annot_idx+1}: {e}")
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
print(f"Row {idx}: Using FERRET plausibility metrics: {list(available_metrics)}")
|
| 231 |
+
|
| 232 |
+
# --- Evaluate plausibility for each explainer/annotator combination ---
|
| 233 |
+
for explainer_name in explainer_names:
|
| 234 |
+
if explainer_name not in row_explanations:
|
| 235 |
+
for annot_idx in range(3):
|
| 236 |
+
for metric in available_metrics:
|
| 237 |
+
colname = f"{explainer_name}Annotator{annot_idx+1}{metric}"
|
| 238 |
+
ferret_row[colname] = "N/A"
|
| 239 |
+
all_fieldnames.add(colname)
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
explanation = row_explanations[explainer_name]
|
| 243 |
+
label = explanation.target
|
| 244 |
+
|
| 245 |
+
for annot_idx in range(3):
|
| 246 |
+
attn_col = f"embert_attention_{annot_idx+1}"
|
| 247 |
+
human_rationale_str = str(row.get(attn_col, ""))
|
| 248 |
+
human_rationale = [int(v) for v in human_rationale_str.split() if v.isdigit()]
|
| 249 |
+
|
| 250 |
+
annot_labels_list = str(row["Annotations"]).split("|")
|
| 251 |
+
if annot_idx < len(annot_labels_list):
|
| 252 |
+
annot_label_str = annot_labels_list[annot_idx].split("-")[0].strip()
|
| 253 |
+
else:
|
| 254 |
+
annot_label_str = ""
|
| 255 |
+
|
| 256 |
+
final_label_str = inv_label_map[label]
|
| 257 |
+
for metric in available_metrics:
|
| 258 |
+
colname = f"{explainer_name}Annotator{annot_idx+1}{metric}"
|
| 259 |
+
all_fieldnames.add(colname)
|
| 260 |
+
|
| 261 |
+
if annot_label_str != final_label_str:
|
| 262 |
+
for metric in available_metrics:
|
| 263 |
+
ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
if any(human_rationale):
|
| 267 |
+
try:
|
| 268 |
+
plaus_eval = bench.evaluate_explanations(
|
| 269 |
+
[explanation],
|
| 270 |
+
human_rationale=human_rationale,
|
| 271 |
+
target=label,
|
| 272 |
+
skip_faithfulness=True
|
| 273 |
+
)
|
| 274 |
+
if plaus_eval and len(plaus_eval) > 0:
|
| 275 |
+
eval_obj = plaus_eval[0]
|
| 276 |
+
if hasattr(eval_obj, "evaluation_scores") and eval_obj.evaluation_scores:
|
| 277 |
+
for score in eval_obj.evaluation_scores:
|
| 278 |
+
if score.name in ['auprc_plau', 'token_f1_plau', 'token_iou_plau']:
|
| 279 |
+
ferret_row[f"{explainer_name}Annotator{annot_idx+1}{score.name}"] = float(score.score)
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f"Error evaluating {explainer_name} for annotator {annot_idx+1} at index {idx}: {e}")
|
| 282 |
+
for metric in available_metrics:
|
| 283 |
+
ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
|
| 284 |
+
else:
|
| 285 |
+
for metric in available_metrics:
|
| 286 |
+
ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
|
| 287 |
+
|
| 288 |
+
# --- Ensure no empty cells: fill missing columns with "N/A" ---
|
| 289 |
+
for col in all_fieldnames:
|
| 290 |
+
if col not in ferret_row:
|
| 291 |
+
ferret_row[col] = "N/A"
|
| 292 |
+
|
| 293 |
+
# === SAVE THIS ROW TO CSV IMMEDIATELY ===
|
| 294 |
+
write_header = not header_written
|
| 295 |
+
with open(ferret_filename, mode='a', newline='', encoding='utf-8') as f:
|
| 296 |
+
writer = csv.DictWriter(f, fieldnames=list(all_fieldnames))
|
| 297 |
+
if write_header:
|
| 298 |
+
writer.writeheader()
|
| 299 |
+
header_written = True
|
| 300 |
+
writer.writerow(ferret_row)
|
| 301 |
+
|
| 302 |
+
print(f"[INFO] FERRET plausibility results saved row-wise to {ferret_filename}.")
|
| 303 |
+
|
| 304 |
+
# --- Memory cleanup ---
|
| 305 |
+
print(f"[INFO] Cleaning up memory for model {hf_model_name}...")
|
| 306 |
+
del bench, model, tokenizer, df
|
| 307 |
+
gc.collect()
|
| 308 |
+
if device.type == "cuda":
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
# ==============================
|
| 312 |
+
# End of pipeline
|
| 313 |
+
# ==============================
|
hyperparameter_tuning_for_rationale.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from transformers import AutoTokenizer, AutoModel
|
| 10 |
+
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support
|
| 11 |
+
import itertools
|
| 12 |
+
import warnings
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
def set_seed(seed=13):
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
np.random.seed(seed)
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed_all(seed)
|
| 20 |
+
torch.backends.cudnn.deterministic = True
|
| 21 |
+
torch.backends.cudnn.benchmark = False
|
| 22 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 23 |
+
|
| 24 |
+
set_seed(13)
|
| 25 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 26 |
+
|
| 27 |
+
# --- CONFIG ---
|
| 28 |
+
param_grid = {
|
| 29 |
+
"learning_rate": [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
|
| 30 |
+
"batch_size": [16, 32, 64],
|
| 31 |
+
"optimizer": ["Adam"],
|
| 32 |
+
"lambda": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
| 33 |
+
}
|
| 34 |
+
num_epochs = 7
|
| 35 |
+
max_length = 128
|
| 36 |
+
model_name = "bert-base-multilingual-cased"
|
| 37 |
+
num_labels = 3
|
| 38 |
+
|
| 39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
|
| 41 |
+
# --- LOAD DATA ---
|
| 42 |
+
train_df = pd.read_csv("train.csv")
|
| 43 |
+
val_df = pd.read_csv("val.csv")
|
| 44 |
+
|
| 45 |
+
valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
|
| 46 |
+
train_df = train_df[train_df["final_label"].isin(valid_labels.keys())]
|
| 47 |
+
val_df = val_df[val_df["final_label"].isin(valid_labels.keys())]
|
| 48 |
+
if train_df.empty:
|
| 49 |
+
raise ValueError("Train dataset empty after filtering.")
|
| 50 |
+
if val_df.empty:
|
| 51 |
+
raise ValueError("Validation dataset empty after filtering.")
|
| 52 |
+
|
| 53 |
+
# --- INITIALIZE TOKENIZER & ADD EMOJIS ---
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
+
emoji_path = "emoji.csv" # adjust path if needed
|
| 56 |
+
if os.path.exists(emoji_path):
|
| 57 |
+
emoji_df = pd.read_csv(emoji_path)
|
| 58 |
+
emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
|
| 59 |
+
existing_vocab = set(tokenizer.get_vocab().keys())
|
| 60 |
+
emoji_set = set(emoji_list) - existing_vocab
|
| 61 |
+
if emoji_set:
|
| 62 |
+
tokenizer.add_tokens(list(emoji_set))
|
| 63 |
+
print(f"Added {len(emoji_set)} new emoji tokens to the tokenizer.")
|
| 64 |
+
else:
|
| 65 |
+
print("No new emojis to add.")
|
| 66 |
+
else:
|
| 67 |
+
print(f"Emoji file not found at: {emoji_path}")
|
| 68 |
+
|
| 69 |
+
# --- FUNCTIONS ---
|
| 70 |
+
|
| 71 |
+
def generate_attention_vectors_from_rationales(df, tokenizer, epsilon=1e-8):
|
| 72 |
+
attention_vectors = []
|
| 73 |
+
for _, row in df.iterrows():
|
| 74 |
+
text = str(row["Content"])
|
| 75 |
+
final_label = str(row["final_label"]).strip()
|
| 76 |
+
encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
|
| 77 |
+
offsets = encoding["offset_mapping"]
|
| 78 |
+
num_tokens = len(offsets)
|
| 79 |
+
avg_vector = np.zeros(num_tokens, dtype=np.float32)
|
| 80 |
+
annotations = str(row.get("Annotations", "")).split("|")
|
| 81 |
+
rationales = str(row.get("Rationale", "")).split("|")
|
| 82 |
+
annot_vectors = []
|
| 83 |
+
for annot_label, annot_rationale in zip(annotations, rationales):
|
| 84 |
+
if not annot_label:
|
| 85 |
+
continue
|
| 86 |
+
if annot_label.split("-")[0].strip() != final_label:
|
| 87 |
+
continue
|
| 88 |
+
spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
|
| 89 |
+
if not spans:
|
| 90 |
+
continue
|
| 91 |
+
vec = np.zeros(num_tokens, dtype=np.float32)
|
| 92 |
+
for span_text in spans:
|
| 93 |
+
start = 0
|
| 94 |
+
while True:
|
| 95 |
+
idx = text.find(span_text, start)
|
| 96 |
+
if idx < 0:
|
| 97 |
+
break
|
| 98 |
+
span_start, span_end = idx, idx + len(span_text)
|
| 99 |
+
for i, (tok_start, tok_end) in enumerate(offsets):
|
| 100 |
+
if tok_end > span_start and tok_start < span_end:
|
| 101 |
+
vec[i] = 1.0
|
| 102 |
+
start = idx + 1
|
| 103 |
+
if vec.sum() > 0:
|
| 104 |
+
annot_vectors.append(vec)
|
| 105 |
+
if annot_vectors:
|
| 106 |
+
avg_vector = np.mean(annot_vectors, axis=0)
|
| 107 |
+
avg_vector = np.where(avg_vector == 0, epsilon, avg_vector)
|
| 108 |
+
attn_str = " ".join(f"{v:.8f}" for v in avg_vector)
|
| 109 |
+
attention_vectors.append(attn_str)
|
| 110 |
+
df["embert_attention"] = attention_vectors
|
| 111 |
+
return df
|
| 112 |
+
|
| 113 |
+
class RationaleDataset(Dataset):
|
| 114 |
+
def __init__(self, df, tokenizer, max_length=128, label_mapping=None):
|
| 115 |
+
self.df = df
|
| 116 |
+
self.tokenizer = tokenizer
|
| 117 |
+
self.max_length = max_length
|
| 118 |
+
self.label_mapping = label_mapping
|
| 119 |
+
|
| 120 |
+
def __len__(self):
|
| 121 |
+
return len(self.df)
|
| 122 |
+
|
| 123 |
+
def __getitem__(self, idx):
|
| 124 |
+
row = self.df.iloc[idx]
|
| 125 |
+
text = row["Content"]
|
| 126 |
+
label = self.label_mapping[row["final_label"]]
|
| 127 |
+
encoding = self.tokenizer(
|
| 128 |
+
text, padding="max_length", truncation=True,
|
| 129 |
+
max_length=self.max_length, return_tensors="pt"
|
| 130 |
+
)
|
| 131 |
+
rationale_raw = [float(x) for x in row["embert_attention"].split()] \
|
| 132 |
+
if pd.notna(row["embert_attention"]) and row["embert_attention"].strip() else []
|
| 133 |
+
rationale_vector = np.concatenate([
|
| 134 |
+
np.array([0.0], dtype=np.float32),
|
| 135 |
+
np.array(rationale_raw, dtype=np.float32),
|
| 136 |
+
np.array([0.0], dtype=np.float32)
|
| 137 |
+
])
|
| 138 |
+
rationale_vector = rationale_vector[:self.max_length]
|
| 139 |
+
if len(rationale_vector) < self.max_length:
|
| 140 |
+
rationale_vector = np.pad(rationale_vector, (0, self.max_length - len(rationale_vector)), constant_values=0.0)
|
| 141 |
+
rationale_tensor = torch.tensor(rationale_vector, dtype=torch.float32)
|
| 142 |
+
if torch.sum(rationale_tensor) == 0.0:
|
| 143 |
+
has_rationale = False
|
| 144 |
+
rationale_probs = torch.ones(self.max_length, dtype=torch.float32) / self.max_length
|
| 145 |
+
else:
|
| 146 |
+
has_rationale = True
|
| 147 |
+
rationale_probs = torch.softmax(rationale_tensor, dim=0)
|
| 148 |
+
return (
|
| 149 |
+
encoding["input_ids"].squeeze(0),
|
| 150 |
+
encoding["attention_mask"].squeeze(0),
|
| 151 |
+
rationale_probs,
|
| 152 |
+
torch.tensor(label, dtype=torch.long),
|
| 153 |
+
torch.tensor(has_rationale, dtype=torch.bool)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
class RationaleModel(nn.Module):
|
| 157 |
+
def __init__(self, model_name, num_labels):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.bert = AutoModel.from_pretrained(model_name, output_attentions=True)
|
| 160 |
+
self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
|
| 161 |
+
|
| 162 |
+
def forward(self, input_ids, attention_mask):
|
| 163 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 164 |
+
cls_output = outputs.last_hidden_state[:, 0, :]
|
| 165 |
+
logits = self.classifier(cls_output)
|
| 166 |
+
last_layer_attn = outputs.attentions[-1] # (batch, heads, seq, seq)
|
| 167 |
+
cls_attn = last_layer_attn[:, :, 0, :] # (batch, heads, seq)
|
| 168 |
+
cls_attn_avg = cls_attn.mean(dim=1) # (batch, seq)
|
| 169 |
+
return logits, cls_attn_avg
|
| 170 |
+
|
| 171 |
+
def evaluate_model(model, val_loader, criterion_cls, device):
|
| 172 |
+
model.eval()
|
| 173 |
+
total_val_loss = 0.0
|
| 174 |
+
all_preds = []
|
| 175 |
+
all_labels = []
|
| 176 |
+
all_probs = []
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
for batch in val_loader:
|
| 179 |
+
input_ids, attention_mask, _, labels, _ = [b.to(device) for b in batch]
|
| 180 |
+
logits, _ = model(input_ids, attention_mask)
|
| 181 |
+
loss = criterion_cls(logits, labels)
|
| 182 |
+
total_val_loss += loss.item()
|
| 183 |
+
probs = torch.softmax(logits, dim=1)
|
| 184 |
+
preds = torch.argmax(probs, dim=1)
|
| 185 |
+
all_preds.extend(preds.cpu().numpy())
|
| 186 |
+
all_labels.extend(labels.cpu().numpy())
|
| 187 |
+
all_probs.extend(probs.cpu().numpy())
|
| 188 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 189 |
+
# Overall metrics
|
| 190 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 191 |
+
f1_macro = f1_score(all_labels, all_preds, average="macro")
|
| 192 |
+
try:
|
| 193 |
+
y_true_oh = np.eye(num_labels)[all_labels]
|
| 194 |
+
auroc_ovr = roc_auc_score(y_true_oh, all_probs, multi_class="ovr")
|
| 195 |
+
except Exception:
|
| 196 |
+
auroc_ovr = -1.0
|
| 197 |
+
# Class-wise metrics
|
| 198 |
+
class_wise_metrics = {}
|
| 199 |
+
target_names = sorted(valid_labels, key=valid_labels.get)
|
| 200 |
+
precision, recall, f1_per_class, _ = precision_recall_fscore_support(all_labels, all_preds, average=None, labels=[valid_labels[label_name] for label_name in target_names])
|
| 201 |
+
for i, label_name in enumerate(target_names):
|
| 202 |
+
class_wise_metrics[f"{label_name}_precision"] = precision[i]
|
| 203 |
+
class_wise_metrics[f"{label_name}_recall"] = recall[i]
|
| 204 |
+
class_wise_metrics[f"{label_name}_f1"] = f1_per_class[i]
|
| 205 |
+
# Per-class accuracy: of true class samples, how many were predicted correctly
|
| 206 |
+
idx = np.array(all_labels) == valid_labels[label_name]
|
| 207 |
+
if idx.sum() > 0:
|
| 208 |
+
acc = (np.array(all_preds)[idx] == valid_labels[label_name]).sum() / idx.sum()
|
| 209 |
+
else:
|
| 210 |
+
acc = -1.0
|
| 211 |
+
class_wise_metrics[f"{label_name}_accuracy"] = acc
|
| 212 |
+
# Class-wise AUROC
|
| 213 |
+
try:
|
| 214 |
+
binary_labels = (np.array(all_labels) == valid_labels[label_name]).astype(int)
|
| 215 |
+
class_probs = np.array(all_probs)[:, valid_labels[label_name]]
|
| 216 |
+
if len(np.unique(binary_labels)) > 1:
|
| 217 |
+
class_wise_metrics[f"{label_name}_auroc"] = roc_auc_score(binary_labels, class_probs)
|
| 218 |
+
else:
|
| 219 |
+
class_wise_metrics[f"{label_name}_auroc"] = -1.0
|
| 220 |
+
except Exception:
|
| 221 |
+
class_wise_metrics[f"{label_name}_auroc"] = -1.0
|
| 222 |
+
return avg_val_loss, accuracy, f1_macro, auroc_ovr, class_wise_metrics
|
| 223 |
+
|
| 224 |
+
def train_model(model, train_loader, val_loader, num_epochs, device, lambda_attn=1.0, optimizer=None, learning_rate=2e-5, results_writer=None, results_file_handle=None, params=None):
|
| 225 |
+
criterion_cls = nn.CrossEntropyLoss()
|
| 226 |
+
criterion_kl = nn.KLDivLoss(reduction="batchmean")
|
| 227 |
+
if optimizer is None:
|
| 228 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 229 |
+
for epoch in range(num_epochs):
|
| 230 |
+
model.train()
|
| 231 |
+
total_train_loss = 0.0
|
| 232 |
+
for batch in train_loader:
|
| 233 |
+
input_ids, attention_mask, rationale_probs, labels, has_rationale = [b.to(device) for b in batch]
|
| 234 |
+
optimizer.zero_grad()
|
| 235 |
+
logits, model_attention = model(input_ids, attention_mask)
|
| 236 |
+
loss_cls = criterion_cls(logits, labels)
|
| 237 |
+
loss = loss_cls
|
| 238 |
+
if has_rationale.any():
|
| 239 |
+
model_attn_batch = model_attention[has_rationale]
|
| 240 |
+
rationale_batch = rationale_probs[has_rationale]
|
| 241 |
+
log_model_attn = torch.log(model_attn_batch + 1e-8)
|
| 242 |
+
loss_kl = criterion_kl(log_model_attn, rationale_batch)
|
| 243 |
+
loss += lambda_attn * loss_kl
|
| 244 |
+
loss.backward()
|
| 245 |
+
optimizer.step()
|
| 246 |
+
total_train_loss += loss.item()
|
| 247 |
+
avg_train_loss = total_train_loss / len(train_loader)
|
| 248 |
+
val_loss, val_acc, val_f1_macro, val_auroc_ovr, class_wise_metrics = evaluate_model(model, val_loader, criterion_cls, device)
|
| 249 |
+
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1 (Macro): {val_f1_macro:.4f} | Val AUROC (OvR): {val_auroc_ovr:.4f}")
|
| 250 |
+
sorted_labels = sorted(valid_labels, key=valid_labels.get)
|
| 251 |
+
for label_name in sorted_labels:
|
| 252 |
+
print(f" {label_name}: P={class_wise_metrics[f'{label_name}_precision']:.4f}, R={class_wise_metrics[f'{label_name}_recall']:.4f}, F1={class_wise_metrics[f'{label_name}_f1']:.4f}, Acc={class_wise_metrics[f'{label_name}_accuracy']:.4f}, AUROC={class_wise_metrics[f'{label_name}_auroc']:.4f}")
|
| 253 |
+
if results_writer and results_file_handle:
|
| 254 |
+
row_data = [
|
| 255 |
+
params["learning_rate"],
|
| 256 |
+
params["batch_size"],
|
| 257 |
+
params["optimizer"],
|
| 258 |
+
params["lambda"],
|
| 259 |
+
epoch + 1,
|
| 260 |
+
avg_train_loss,
|
| 261 |
+
val_loss,
|
| 262 |
+
val_acc,
|
| 263 |
+
val_f1_macro,
|
| 264 |
+
val_auroc_ovr
|
| 265 |
+
]
|
| 266 |
+
for label_name in sorted_labels:
|
| 267 |
+
row_data.extend([
|
| 268 |
+
class_wise_metrics[f"{label_name}_precision"],
|
| 269 |
+
class_wise_metrics[f"{label_name}_recall"],
|
| 270 |
+
class_wise_metrics[f"{label_name}_f1"],
|
| 271 |
+
class_wise_metrics[f"{label_name}_accuracy"],
|
| 272 |
+
class_wise_metrics[f"{label_name}_auroc"]
|
| 273 |
+
])
|
| 274 |
+
results_writer.writerow(row_data)
|
| 275 |
+
results_file_handle.flush()
|
| 276 |
+
os.fsync(results_file_handle.fileno())
|
| 277 |
+
|
| 278 |
+
# --- PREPARE DATASETS ---
|
| 279 |
+
print("Generating attention vectors for training data...")
|
| 280 |
+
train_df = generate_attention_vectors_from_rationales(train_df, tokenizer)
|
| 281 |
+
print("Generating attention vectors for validation data...")
|
| 282 |
+
val_df = generate_attention_vectors_from_rationales(val_df, tokenizer)
|
| 283 |
+
|
| 284 |
+
train_dataset = RationaleDataset(train_df, tokenizer, max_length, label_mapping=valid_labels)
|
| 285 |
+
val_dataset = RationaleDataset(val_df, tokenizer, max_length, label_mapping=valid_labels)
|
| 286 |
+
|
| 287 |
+
# --- GRID SEARCH LOOP ---
|
| 288 |
+
keys, values = zip(*param_grid.items())
|
| 289 |
+
param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
| 290 |
+
results_file = "grid_results_detailed.csv"
|
| 291 |
+
headers = ["learning_rate", "batch_size", "optimizer", "lambda", "epoch", "train_loss", "val_loss", "val_accuracy", "val_f1_macro", "val_auroc_ovr"]
|
| 292 |
+
sorted_labels = sorted(valid_labels, key=valid_labels.get)
|
| 293 |
+
for label_name in sorted_labels:
|
| 294 |
+
headers.extend([f"{label_name}_precision", f"{label_name}_recall", f"{label_name}_f1", f"{label_name}_accuracy", f"{label_name}_auroc"])
|
| 295 |
+
with open(results_file, mode="w", newline="") as f:
|
| 296 |
+
writer = csv.writer(f)
|
| 297 |
+
writer.writerow(headers)
|
| 298 |
+
for params in param_combinations:
|
| 299 |
+
print("\nRunning:", params)
|
| 300 |
+
learning_rate = params["learning_rate"]
|
| 301 |
+
batch_size = params["batch_size"]
|
| 302 |
+
optimizer_type = params["optimizer"]
|
| 303 |
+
lambda_attn = params["lambda"]
|
| 304 |
+
model = RationaleModel(model_name=model_name, num_labels=num_labels).to(device)
|
| 305 |
+
if 'emoji_set' in locals() and len(emoji_set) > 0:
|
| 306 |
+
model.bert.resize_token_embeddings(len(tokenizer))
|
| 307 |
+
if optimizer_type == "Adam":
|
| 308 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError("Unsupported optimizer")
|
| 311 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(13))
|
| 312 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 313 |
+
train_model(
|
| 314 |
+
model=model,
|
| 315 |
+
train_loader=train_loader,
|
| 316 |
+
val_loader=val_loader,
|
| 317 |
+
num_epochs=num_epochs,
|
| 318 |
+
device=device,
|
| 319 |
+
lambda_attn=lambda_attn,
|
| 320 |
+
optimizer=optimizer,
|
| 321 |
+
learning_rate=learning_rate,
|
| 322 |
+
results_writer=writer,
|
| 323 |
+
results_file_handle=f,
|
| 324 |
+
params=params
|
| 325 |
+
)
|
| 326 |
+
print("Grid search complete. Results saved to:", results_file)
|
hyperparameter_tuning_without_rationale.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import itertools
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
| 10 |
+
from torch.optim import Adam
|
| 11 |
+
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_fscore_support, accuracy_score
|
| 12 |
+
|
| 13 |
+
# Reproducibility
|
| 14 |
+
def set_seed(seed=13):
|
| 15 |
+
random.seed(seed)
|
| 16 |
+
np.random.seed(seed)
|
| 17 |
+
torch.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed(seed)
|
| 19 |
+
torch.backends.cudnn.deterministic = True
|
| 20 |
+
torch.backends.cudnn.benchmark = False
|
| 21 |
+
|
| 22 |
+
set_seed(13)
|
| 23 |
+
|
| 24 |
+
# Configurations
|
| 25 |
+
param_grid = {
|
| 26 |
+
"learning_rate": [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
|
| 27 |
+
"batch_size": [16, 32, 64]
|
| 28 |
+
}
|
| 29 |
+
num_epochs = 10
|
| 30 |
+
max_length = 128
|
| 31 |
+
model_name = "bert-base-multilingual-cased"
|
| 32 |
+
num_labels = 3
|
| 33 |
+
|
| 34 |
+
# Tokenizer + Emoji Extension
|
| 35 |
+
emoji_df = pd.read_csv("emoji.csv")
|
| 36 |
+
emoji_list = emoji_df.iloc[:, 0].dropna().astype(str).unique().tolist()
|
| 37 |
+
|
| 38 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 39 |
+
new_tokens = list(set(emoji_list) - set(tokenizer.vocab.keys()))
|
| 40 |
+
if new_tokens:
|
| 41 |
+
tokenizer.add_tokens(new_tokens)
|
| 42 |
+
print(f"Added {len(new_tokens)} emojis to tokenizer.")
|
| 43 |
+
|
| 44 |
+
# Data loading
|
| 45 |
+
train_df = pd.read_csv("train.csv")
|
| 46 |
+
val_df = pd.read_csv("val.csv")
|
| 47 |
+
|
| 48 |
+
valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
|
| 49 |
+
train_df = train_df[train_df["final_label"].isin(valid_labels)]
|
| 50 |
+
val_df = val_df[val_df["final_label"].isin(valid_labels)]
|
| 51 |
+
|
| 52 |
+
class CustomDataset(Dataset):
|
| 53 |
+
def __init__(self, dataframe, tokenizer, max_length):
|
| 54 |
+
self.dataframe = dataframe.reset_index(drop=True)
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.max_length = max_length
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.dataframe)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
row = self.dataframe.iloc[idx]
|
| 63 |
+
text = row["Content"]
|
| 64 |
+
label = valid_labels[row["final_label"]]
|
| 65 |
+
encoding = self.tokenizer(
|
| 66 |
+
text,
|
| 67 |
+
padding="max_length",
|
| 68 |
+
truncation=True,
|
| 69 |
+
max_length=self.max_length,
|
| 70 |
+
return_tensors="pt"
|
| 71 |
+
)
|
| 72 |
+
return (
|
| 73 |
+
encoding["input_ids"].squeeze(0),
|
| 74 |
+
encoding["attention_mask"].squeeze(0),
|
| 75 |
+
torch.tensor(label, dtype=torch.long)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
train_dataset = CustomDataset(train_df, tokenizer, max_length)
|
| 79 |
+
val_dataset = CustomDataset(val_df, tokenizer, max_length)
|
| 80 |
+
|
| 81 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
+
|
| 83 |
+
# Results directory and file
|
| 84 |
+
os.makedirs("results", exist_ok=True)
|
| 85 |
+
results_path = "results/grid_search_metrics.csv"
|
| 86 |
+
|
| 87 |
+
if not os.path.exists(results_path):
|
| 88 |
+
with open(results_path, "w") as f:
|
| 89 |
+
f.write("timestamp,learning_rate,batch_size,epoch,val_macro_f1,val_auroc,"
|
| 90 |
+
"acc_negative,prec_negative,rec_negative,f1_negative,"
|
| 91 |
+
"acc_neutral,prec_neutral,rec_neutral,f1_neutral,"
|
| 92 |
+
"acc_positive,prec_positive,rec_positive,f1_positive\n")
|
| 93 |
+
|
| 94 |
+
# Hyperparameter grid search
|
| 95 |
+
for lr, bs in itertools.product(param_grid["learning_rate"], param_grid["batch_size"]):
|
| 96 |
+
print(f"\nStarting config: LR={lr}, Batch Size={bs}")
|
| 97 |
+
set_seed(13)
|
| 98 |
+
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
|
| 99 |
+
val_loader = DataLoader(val_dataset, batch_size=bs)
|
| 100 |
+
|
| 101 |
+
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)
|
| 102 |
+
if new_tokens:
|
| 103 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 104 |
+
|
| 105 |
+
optimizer = Adam(model.parameters(), lr=lr)
|
| 106 |
+
|
| 107 |
+
for epoch in range(1, num_epochs + 1):
|
| 108 |
+
model.train()
|
| 109 |
+
for batch in train_loader:
|
| 110 |
+
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
| 111 |
+
optimizer.zero_grad()
|
| 112 |
+
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
|
| 113 |
+
outputs.loss.backward()
|
| 114 |
+
optimizer.step()
|
| 115 |
+
|
| 116 |
+
# Evaluation
|
| 117 |
+
model.eval()
|
| 118 |
+
val_preds, val_probs, val_labels = [], [], []
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
for batch in val_loader:
|
| 121 |
+
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
| 122 |
+
logits = model(input_ids, attention_mask=attention_mask).logits
|
| 123 |
+
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
| 124 |
+
preds = torch.argmax(logits, axis=1).cpu().tolist()
|
| 125 |
+
|
| 126 |
+
val_probs.extend(probs)
|
| 127 |
+
val_preds.extend(preds)
|
| 128 |
+
val_labels.extend(labels.cpu().tolist())
|
| 129 |
+
|
| 130 |
+
val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
|
| 131 |
+
val_auroc = roc_auc_score(
|
| 132 |
+
np.eye(num_labels)[val_labels],
|
| 133 |
+
np.array(val_probs),
|
| 134 |
+
average="macro",
|
| 135 |
+
multi_class="ovr"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Label-wise metrics
|
| 139 |
+
report = precision_recall_fscore_support(val_labels, val_preds, labels=[0, 1, 2], zero_division=0)
|
| 140 |
+
acc_per_label = []
|
| 141 |
+
for i in range(num_labels):
|
| 142 |
+
idx = np.array(val_labels) == i
|
| 143 |
+
correct = (np.array(val_preds)[idx] == i).sum()
|
| 144 |
+
total = idx.sum()
|
| 145 |
+
acc = correct / total if total > 0 else 0
|
| 146 |
+
acc_per_label.append(acc)
|
| 147 |
+
|
| 148 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 149 |
+
row = [
|
| 150 |
+
timestamp, lr, bs, epoch, f"{val_macro_f1:.4f}", f"{val_auroc:.4f}"
|
| 151 |
+
]
|
| 152 |
+
for i in range(num_labels):
|
| 153 |
+
row.extend([
|
| 154 |
+
f"{acc_per_label[i]:.4f}",
|
| 155 |
+
f"{report[0][i]:.4f}", # precision
|
| 156 |
+
f"{report[1][i]:.4f}", # recall
|
| 157 |
+
f"{report[2][i]:.4f}" # f1
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
with open(results_path, "a") as f:
|
| 161 |
+
f.write(",".join(map(str, row)) + "\n")
|
| 162 |
+
|
| 163 |
+
print(f"[Epoch {epoch}] LR={lr}, BS={bs} | F1={val_macro_f1:.4f} | AUROC={val_auroc:.4f}")
|
| 164 |
+
|
| 165 |
+
print(f"\nGrid Search Complete. Results saved to: {results_path}")
|
model_training_with_rationale.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 9 |
+
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support
|
| 10 |
+
import warnings
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
def set_seed(seed=13):
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed_all(seed)
|
| 18 |
+
torch.backends.cudnn.deterministic = True
|
| 19 |
+
torch.backends.cudnn.benchmark = False
|
| 20 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
| 21 |
+
|
| 22 |
+
set_seed(13)
|
| 23 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 24 |
+
|
| 25 |
+
# --- CONFIG ---
|
| 26 |
+
model_name = "bert-base-multilingual-cased" # Set your model name here
|
| 27 |
+
num_epochs = 4
|
| 28 |
+
max_length = 128
|
| 29 |
+
num_labels = 3
|
| 30 |
+
learning_rate = 2e-5
|
| 31 |
+
batch_size = 64
|
| 32 |
+
optimizer_type = "Adam"
|
| 33 |
+
lambda_attn = 0.6
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
|
| 37 |
+
# --- LOAD DATA ---
|
| 38 |
+
train_df = pd.read_csv("train.csv")
|
| 39 |
+
val_df = pd.read_csv("val.csv")
|
| 40 |
+
test_df = pd.read_csv("test.csv")
|
| 41 |
+
valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
|
| 42 |
+
train_df = train_df[train_df["final_label"].isin(valid_labels.keys())]
|
| 43 |
+
val_df = val_df[val_df["final_label"].isin(valid_labels.keys())]
|
| 44 |
+
test_df = test_df[test_df["final_label"].isin(valid_labels.keys())]
|
| 45 |
+
if train_df.empty:
|
| 46 |
+
raise ValueError("Train dataset empty after filtering.")
|
| 47 |
+
if val_df.empty:
|
| 48 |
+
raise ValueError("Validation dataset empty after filtering.")
|
| 49 |
+
|
| 50 |
+
# --- FUNCTIONS ---
|
| 51 |
+
def generate_attention_vectors_from_rationales(df, tokenizer, epsilon=1e-8):
|
| 52 |
+
attention_vectors = []
|
| 53 |
+
for _, row in df.iterrows():
|
| 54 |
+
text = str(row["Content"])
|
| 55 |
+
final_label = str(row["final_label"]).strip()
|
| 56 |
+
encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
|
| 57 |
+
offsets = encoding["offset_mapping"]
|
| 58 |
+
num_tokens = len(offsets)
|
| 59 |
+
avg_vector = np.zeros(num_tokens, dtype=np.float32)
|
| 60 |
+
annotations = str(row.get("Annotations", "")).split("|")
|
| 61 |
+
rationales = str(row.get("Rationale", "")).split("|")
|
| 62 |
+
annot_vectors = []
|
| 63 |
+
for annot_label, annot_rationale in zip(annotations, rationales):
|
| 64 |
+
if not annot_label:
|
| 65 |
+
continue
|
| 66 |
+
if annot_label.split("-")[0].strip() != final_label:
|
| 67 |
+
continue
|
| 68 |
+
spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
|
| 69 |
+
if not spans:
|
| 70 |
+
continue
|
| 71 |
+
vec = np.zeros(num_tokens, dtype=np.float32)
|
| 72 |
+
for span_text in spans:
|
| 73 |
+
start = 0
|
| 74 |
+
while True:
|
| 75 |
+
idx = text.find(span_text, start)
|
| 76 |
+
if idx < 0:
|
| 77 |
+
break
|
| 78 |
+
span_start, span_end = idx, idx + len(span_text)
|
| 79 |
+
for i, (tok_start, tok_end) in enumerate(offsets):
|
| 80 |
+
if tok_end > span_start and tok_start < span_end:
|
| 81 |
+
vec[i] = 1.0
|
| 82 |
+
start = idx + 1
|
| 83 |
+
if vec.sum() > 0:
|
| 84 |
+
annot_vectors.append(vec)
|
| 85 |
+
if annot_vectors:
|
| 86 |
+
avg_vector = np.mean(annot_vectors, axis=0)
|
| 87 |
+
avg_vector = np.where(avg_vector == 0, epsilon, avg_vector)
|
| 88 |
+
attn_str = " ".join(f"{v:.8f}" for v in avg_vector)
|
| 89 |
+
attention_vectors.append(attn_str)
|
| 90 |
+
df["embert_attention"] = attention_vectors
|
| 91 |
+
return df
|
| 92 |
+
|
| 93 |
+
class RationaleDataset(Dataset):
|
| 94 |
+
def __init__(self, df, tokenizer, max_length=128, label_mapping=None):
|
| 95 |
+
self.df = df
|
| 96 |
+
self.tokenizer = tokenizer
|
| 97 |
+
self.max_length = max_length
|
| 98 |
+
self.label_mapping = label_mapping
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.df)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
row = self.df.iloc[idx]
|
| 105 |
+
text = row["Content"]
|
| 106 |
+
label = self.label_mapping[row["final_label"]]
|
| 107 |
+
encoding = self.tokenizer(
|
| 108 |
+
text, padding="max_length", truncation=True,
|
| 109 |
+
max_length=self.max_length, return_tensors="pt"
|
| 110 |
+
)
|
| 111 |
+
rationale_raw = [float(x) for x in row["embert_attention"].split()] \
|
| 112 |
+
if pd.notna(row["embert_attention"]) and row["embert_attention"].strip() else []
|
| 113 |
+
rationale_vector = np.concatenate([
|
| 114 |
+
np.array([0.0], dtype=np.float32),
|
| 115 |
+
np.array(rationale_raw, dtype=np.float32),
|
| 116 |
+
np.array([0.0], dtype=np.float32)
|
| 117 |
+
])
|
| 118 |
+
rationale_vector = rationale_vector[:self.max_length]
|
| 119 |
+
if len(rationale_vector) < self.max_length:
|
| 120 |
+
rationale_vector = np.pad(rationale_vector, (0, self.max_length - len(rationale_vector)), constant_values=0.0)
|
| 121 |
+
rationale_tensor = torch.tensor(rationale_vector, dtype=torch.float32)
|
| 122 |
+
if torch.sum(rationale_tensor) == 0.0:
|
| 123 |
+
has_rationale = False
|
| 124 |
+
rationale_probs = torch.ones(self.max_length, dtype=torch.float32) / self.max_length
|
| 125 |
+
else:
|
| 126 |
+
has_rationale = True
|
| 127 |
+
rationale_probs = torch.softmax(rationale_tensor, dim=0)
|
| 128 |
+
return (
|
| 129 |
+
encoding["input_ids"].squeeze(0),
|
| 130 |
+
encoding["attention_mask"].squeeze(0),
|
| 131 |
+
rationale_probs,
|
| 132 |
+
torch.tensor(label, dtype=torch.long),
|
| 133 |
+
torch.tensor(has_rationale, dtype=torch.bool)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
class RationaleModel(nn.Module):
|
| 137 |
+
def __init__(self, model_name, num_labels):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, output_attentions=True)
|
| 140 |
+
def forward(self, input_ids, attention_mask):
|
| 141 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 142 |
+
logits = outputs.logits
|
| 143 |
+
last_layer_attn = outputs.attentions[-1] # (batch, heads, seq, seq)
|
| 144 |
+
cls_attn = last_layer_attn[:, :, 0, :] # (batch, heads, seq)
|
| 145 |
+
cls_attn_avg = cls_attn.mean(dim=1) # (batch, seq)
|
| 146 |
+
return logits, cls_attn_avg
|
| 147 |
+
|
| 148 |
+
def evaluate_model(model, val_loader, criterion_cls, device, valid_labels, num_labels):
|
| 149 |
+
model.eval()
|
| 150 |
+
total_val_loss = 0.0
|
| 151 |
+
all_preds = []
|
| 152 |
+
all_labels = []
|
| 153 |
+
all_probs = []
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
for batch in val_loader:
|
| 156 |
+
input_ids, attention_mask, _, labels, _ = [b.to(device) for b in batch]
|
| 157 |
+
logits, _ = model(input_ids, attention_mask)
|
| 158 |
+
loss = criterion_cls(logits, labels)
|
| 159 |
+
total_val_loss += loss.item()
|
| 160 |
+
probs = torch.softmax(logits, dim=1)
|
| 161 |
+
preds = torch.argmax(probs, dim=1)
|
| 162 |
+
all_preds.extend(preds.cpu().numpy())
|
| 163 |
+
all_labels.extend(labels.cpu().numpy())
|
| 164 |
+
all_probs.extend(probs.cpu().numpy())
|
| 165 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 166 |
+
all_labels_np = np.array(all_labels)
|
| 167 |
+
all_preds_np = np.array(all_preds)
|
| 168 |
+
all_probs_np = np.array(all_probs)
|
| 169 |
+
accuracy = accuracy_score(all_labels_np, all_preds_np)
|
| 170 |
+
f1_macro = f1_score(all_labels_np, all_preds_np, average="macro")
|
| 171 |
+
try:
|
| 172 |
+
y_true_oh = np.eye(num_labels)[all_labels_np]
|
| 173 |
+
auroc_ovr = roc_auc_score(y_true_oh, all_probs_np, multi_class="ovr")
|
| 174 |
+
except:
|
| 175 |
+
auroc_ovr = -1.0
|
| 176 |
+
class_wise_metrics = {}
|
| 177 |
+
target_names = sorted(valid_labels, key=valid_labels.get)
|
| 178 |
+
label_indices = [valid_labels[label_name] for label_name in target_names]
|
| 179 |
+
precision, recall, f1_per_class, support = precision_recall_fscore_support(
|
| 180 |
+
all_labels_np, all_preds_np, labels=label_indices, average=None)
|
| 181 |
+
for i, label_name in enumerate(target_names):
|
| 182 |
+
label_id = valid_labels[label_name]
|
| 183 |
+
class_wise_metrics[f"{label_name}_precision"] = precision[i]
|
| 184 |
+
class_wise_metrics[f"{label_name}_recall"] = recall[i]
|
| 185 |
+
class_wise_metrics[f"{label_name}_f1"] = f1_per_class[i]
|
| 186 |
+
label_mask = all_labels_np == label_id
|
| 187 |
+
correct_preds = np.sum((all_preds_np == label_id) & label_mask)
|
| 188 |
+
total_label = np.sum(label_mask)
|
| 189 |
+
if total_label > 0:
|
| 190 |
+
class_wise_metrics[f"{label_name}_accuracy"] = correct_preds / total_label
|
| 191 |
+
else:
|
| 192 |
+
class_wise_metrics[f"{label_name}_accuracy"] = -1.0
|
| 193 |
+
try:
|
| 194 |
+
binary_labels = (all_labels_np == label_id).astype(int)
|
| 195 |
+
class_probs = all_probs_np[:, label_id]
|
| 196 |
+
if len(np.unique(binary_labels)) > 1:
|
| 197 |
+
class_wise_metrics[f"{label_name}_auroc"] = roc_auc_score(binary_labels, class_probs)
|
| 198 |
+
else:
|
| 199 |
+
class_wise_metrics[f"{label_name}_auroc"] = -1.0
|
| 200 |
+
except:
|
| 201 |
+
class_wise_metrics[f"{label_name}_auroc"] = -1.0
|
| 202 |
+
return avg_val_loss, accuracy, f1_macro, auroc_ovr, class_wise_metrics
|
| 203 |
+
|
| 204 |
+
def train_model(model, train_loader, val_loader, num_epochs, device, lambda_attn=1.0, optimizer=None, learning_rate=2e-5, results_writer=None, results_file_handle=None):
|
| 205 |
+
criterion_cls = nn.CrossEntropyLoss()
|
| 206 |
+
criterion_kl = nn.KLDivLoss(reduction="batchmean")
|
| 207 |
+
if optimizer is None:
|
| 208 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 209 |
+
for epoch in range(num_epochs):
|
| 210 |
+
model.train()
|
| 211 |
+
total_train_loss = 0.0
|
| 212 |
+
for batch in train_loader:
|
| 213 |
+
input_ids, attention_mask, rationale_probs, labels, has_rationale = [b.to(device) for b in batch]
|
| 214 |
+
optimizer.zero_grad()
|
| 215 |
+
logits, model_attention = model(input_ids, attention_mask)
|
| 216 |
+
loss_cls = criterion_cls(logits, labels)
|
| 217 |
+
loss = loss_cls
|
| 218 |
+
if has_rationale.any():
|
| 219 |
+
model_attn_batch = model_attention[has_rationale]
|
| 220 |
+
rationale_batch = rationale_probs[has_rationale]
|
| 221 |
+
log_model_attn = torch.log(model_attn_batch + 1e-8)
|
| 222 |
+
loss_kl = criterion_kl(log_model_attn, rationale_batch)
|
| 223 |
+
loss += lambda_attn * loss_kl
|
| 224 |
+
loss.backward()
|
| 225 |
+
optimizer.step()
|
| 226 |
+
total_train_loss += loss.item()
|
| 227 |
+
avg_train_loss = total_train_loss / len(train_loader)
|
| 228 |
+
val_loss, val_acc, val_f1_macro, val_auroc_ovr, class_wise_metrics = evaluate_model(model, val_loader, criterion_cls, device, valid_labels, num_labels)
|
| 229 |
+
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1 (Macro): {val_f1_macro:.4f} | Val AUROC (OvR): {val_auroc_ovr:.4f}")
|
| 230 |
+
sorted_labels = sorted(valid_labels, key=valid_labels.get)
|
| 231 |
+
for label_name in sorted_labels:
|
| 232 |
+
print(f" {label_name}: P={class_wise_metrics[f'{label_name}_precision']:.4f}, R={class_wise_metrics[f'{label_name}_recall']:.4f}, F1={class_wise_metrics[f'{label_name}_f1']:.4f}, Acc={class_wise_metrics[f'{label_name}_accuracy']:.4f}, AUROC={class_wise_metrics[f'{label_name}_auroc']:.4f}")
|
| 233 |
+
|
| 234 |
+
if results_writer and results_file_handle:
|
| 235 |
+
row_data = [
|
| 236 |
+
learning_rate,
|
| 237 |
+
batch_size,
|
| 238 |
+
optimizer_type,
|
| 239 |
+
lambda_attn,
|
| 240 |
+
epoch + 1,
|
| 241 |
+
avg_train_loss,
|
| 242 |
+
val_loss,
|
| 243 |
+
val_acc,
|
| 244 |
+
val_f1_macro,
|
| 245 |
+
val_auroc_ovr
|
| 246 |
+
]
|
| 247 |
+
for label_name in sorted_labels:
|
| 248 |
+
row_data.extend([
|
| 249 |
+
class_wise_metrics[f"{label_name}_precision"],
|
| 250 |
+
class_wise_metrics[f"{label_name}_recall"],
|
| 251 |
+
class_wise_metrics[f"{label_name}_f1"],
|
| 252 |
+
class_wise_metrics[f"{label_name}_accuracy"],
|
| 253 |
+
class_wise_metrics[f"{label_name}_auroc"]
|
| 254 |
+
])
|
| 255 |
+
results_writer.writerow(row_data)
|
| 256 |
+
results_file_handle.flush()
|
| 257 |
+
os.fsync(results_file_handle.fileno())
|
| 258 |
+
|
| 259 |
+
# --- OUTPUT FOLDERS ---
|
| 260 |
+
csv_output_dir = "csv_outputs"
|
| 261 |
+
os.makedirs(csv_output_dir, exist_ok=True)
|
| 262 |
+
results_file = os.path.join(csv_output_dir, "results_detailed.csv")
|
| 263 |
+
headers = ["learning_rate", "batch_size", "optimizer", "lambda", "epoch", "train_loss", "val_loss", "val_accuracy", "val_f1_macro", "val_auroc_ovr"]
|
| 264 |
+
sorted_labels = sorted(valid_labels, key=valid_labels.get)
|
| 265 |
+
for label in sorted_labels:
|
| 266 |
+
headers.extend([f"{label}_precision", f"{label}_recall", f"{label}_f1", f"{label}_accuracy", f"{label}_auroc"])
|
| 267 |
+
|
| 268 |
+
# --- INITIALIZE TOKENIZER & ADD EMOJIS ---
|
| 269 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 270 |
+
emoji_path = "emoji.csv"
|
| 271 |
+
if os.path.exists(emoji_path):
|
| 272 |
+
emoji_df = pd.read_csv(emoji_path)
|
| 273 |
+
emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
|
| 274 |
+
existing_vocab = set(tokenizer.get_vocab().keys())
|
| 275 |
+
emoji_set = set(emoji_list) - existing_vocab
|
| 276 |
+
if emoji_set:
|
| 277 |
+
tokenizer.add_tokens(list(emoji_set))
|
| 278 |
+
print(f"Added {len(emoji_set)} new emoji tokens to the tokenizer.")
|
| 279 |
+
else:
|
| 280 |
+
print("No new emojis to add.")
|
| 281 |
+
else:
|
| 282 |
+
print(f"Emoji file not found at: {emoji_path}")
|
| 283 |
+
|
| 284 |
+
# --- PREPARE DATASETS ---
|
| 285 |
+
print("Generating attention vectors for training data...")
|
| 286 |
+
train_df_model = generate_attention_vectors_from_rationales(train_df.copy(), tokenizer)
|
| 287 |
+
print("Generating attention vectors for validation data...")
|
| 288 |
+
val_df_model = generate_attention_vectors_from_rationales(val_df.copy(), tokenizer)
|
| 289 |
+
|
| 290 |
+
train_dataset = RationaleDataset(train_df_model, tokenizer, max_length, label_mapping=valid_labels)
|
| 291 |
+
val_dataset = RationaleDataset(val_df_model, tokenizer, max_length, label_mapping=valid_labels)
|
| 292 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(13))
|
| 293 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 294 |
+
|
| 295 |
+
# --- CSV Setup ---
|
| 296 |
+
with open(results_file, mode="w", newline="") as f:
|
| 297 |
+
writer = csv.writer(f)
|
| 298 |
+
writer.writerow(headers)
|
| 299 |
+
model = RationaleModel(model_name=model_name, num_labels=num_labels).to(device)
|
| 300 |
+
if 'emoji_set' in locals() and len(emoji_set) > 0:
|
| 301 |
+
model.bert.resize_token_embeddings(len(tokenizer))
|
| 302 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 303 |
+
train_model(
|
| 304 |
+
model=model,
|
| 305 |
+
train_loader=train_loader,
|
| 306 |
+
val_loader=val_loader,
|
| 307 |
+
num_epochs=num_epochs,
|
| 308 |
+
device=device,
|
| 309 |
+
lambda_attn=lambda_attn,
|
| 310 |
+
optimizer=optimizer,
|
| 311 |
+
learning_rate=learning_rate,
|
| 312 |
+
results_writer=writer,
|
| 313 |
+
results_file_handle=f
|
| 314 |
+
)
|
| 315 |
+
# Save final model and tokenizer
|
| 316 |
+
model.bert.save_pretrained("model_outputs")
|
| 317 |
+
tokenizer.save_pretrained("model_outputs")
|
| 318 |
+
print(f"Final model and tokenizer saved to model_outputs")
|
model_training_without_rationale.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
from torch.optim import Adam
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import warnings
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
f1_score, roc_auc_score, accuracy_score,
|
| 13 |
+
precision_recall_fscore_support, confusion_matrix
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 17 |
+
|
| 18 |
+
# --- CONFIG ---
|
| 19 |
+
args_dict = {
|
| 20 |
+
"batch_size": 64,
|
| 21 |
+
"num_epochs": 4,
|
| 22 |
+
"learning_rate": 2e-5,
|
| 23 |
+
"max_length": 128,
|
| 24 |
+
"model_name": "bert-base-multilingual-cased",#replace with your model_name
|
| 25 |
+
"num_labels": 3,
|
| 26 |
+
"save_dir": "./saved_model"
|
| 27 |
+
}
|
| 28 |
+
os.makedirs(args_dict["save_dir"], exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# --- LABEL MAPPING ---
|
| 31 |
+
label_mapping = {"Negative": 0, "Neutral": 1, "Positive": 2}
|
| 32 |
+
label2name = {v: k for k, v in label_mapping.items()}
|
| 33 |
+
label_ids = list(label2name.keys())
|
| 34 |
+
|
| 35 |
+
# --- LOAD DATA ---
|
| 36 |
+
train_df = pd.read_csv("train.csv")
|
| 37 |
+
val_df = pd.read_csv("val.csv")
|
| 38 |
+
test_df = pd.read_csv("test.csv")
|
| 39 |
+
emoji_df = pd.read_csv("emoji.csv")
|
| 40 |
+
|
| 41 |
+
# --- FILTER INVALID LABELS ---
|
| 42 |
+
train_df = train_df[train_df["final_label"].isin(label_mapping)]
|
| 43 |
+
val_df = val_df[val_df["final_label"].isin(label_mapping)]
|
| 44 |
+
test_df = test_df[test_df["final_label"].isin(label_mapping)]
|
| 45 |
+
|
| 46 |
+
# --- TOKENIZER ---
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(args_dict["model_name"])
|
| 48 |
+
emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
|
| 49 |
+
emoji_set = set(emoji_list) - set(tokenizer.vocab.keys())
|
| 50 |
+
if emoji_set:
|
| 51 |
+
tokenizer.add_tokens(list(emoji_set))
|
| 52 |
+
print(f"Added {len(emoji_set)} emojis to tokenizer.")
|
| 53 |
+
|
| 54 |
+
# --- MODEL ---
|
| 55 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 56 |
+
args_dict["model_name"], num_labels=args_dict["num_labels"]
|
| 57 |
+
)
|
| 58 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
model.to(device)
|
| 61 |
+
|
| 62 |
+
# --- DATASET ---
|
| 63 |
+
class SimpleTextDataset(Dataset):
|
| 64 |
+
def __init__(self, dataframe, tokenizer, max_length=128):
|
| 65 |
+
self.dataframe = dataframe
|
| 66 |
+
self.tokenizer = tokenizer
|
| 67 |
+
self.max_length = max_length
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.dataframe)
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
row = self.dataframe.iloc[idx]
|
| 74 |
+
text = row["Content"]
|
| 75 |
+
label = label_mapping[row["final_label"]]
|
| 76 |
+
encoding = self.tokenizer(
|
| 77 |
+
text, padding="max_length", truncation=True,
|
| 78 |
+
max_length=self.max_length, return_tensors="pt"
|
| 79 |
+
)
|
| 80 |
+
return (
|
| 81 |
+
encoding["input_ids"].squeeze(0),
|
| 82 |
+
encoding["attention_mask"].squeeze(0),
|
| 83 |
+
torch.tensor(label, dtype=torch.long),
|
| 84 |
+
text
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# --- DATALOADERS ---
|
| 88 |
+
train_loader = DataLoader(SimpleTextDataset(train_df, tokenizer), batch_size=args_dict["batch_size"], shuffle=True)
|
| 89 |
+
val_loader = DataLoader(SimpleTextDataset(val_df, tokenizer), batch_size=args_dict["batch_size"])
|
| 90 |
+
test_loader = DataLoader(SimpleTextDataset(test_df, tokenizer), batch_size=args_dict["batch_size"])
|
| 91 |
+
|
| 92 |
+
# --- TRAINING ---
|
| 93 |
+
optimizer = Adam(model.parameters(), lr=args_dict["learning_rate"])
|
| 94 |
+
val_metrics_history = []
|
| 95 |
+
|
| 96 |
+
for epoch in range(1, args_dict["num_epochs"] + 1):
|
| 97 |
+
model.train()
|
| 98 |
+
total_loss = 0
|
| 99 |
+
for batch in train_loader:
|
| 100 |
+
input_ids, attn_mask, labels, _ = [x.to(device) for x in batch[:3]]
|
| 101 |
+
optimizer.zero_grad()
|
| 102 |
+
outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
|
| 103 |
+
outputs.loss.backward()
|
| 104 |
+
optimizer.step()
|
| 105 |
+
total_loss += outputs.loss.item()
|
| 106 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 107 |
+
|
| 108 |
+
# --- VALIDATION ---
|
| 109 |
+
model.eval()
|
| 110 |
+
val_preds, val_labels, val_loss = [], [], 0
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
for batch in val_loader:
|
| 113 |
+
input_ids, attn_mask, labels, _ = [x.to(device) for x in batch[:3]]
|
| 114 |
+
outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
|
| 115 |
+
val_preds.extend(outputs.logits.argmax(dim=1).cpu().numpy())
|
| 116 |
+
val_labels.extend(labels.cpu().numpy())
|
| 117 |
+
val_loss += outputs.loss.item()
|
| 118 |
+
val_loss /= len(val_loader)
|
| 119 |
+
val_acc = accuracy_score(val_labels, val_preds)
|
| 120 |
+
val_f1 = f1_score(val_labels, val_preds, average="weighted")
|
| 121 |
+
try:
|
| 122 |
+
val_auroc = roc_auc_score(
|
| 123 |
+
pd.get_dummies(val_labels), pd.get_dummies(val_preds),
|
| 124 |
+
average="weighted", multi_class="ovo"
|
| 125 |
+
)
|
| 126 |
+
except:
|
| 127 |
+
val_auroc = float("nan")
|
| 128 |
+
|
| 129 |
+
# --- Label-wise Metrics ---
|
| 130 |
+
prec, rec, f1, supp = precision_recall_fscore_support(val_labels, val_preds, labels=[0,1,2])
|
| 131 |
+
labelwise = {}
|
| 132 |
+
for i in [0, 1, 2]:
|
| 133 |
+
idx = np.array(val_labels) == i
|
| 134 |
+
if idx.sum() > 0:
|
| 135 |
+
acc = (np.array(val_preds)[idx] == i).sum() / idx.sum()
|
| 136 |
+
else:
|
| 137 |
+
acc = 0.0
|
| 138 |
+
labelwise[label2name[i]] = {
|
| 139 |
+
"val_acc": acc,
|
| 140 |
+
"val_f1": f1[i],
|
| 141 |
+
"val_precision": prec[i],
|
| 142 |
+
"val_recall": rec[i],
|
| 143 |
+
"val_support": supp[i]
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
val_metrics_history.append({
|
| 147 |
+
"epoch": epoch,
|
| 148 |
+
"train_loss": avg_train_loss,
|
| 149 |
+
"val_loss": val_loss,
|
| 150 |
+
"val_accuracy": val_acc,
|
| 151 |
+
"val_f1": val_f1,
|
| 152 |
+
"val_auroc": val_auroc,
|
| 153 |
+
**{f"{label}_{m}": labelwise[label][m]
|
| 154 |
+
for label in labelwise for m in labelwise[label]}
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f} | Val Acc={val_acc:.4f} | Val F1={val_f1:.4f} | AUROC={val_auroc:.4f}")
|
| 158 |
+
|
| 159 |
+
model.save_pretrained(args_dict["save_dir"])
|
| 160 |
+
tokenizer.save_pretrained(args_dict["save_dir"])
|
| 161 |
+
print(f"Last model saved after epoch {args_dict['num_epochs']}")
|
| 162 |
+
# --- SAVE VAL METRICS ---
|
| 163 |
+
pd.DataFrame(val_metrics_history).to_csv("val_metrics_detailed.csv", index=False)
|
| 164 |
+
|
| 165 |
+
# --- LOAD BEST MODEL ---
|
| 166 |
+
model = AutoModelForSequenceClassification.from_pretrained(args_dict["save_dir"]).to(device)
|
| 167 |
+
tokenizer = AutoTokenizer.from_pretrained(args_dict["save_dir"])
|
| 168 |
+
|
| 169 |
+
# --- TEST EVAL ---
|
| 170 |
+
model.eval()
|
| 171 |
+
all_preds, all_labels, all_sentences, all_tokens = [], [], [], []
|
| 172 |
+
test_loss = 0
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for batch in test_loader:
|
| 176 |
+
input_ids, attn_mask, labels, sentences = batch
|
| 177 |
+
input_ids, attn_mask, labels = input_ids.to(device), attn_mask.to(device), labels.to(device)
|
| 178 |
+
outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
|
| 179 |
+
test_loss += outputs.loss.item()
|
| 180 |
+
preds = outputs.logits.argmax(dim=1)
|
| 181 |
+
all_preds.extend(preds.cpu().numpy())
|
| 182 |
+
all_labels.extend(labels.cpu().numpy())
|
| 183 |
+
all_sentences.extend(sentences)
|
| 184 |
+
all_tokens.extend(tokenizer.batch_decode(input_ids.cpu(), skip_special_tokens=True))
|
| 185 |
+
|
| 186 |
+
test_loss /= len(test_loader)
|
| 187 |
+
test_acc = accuracy_score(all_labels, all_preds)
|
| 188 |
+
test_f1 = f1_score(all_labels, all_preds, average="weighted")
|
| 189 |
+
try:
|
| 190 |
+
test_auroc = roc_auc_score(pd.get_dummies(all_labels), pd.get_dummies(all_preds), average="weighted", multi_class="ovo")
|
| 191 |
+
except:
|
| 192 |
+
test_auroc = float("nan")
|
| 193 |
+
|
| 194 |
+
# --- LABEL-WISE TEST METRICS ---
|
| 195 |
+
prec, rec, f1, supp = precision_recall_fscore_support(all_labels, all_preds, labels=[0,1,2])
|
| 196 |
+
label_metrics = {
|
| 197 |
+
"Label": [], "Accuracy": [], "F1": [], "Precision": [], "Recall": [], "Support": []
|
| 198 |
+
}
|
| 199 |
+
for i in [0, 1, 2]:
|
| 200 |
+
idx = np.array(all_labels) == i
|
| 201 |
+
if idx.sum() > 0:
|
| 202 |
+
acc = (np.array(all_preds)[idx] == i).sum() / idx.sum()
|
| 203 |
+
else:
|
| 204 |
+
acc = 0.0
|
| 205 |
+
label_name = label2name[i]
|
| 206 |
+
label_metrics["Label"].append(label_name)
|
| 207 |
+
label_metrics["Accuracy"].append(acc)
|
| 208 |
+
label_metrics["F1"].append(f1[i])
|
| 209 |
+
label_metrics["Precision"].append(prec[i])
|
| 210 |
+
label_metrics["Recall"].append(rec[i])
|
| 211 |
+
label_metrics["Support"].append(supp[i])
|
| 212 |
+
pd.DataFrame(label_metrics).to_csv("labelwise_test_metrics.csv", index=False)
|
| 213 |
+
|
| 214 |
+
# --- OVERALL TEST METRICS CSV ---
|
| 215 |
+
pd.DataFrame([{
|
| 216 |
+
"Test Loss": test_loss,
|
| 217 |
+
"Test Accuracy": test_acc,
|
| 218 |
+
"Test F1 Score": test_f1,
|
| 219 |
+
"Test AUROC": test_auroc
|
| 220 |
+
}]).to_csv("overall_test_metrics.csv", index=False)
|
| 221 |
+
|
| 222 |
+
# --- TEST PREDICTIONS ---
|
| 223 |
+
pd.DataFrame({
|
| 224 |
+
"Content": all_sentences,
|
| 225 |
+
"Tokens": all_tokens,
|
| 226 |
+
"final_label": [label2name[l] for l in all_labels],
|
| 227 |
+
"predicted_label": [label2name[p] for p in all_preds]
|
| 228 |
+
}).to_csv("test_predictions.csv", index=False)
|
| 229 |
+
|
| 230 |
+
# --- CONFUSION MATRIX ---
|
| 231 |
+
conf_matrix = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])
|
| 232 |
+
conf_matrix_df = pd.DataFrame(conf_matrix, index=[label2name[i] for i in [0,1,2]],
|
| 233 |
+
columns=[label2name[i] for i in [0,1,2]])
|
| 234 |
+
conf_matrix_df.to_csv("confusion_matrix.csv")
|
| 235 |
+
|
| 236 |
+
# --- CONFUSION MATRIX PLOT ---
|
| 237 |
+
plt.figure(figsize=(6, 5))
|
| 238 |
+
sns.heatmap(conf_matrix_df, annot=True, fmt='d', cmap='Blues')
|
| 239 |
+
plt.title("Confusion Matrix")
|
| 240 |
+
plt.ylabel("True Label")
|
| 241 |
+
plt.xlabel("Predicted Label")
|
| 242 |
+
plt.tight_layout()
|
| 243 |
+
plt.savefig("confusion_matrix.png")
|
| 244 |
+
plt.close()
|
| 245 |
+
|
| 246 |
+
# --- DONE ---
|
| 247 |
+
print("\n=== FINAL TEST METRICS ===")
|
| 248 |
+
print(f"Test Accuracy : {test_acc:.4f}")
|
| 249 |
+
print(f"Test F1 : {test_f1:.4f}")
|
| 250 |
+
print(f"Test AUROC : {test_auroc:.4f}")
|
| 251 |
+
print("All test metrics, predictions, and confusion matrix saved.")
|