Telugu
Raj411 commited on
Commit
4890177
·
verified ·
1 Parent(s): aaca83a

Upload 6 files

Browse files
ferret_faithfullness.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import pandas as pd
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from ferret import Benchmark
8
+ from scipy.stats import rankdata
9
+ import torch
10
+
11
+ # ================================================================
12
+ # CONFIGURATION
13
+ # ================================================================
14
+
15
+ # Set your Hugging Face model repo name here
16
+ hf_model_name = "PLACE_YOUR_MODEL_NAME"
17
+
18
+ # CSV test file (expected in current directory)
19
+ test_file = "test.csv"
20
+
21
+ # Batch sizes
22
+ prediction_batch_size = 64
23
+ ferret_batch_size = 1
24
+
25
+ # Label mapping
26
+ label_map = {
27
+ "Negative": 0,
28
+ "Neutral": 1,
29
+ "Positive": 2
30
+ }
31
+
32
+ # Device
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ print(f"[INFO] Using device: {device}")
35
+
36
+ # ================================================================
37
+ # LOAD TEST DATA
38
+ # ================================================================
39
+
40
+ if not os.path.exists(test_file):
41
+ raise FileNotFoundError(f"Test file not found: {test_file}")
42
+
43
+ df_full = pd.read_csv(test_file)
44
+ df_full["final_label"] = df_full["final_label"].astype("category")
45
+ df_full["final_label_numeric"] = df_full["final_label"].map(label_map)
46
+
47
+ texts_all = df_full["Content"].tolist()
48
+ labels_all = df_full["final_label"].tolist()
49
+ print(f"[INFO] ✅ Loaded test data: {len(df_full)} rows.")
50
+
51
+ # ================================================================
52
+ # PIPELINE FOR SINGLE MODEL (MATCHED ONLY)
53
+ # ================================================================
54
+
55
+ print(f"\n==============================")
56
+ print(f"[INFO] 🚀 Starting pipeline for model: {hf_model_name}")
57
+ print(f"==============================")
58
+
59
+ # -----------------------------
60
+ # LOAD MODEL & TOKENIZER
61
+ # -----------------------------
62
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
63
+ model = AutoModelForSequenceClassification.from_pretrained(
64
+ hf_model_name,
65
+ trust_remote_code=True,
66
+ use_safetensors=True
67
+ )
68
+ model = model.to(device)
69
+ model.eval()
70
+ print("[INFO] ✅ Model loaded.")
71
+
72
+ # -----------------------------
73
+ # PREDICTIONS
74
+ # -----------------------------
75
+ predictions = []
76
+ print("[INFO] 🔎 Predicting on test set...")
77
+
78
+ for i in tqdm(range(0, len(texts_all), prediction_batch_size), desc="Predicting"):
79
+ batch_texts = texts_all[i : i + prediction_batch_size]
80
+ inputs = tokenizer(
81
+ batch_texts,
82
+ padding=True,
83
+ truncation=True,
84
+ return_tensors="pt",
85
+ max_length=256
86
+ )
87
+ inputs = {k: v.to(device) for k, v in inputs.items()}
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+ preds = torch.argmax(outputs.logits, dim=1).cpu().tolist()
91
+ predictions.extend(preds)
92
+
93
+ # Store predictions
94
+ df = df_full.copy()
95
+ df["prediction"] = predictions
96
+ df["prediction"] = df["prediction"].astype("int8")
97
+
98
+ predictions_filename = f"{hf_model_name.replace('/', '_')}_predictions.csv"
99
+ df.to_csv(predictions_filename, index=False)
100
+ print(f"[INFO] ✅ Predictions saved to {predictions_filename}.")
101
+
102
+ # -----------------------------
103
+ # SPLIT MATCHED ONLY
104
+ # -----------------------------
105
+ matched_df = df[df["prediction"] == df["final_label_numeric"]].reset_index(drop=True)
106
+ print(f"[INFO] ✅ {len(matched_df)} matched rows retained.")
107
+
108
+ # Save matched for records
109
+ matched_df.to_csv(f"{hf_model_name.replace('/', '_')}_matched.csv", index=False)
110
+
111
+ # -----------------------------
112
+ # FERRET ON MATCHED
113
+ # -----------------------------
114
+ if len(matched_df) > 0:
115
+ ferret_rows = []
116
+ bench = Benchmark(model, tokenizer)
117
+ print(f"[INFO] 🚀 Running FERRET on matched rows...")
118
+ for i in tqdm(range(0, len(matched_df), ferret_batch_size), desc="FERRET (Matched)"):
119
+ batch = matched_df.iloc[i : i + ferret_batch_size]
120
+ for _, row in batch.iterrows():
121
+ text = row["Content"]
122
+ label = int(row["final_label_numeric"])
123
+ try:
124
+ explanations = bench.explain(text, target=label)
125
+ evaluations = bench.evaluate_explanations(explanations, target=label)
126
+ except Exception as ex:
127
+ print(f"[WARN] FERRET failed on matched text: {text}\nReason: {ex}")
128
+ continue
129
+ ferret_row = {
130
+ "Text": text,
131
+ "final_label": row["final_label"],
132
+ "final_label_numeric": label,
133
+ "Annotations": row.get("Annotations", ""),
134
+ "Rationale": row.get("Rationale", ""),
135
+ }
136
+ if explanations:
137
+ ferret_row["Tokens"] = " ".join(explanations[0].tokens)
138
+ for expl, evaluation in zip(explanations, evaluations):
139
+ explainer_name = expl.explainer
140
+ scores = expl.scores
141
+ ranks = rankdata(-np.array(scores), method="min")
142
+ ferret_row[f"{explainer_name}_ImportanceScores"] = " ".join(map(str, scores))
143
+ ferret_row[f"{explainer_name}_RankVector"] = " ".join(map(str, ranks))
144
+ if evaluation and hasattr(evaluation, "evaluation_scores"):
145
+ for eval_score in evaluation.evaluation_scores:
146
+ ferret_row[f"{explainer_name}_{eval_score.name}"] = float(eval_score.score)
147
+ ferret_rows.append(ferret_row)
148
+ del explanations
149
+ del evaluations
150
+ gc.collect()
151
+ ferret_filename = f"{hf_model_name.replace('/', '_')}_ferret_matched.csv"
152
+ pd.DataFrame(ferret_rows).to_csv(ferret_filename, index=False)
153
+ print(f"[INFO] ✅ FERRET results saved to {ferret_filename}.")
154
+ else:
155
+ print("[INFO] ⚠ No matched rows to run FERRET on.")
156
+
157
+ print("[INFO] ✅ Pipeline finished for matched rows only!")
ferret_plausibility.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ferret
3
+ from ferret import Benchmark
4
+ import csv
5
+ import gc
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ import torch
11
+
12
+ # ==============================
13
+ # File path and Data Loading
14
+ # ==============================
15
+ input_file = "modelname_ferret_matched.csv" # Columns: 'Text', 'final_label', 'final_label_numeric', 'Annotations', 'Rationale', plus FERRET columns
16
+
17
+ if not os.path.exists(input_file):
18
+ raise FileNotFoundError(f"Input file not found: {input_file}")
19
+
20
+ df = pd.read_csv(input_file)
21
+
22
+ # ==============================
23
+ # Model and Plausibility Configs
24
+ # ==============================
25
+ hf_model_names = [
26
+ #keep your model names as "model_name"
27
+ ]
28
+ label_map = {"Negative": 0, "Neutral": 1, "Positive": 2}
29
+ inv_label_map = {v: k for k, v in label_map.items()}
30
+ max_length = 128
31
+
32
+ # ==============================
33
+ # Rationales to Attention Vectors
34
+ # ==============================
35
+ def generate_attention_vectors_from_rationales(df, tokenizer, max_length=128):
36
+ """
37
+ For each row, generate attention vectors for each annotator.
38
+ Vector is zero if annotator's label does not match final label.
39
+ """
40
+ all_attention_vectors = []
41
+ token_lengths = []
42
+ max_annotators = 0
43
+
44
+ for _, row in df.iterrows():
45
+ text = str(row["Text"])
46
+ final_label_id = row["final_label_numeric"]
47
+ final_label = inv_label_map[final_label_id]
48
+
49
+ encoding = tokenizer(
50
+ text,
51
+ add_special_tokens=True,
52
+ return_offsets_mapping=True,
53
+ return_attention_mask=False,
54
+ return_token_type_ids=False,
55
+ max_length=max_length,
56
+ truncation=True
57
+ )
58
+ offsets = encoding["offset_mapping"]
59
+ real_token_indices = [i for i, (start, end) in enumerate(offsets) if start != end and start >= 0]
60
+ num_real_tokens = len(real_token_indices)
61
+ token_lengths.append(num_real_tokens)
62
+
63
+ annotations = str(row["Annotations"]).split("|")
64
+ rationales = str(row["Rationale"]).split("|")
65
+ max_annotators = max(max_annotators, len(annotations))
66
+
67
+ row_attention_vectors = []
68
+
69
+ for annot_label, annot_rationale in zip(annotations, rationales):
70
+ vec = np.zeros(num_real_tokens, dtype=np.float32)
71
+ # Set all zeros and skip if annotator label does not match final label
72
+ if not annot_label.strip() or annot_label.split("-")[0].strip() != final_label:
73
+ row_attention_vectors.append(vec)
74
+ continue
75
+
76
+ spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
77
+ if not spans:
78
+ row_attention_vectors.append(vec)
79
+ continue
80
+
81
+ for span_text in spans:
82
+ start = 0
83
+ while True:
84
+ idx = text.find(span_text, start)
85
+ if idx < 0:
86
+ break
87
+ span_start, span_end = idx, idx + len(span_text)
88
+ for vec_idx, tok_idx in enumerate(real_token_indices):
89
+ tok_start, tok_end = offsets[tok_idx]
90
+ if tok_end > span_start and tok_start < span_end:
91
+ vec[vec_idx] = 1.0
92
+ start = idx + 1
93
+
94
+ row_attention_vectors.append(vec)
95
+
96
+ all_attention_vectors.append(row_attention_vectors)
97
+
98
+ # Write attention vectors to new columns
99
+ for i in range(max_annotators):
100
+ col_name = f"embert_attention_{i+1}"
101
+ col_vectors = []
102
+ for row_vecs, num_tokens in zip(all_attention_vectors, token_lengths):
103
+ if i < len(row_vecs):
104
+ vec_str = " ".join(f"{int(v)}" for v in row_vecs[i])
105
+ else:
106
+ vec_str = " ".join(["0"] * num_tokens)
107
+ col_vectors.append(vec_str)
108
+ df[col_name] = col_vectors
109
+
110
+ return df
111
+
112
+ # ==============================
113
+ # Explanation class for FERRET
114
+ # ==============================
115
+ class Explanation:
116
+ def __init__(self, text, tokens, scores, explainer, target):
117
+ self.text = text
118
+ self.tokens = tokens
119
+ self.scores = np.array(scores, dtype=np.float32)
120
+ self.explainer = explainer
121
+ self.target = target
122
+
123
+ def __repr__(self):
124
+ return f"Explanation(text={self.text!r}, tokens={self.tokens}, scores=array({self.scores}, dtype=float32), explainer={self.explainer!r}, target={self.target})"
125
+
126
+ # ==============================
127
+ # DEVICE SETUP
128
+ # ==============================
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+ print(f"[INFO] Using device: {device}")
131
+
132
+ # ==============================
133
+ # FERRET PIPELINE
134
+ # ==============================
135
+ for hf_model_name in hf_model_names:
136
+ print(f"\n==============================")
137
+ print(f"[INFO] Starting pipeline for model: {hf_model_name}")
138
+ print(f"==============================")
139
+
140
+ # Load model and tokenizer
141
+ model = AutoModelForSequenceClassification.from_pretrained(
142
+ hf_model_name,
143
+ trust_remote_code=True,
144
+ use_safetensors=True
145
+ )
146
+ model.to(device)
147
+ model.eval()
148
+
149
+ tokenizer = AutoTokenizer.from_pretrained(
150
+ hf_model_name,
151
+ trust_remote_code=True,
152
+ use_safetensors=True
153
+ )
154
+
155
+ bench = Benchmark(model, tokenizer)
156
+
157
+ df = generate_attention_vectors_from_rationales(df, tokenizer)
158
+
159
+ # List of explainers you want to use
160
+ explainer_names = [
161
+ "Partition SHAP", "LIME", "Gradient", "Gradient (x Input)",
162
+ "Integrated Gradient", "Integrated Gradient (x Input)"
163
+ ]
164
+
165
+ ferret_filename = f"{hf_model_name.replace('/', '_')}_ferret_plausibility.csv"
166
+ header_written = os.path.exists(ferret_filename)
167
+
168
+ # To ensure no empty cells: collect all possible output columns
169
+ all_fieldnames = set(["Index", "Text", "final_label", "final_label_numeric", "Annotations", "Rationale"])
170
+
171
+ # --- MAIN LOOP ---
172
+ for idx in tqdm(range(len(df)), desc="FERRET (Plausibility Only)"):
173
+ row = df.iloc[idx]
174
+
175
+ ferret_row = {
176
+ "Index": idx,
177
+ "Text": row["Text"],
178
+ "final_label": row["final_label"],
179
+ "final_label_numeric": int(row["final_label_numeric"]),
180
+ "Annotations": row.get("Annotations", ""),
181
+ "Rationale": row.get("Rationale", ""),
182
+ }
183
+
184
+ # Prepare explanations for all explainers
185
+ row_explanations = {}
186
+ for explainer_name in explainer_names:
187
+ score_col = f"{explainer_name}_ImportanceScores"
188
+ tokens_col = "Tokens"
189
+ if pd.notna(row.get(score_col)) and pd.notna(row.get(tokens_col)):
190
+ try:
191
+ scores = [float(score) for score in str(row[score_col]).split()]
192
+ tokens = str(row[tokens_col]).split()
193
+ target_label = int(row["final_label_numeric"])
194
+ row_explanations[explainer_name] = Explanation(
195
+ text=row["Text"], tokens=tokens, scores=scores,
196
+ explainer=explainer_name, target=target_label
197
+ )
198
+ except Exception as e:
199
+ print(f"Could not create explanation for explainer {explainer_name} at index {idx}: {e}")
200
+ continue
201
+
202
+ # Discover available metrics for plausibility
203
+ available_metrics = set()
204
+ if row_explanations:
205
+ first_explainer = next(iter(row_explanations.keys()))
206
+ first_explanation = row_explanations[first_explainer]
207
+ for test_annot_idx in range(3):
208
+ test_attn_col = f"embert_attention_{test_annot_idx+1}"
209
+ test_human_rationale_str = str(row.get(test_attn_col, ""))
210
+ test_human_rationale = [int(v) for v in test_human_rationale_str.split() if v.isdigit()]
211
+ if any(test_human_rationale):
212
+ try:
213
+ test_plaus_eval = bench.evaluate_explanations(
214
+ [first_explanation],
215
+ human_rationale=test_human_rationale,
216
+ target=first_explanation.target,
217
+ skip_faithfulness=True
218
+ )
219
+ if test_plaus_eval and len(test_plaus_eval) > 0:
220
+ test_eval_obj = test_plaus_eval[0]
221
+ if hasattr(test_eval_obj, "evaluation_scores") and test_eval_obj.evaluation_scores:
222
+ for score in test_eval_obj.evaluation_scores:
223
+ if score.name in ['auprc_plau', 'token_f1_plau', 'token_iou_plau']:
224
+ available_metrics.add(score.name)
225
+ break
226
+ except Exception as e:
227
+ print(f"Error discovering metrics with {first_explainer} and annotator {test_annot_idx+1}: {e}")
228
+ continue
229
+
230
+ print(f"Row {idx}: Using FERRET plausibility metrics: {list(available_metrics)}")
231
+
232
+ # --- Evaluate plausibility for each explainer/annotator combination ---
233
+ for explainer_name in explainer_names:
234
+ if explainer_name not in row_explanations:
235
+ for annot_idx in range(3):
236
+ for metric in available_metrics:
237
+ colname = f"{explainer_name}Annotator{annot_idx+1}{metric}"
238
+ ferret_row[colname] = "N/A"
239
+ all_fieldnames.add(colname)
240
+ continue
241
+
242
+ explanation = row_explanations[explainer_name]
243
+ label = explanation.target
244
+
245
+ for annot_idx in range(3):
246
+ attn_col = f"embert_attention_{annot_idx+1}"
247
+ human_rationale_str = str(row.get(attn_col, ""))
248
+ human_rationale = [int(v) for v in human_rationale_str.split() if v.isdigit()]
249
+
250
+ annot_labels_list = str(row["Annotations"]).split("|")
251
+ if annot_idx < len(annot_labels_list):
252
+ annot_label_str = annot_labels_list[annot_idx].split("-")[0].strip()
253
+ else:
254
+ annot_label_str = ""
255
+
256
+ final_label_str = inv_label_map[label]
257
+ for metric in available_metrics:
258
+ colname = f"{explainer_name}Annotator{annot_idx+1}{metric}"
259
+ all_fieldnames.add(colname)
260
+
261
+ if annot_label_str != final_label_str:
262
+ for metric in available_metrics:
263
+ ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
264
+ continue
265
+
266
+ if any(human_rationale):
267
+ try:
268
+ plaus_eval = bench.evaluate_explanations(
269
+ [explanation],
270
+ human_rationale=human_rationale,
271
+ target=label,
272
+ skip_faithfulness=True
273
+ )
274
+ if plaus_eval and len(plaus_eval) > 0:
275
+ eval_obj = plaus_eval[0]
276
+ if hasattr(eval_obj, "evaluation_scores") and eval_obj.evaluation_scores:
277
+ for score in eval_obj.evaluation_scores:
278
+ if score.name in ['auprc_plau', 'token_f1_plau', 'token_iou_plau']:
279
+ ferret_row[f"{explainer_name}Annotator{annot_idx+1}{score.name}"] = float(score.score)
280
+ except Exception as e:
281
+ print(f"Error evaluating {explainer_name} for annotator {annot_idx+1} at index {idx}: {e}")
282
+ for metric in available_metrics:
283
+ ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
284
+ else:
285
+ for metric in available_metrics:
286
+ ferret_row[f"{explainer_name}Annotator{annot_idx+1}{metric}"] = "N/A"
287
+
288
+ # --- Ensure no empty cells: fill missing columns with "N/A" ---
289
+ for col in all_fieldnames:
290
+ if col not in ferret_row:
291
+ ferret_row[col] = "N/A"
292
+
293
+ # === SAVE THIS ROW TO CSV IMMEDIATELY ===
294
+ write_header = not header_written
295
+ with open(ferret_filename, mode='a', newline='', encoding='utf-8') as f:
296
+ writer = csv.DictWriter(f, fieldnames=list(all_fieldnames))
297
+ if write_header:
298
+ writer.writeheader()
299
+ header_written = True
300
+ writer.writerow(ferret_row)
301
+
302
+ print(f"[INFO] FERRET plausibility results saved row-wise to {ferret_filename}.")
303
+
304
+ # --- Memory cleanup ---
305
+ print(f"[INFO] Cleaning up memory for model {hf_model_name}...")
306
+ del bench, model, tokenizer, df
307
+ gc.collect()
308
+ if device.type == "cuda":
309
+ torch.cuda.empty_cache()
310
+
311
+ # ==============================
312
+ # End of pipeline
313
+ # ==============================
hyperparameter_tuning_for_rationale.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support
11
+ import itertools
12
+ import warnings
13
+ import random
14
+
15
+ def set_seed(seed=13):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
23
+
24
+ set_seed(13)
25
+ warnings.filterwarnings("ignore", category=FutureWarning)
26
+
27
+ # --- CONFIG ---
28
+ param_grid = {
29
+ "learning_rate": [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
30
+ "batch_size": [16, 32, 64],
31
+ "optimizer": ["Adam"],
32
+ "lambda": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
33
+ }
34
+ num_epochs = 7
35
+ max_length = 128
36
+ model_name = "bert-base-multilingual-cased"
37
+ num_labels = 3
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # --- LOAD DATA ---
42
+ train_df = pd.read_csv("train.csv")
43
+ val_df = pd.read_csv("val.csv")
44
+
45
+ valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
46
+ train_df = train_df[train_df["final_label"].isin(valid_labels.keys())]
47
+ val_df = val_df[val_df["final_label"].isin(valid_labels.keys())]
48
+ if train_df.empty:
49
+ raise ValueError("Train dataset empty after filtering.")
50
+ if val_df.empty:
51
+ raise ValueError("Validation dataset empty after filtering.")
52
+
53
+ # --- INITIALIZE TOKENIZER & ADD EMOJIS ---
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ emoji_path = "emoji.csv" # adjust path if needed
56
+ if os.path.exists(emoji_path):
57
+ emoji_df = pd.read_csv(emoji_path)
58
+ emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
59
+ existing_vocab = set(tokenizer.get_vocab().keys())
60
+ emoji_set = set(emoji_list) - existing_vocab
61
+ if emoji_set:
62
+ tokenizer.add_tokens(list(emoji_set))
63
+ print(f"Added {len(emoji_set)} new emoji tokens to the tokenizer.")
64
+ else:
65
+ print("No new emojis to add.")
66
+ else:
67
+ print(f"Emoji file not found at: {emoji_path}")
68
+
69
+ # --- FUNCTIONS ---
70
+
71
+ def generate_attention_vectors_from_rationales(df, tokenizer, epsilon=1e-8):
72
+ attention_vectors = []
73
+ for _, row in df.iterrows():
74
+ text = str(row["Content"])
75
+ final_label = str(row["final_label"]).strip()
76
+ encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
77
+ offsets = encoding["offset_mapping"]
78
+ num_tokens = len(offsets)
79
+ avg_vector = np.zeros(num_tokens, dtype=np.float32)
80
+ annotations = str(row.get("Annotations", "")).split("|")
81
+ rationales = str(row.get("Rationale", "")).split("|")
82
+ annot_vectors = []
83
+ for annot_label, annot_rationale in zip(annotations, rationales):
84
+ if not annot_label:
85
+ continue
86
+ if annot_label.split("-")[0].strip() != final_label:
87
+ continue
88
+ spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
89
+ if not spans:
90
+ continue
91
+ vec = np.zeros(num_tokens, dtype=np.float32)
92
+ for span_text in spans:
93
+ start = 0
94
+ while True:
95
+ idx = text.find(span_text, start)
96
+ if idx < 0:
97
+ break
98
+ span_start, span_end = idx, idx + len(span_text)
99
+ for i, (tok_start, tok_end) in enumerate(offsets):
100
+ if tok_end > span_start and tok_start < span_end:
101
+ vec[i] = 1.0
102
+ start = idx + 1
103
+ if vec.sum() > 0:
104
+ annot_vectors.append(vec)
105
+ if annot_vectors:
106
+ avg_vector = np.mean(annot_vectors, axis=0)
107
+ avg_vector = np.where(avg_vector == 0, epsilon, avg_vector)
108
+ attn_str = " ".join(f"{v:.8f}" for v in avg_vector)
109
+ attention_vectors.append(attn_str)
110
+ df["embert_attention"] = attention_vectors
111
+ return df
112
+
113
+ class RationaleDataset(Dataset):
114
+ def __init__(self, df, tokenizer, max_length=128, label_mapping=None):
115
+ self.df = df
116
+ self.tokenizer = tokenizer
117
+ self.max_length = max_length
118
+ self.label_mapping = label_mapping
119
+
120
+ def __len__(self):
121
+ return len(self.df)
122
+
123
+ def __getitem__(self, idx):
124
+ row = self.df.iloc[idx]
125
+ text = row["Content"]
126
+ label = self.label_mapping[row["final_label"]]
127
+ encoding = self.tokenizer(
128
+ text, padding="max_length", truncation=True,
129
+ max_length=self.max_length, return_tensors="pt"
130
+ )
131
+ rationale_raw = [float(x) for x in row["embert_attention"].split()] \
132
+ if pd.notna(row["embert_attention"]) and row["embert_attention"].strip() else []
133
+ rationale_vector = np.concatenate([
134
+ np.array([0.0], dtype=np.float32),
135
+ np.array(rationale_raw, dtype=np.float32),
136
+ np.array([0.0], dtype=np.float32)
137
+ ])
138
+ rationale_vector = rationale_vector[:self.max_length]
139
+ if len(rationale_vector) < self.max_length:
140
+ rationale_vector = np.pad(rationale_vector, (0, self.max_length - len(rationale_vector)), constant_values=0.0)
141
+ rationale_tensor = torch.tensor(rationale_vector, dtype=torch.float32)
142
+ if torch.sum(rationale_tensor) == 0.0:
143
+ has_rationale = False
144
+ rationale_probs = torch.ones(self.max_length, dtype=torch.float32) / self.max_length
145
+ else:
146
+ has_rationale = True
147
+ rationale_probs = torch.softmax(rationale_tensor, dim=0)
148
+ return (
149
+ encoding["input_ids"].squeeze(0),
150
+ encoding["attention_mask"].squeeze(0),
151
+ rationale_probs,
152
+ torch.tensor(label, dtype=torch.long),
153
+ torch.tensor(has_rationale, dtype=torch.bool)
154
+ )
155
+
156
+ class RationaleModel(nn.Module):
157
+ def __init__(self, model_name, num_labels):
158
+ super().__init__()
159
+ self.bert = AutoModel.from_pretrained(model_name, output_attentions=True)
160
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
161
+
162
+ def forward(self, input_ids, attention_mask):
163
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
164
+ cls_output = outputs.last_hidden_state[:, 0, :]
165
+ logits = self.classifier(cls_output)
166
+ last_layer_attn = outputs.attentions[-1] # (batch, heads, seq, seq)
167
+ cls_attn = last_layer_attn[:, :, 0, :] # (batch, heads, seq)
168
+ cls_attn_avg = cls_attn.mean(dim=1) # (batch, seq)
169
+ return logits, cls_attn_avg
170
+
171
+ def evaluate_model(model, val_loader, criterion_cls, device):
172
+ model.eval()
173
+ total_val_loss = 0.0
174
+ all_preds = []
175
+ all_labels = []
176
+ all_probs = []
177
+ with torch.no_grad():
178
+ for batch in val_loader:
179
+ input_ids, attention_mask, _, labels, _ = [b.to(device) for b in batch]
180
+ logits, _ = model(input_ids, attention_mask)
181
+ loss = criterion_cls(logits, labels)
182
+ total_val_loss += loss.item()
183
+ probs = torch.softmax(logits, dim=1)
184
+ preds = torch.argmax(probs, dim=1)
185
+ all_preds.extend(preds.cpu().numpy())
186
+ all_labels.extend(labels.cpu().numpy())
187
+ all_probs.extend(probs.cpu().numpy())
188
+ avg_val_loss = total_val_loss / len(val_loader)
189
+ # Overall metrics
190
+ accuracy = accuracy_score(all_labels, all_preds)
191
+ f1_macro = f1_score(all_labels, all_preds, average="macro")
192
+ try:
193
+ y_true_oh = np.eye(num_labels)[all_labels]
194
+ auroc_ovr = roc_auc_score(y_true_oh, all_probs, multi_class="ovr")
195
+ except Exception:
196
+ auroc_ovr = -1.0
197
+ # Class-wise metrics
198
+ class_wise_metrics = {}
199
+ target_names = sorted(valid_labels, key=valid_labels.get)
200
+ precision, recall, f1_per_class, _ = precision_recall_fscore_support(all_labels, all_preds, average=None, labels=[valid_labels[label_name] for label_name in target_names])
201
+ for i, label_name in enumerate(target_names):
202
+ class_wise_metrics[f"{label_name}_precision"] = precision[i]
203
+ class_wise_metrics[f"{label_name}_recall"] = recall[i]
204
+ class_wise_metrics[f"{label_name}_f1"] = f1_per_class[i]
205
+ # Per-class accuracy: of true class samples, how many were predicted correctly
206
+ idx = np.array(all_labels) == valid_labels[label_name]
207
+ if idx.sum() > 0:
208
+ acc = (np.array(all_preds)[idx] == valid_labels[label_name]).sum() / idx.sum()
209
+ else:
210
+ acc = -1.0
211
+ class_wise_metrics[f"{label_name}_accuracy"] = acc
212
+ # Class-wise AUROC
213
+ try:
214
+ binary_labels = (np.array(all_labels) == valid_labels[label_name]).astype(int)
215
+ class_probs = np.array(all_probs)[:, valid_labels[label_name]]
216
+ if len(np.unique(binary_labels)) > 1:
217
+ class_wise_metrics[f"{label_name}_auroc"] = roc_auc_score(binary_labels, class_probs)
218
+ else:
219
+ class_wise_metrics[f"{label_name}_auroc"] = -1.0
220
+ except Exception:
221
+ class_wise_metrics[f"{label_name}_auroc"] = -1.0
222
+ return avg_val_loss, accuracy, f1_macro, auroc_ovr, class_wise_metrics
223
+
224
+ def train_model(model, train_loader, val_loader, num_epochs, device, lambda_attn=1.0, optimizer=None, learning_rate=2e-5, results_writer=None, results_file_handle=None, params=None):
225
+ criterion_cls = nn.CrossEntropyLoss()
226
+ criterion_kl = nn.KLDivLoss(reduction="batchmean")
227
+ if optimizer is None:
228
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
229
+ for epoch in range(num_epochs):
230
+ model.train()
231
+ total_train_loss = 0.0
232
+ for batch in train_loader:
233
+ input_ids, attention_mask, rationale_probs, labels, has_rationale = [b.to(device) for b in batch]
234
+ optimizer.zero_grad()
235
+ logits, model_attention = model(input_ids, attention_mask)
236
+ loss_cls = criterion_cls(logits, labels)
237
+ loss = loss_cls
238
+ if has_rationale.any():
239
+ model_attn_batch = model_attention[has_rationale]
240
+ rationale_batch = rationale_probs[has_rationale]
241
+ log_model_attn = torch.log(model_attn_batch + 1e-8)
242
+ loss_kl = criterion_kl(log_model_attn, rationale_batch)
243
+ loss += lambda_attn * loss_kl
244
+ loss.backward()
245
+ optimizer.step()
246
+ total_train_loss += loss.item()
247
+ avg_train_loss = total_train_loss / len(train_loader)
248
+ val_loss, val_acc, val_f1_macro, val_auroc_ovr, class_wise_metrics = evaluate_model(model, val_loader, criterion_cls, device)
249
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1 (Macro): {val_f1_macro:.4f} | Val AUROC (OvR): {val_auroc_ovr:.4f}")
250
+ sorted_labels = sorted(valid_labels, key=valid_labels.get)
251
+ for label_name in sorted_labels:
252
+ print(f" {label_name}: P={class_wise_metrics[f'{label_name}_precision']:.4f}, R={class_wise_metrics[f'{label_name}_recall']:.4f}, F1={class_wise_metrics[f'{label_name}_f1']:.4f}, Acc={class_wise_metrics[f'{label_name}_accuracy']:.4f}, AUROC={class_wise_metrics[f'{label_name}_auroc']:.4f}")
253
+ if results_writer and results_file_handle:
254
+ row_data = [
255
+ params["learning_rate"],
256
+ params["batch_size"],
257
+ params["optimizer"],
258
+ params["lambda"],
259
+ epoch + 1,
260
+ avg_train_loss,
261
+ val_loss,
262
+ val_acc,
263
+ val_f1_macro,
264
+ val_auroc_ovr
265
+ ]
266
+ for label_name in sorted_labels:
267
+ row_data.extend([
268
+ class_wise_metrics[f"{label_name}_precision"],
269
+ class_wise_metrics[f"{label_name}_recall"],
270
+ class_wise_metrics[f"{label_name}_f1"],
271
+ class_wise_metrics[f"{label_name}_accuracy"],
272
+ class_wise_metrics[f"{label_name}_auroc"]
273
+ ])
274
+ results_writer.writerow(row_data)
275
+ results_file_handle.flush()
276
+ os.fsync(results_file_handle.fileno())
277
+
278
+ # --- PREPARE DATASETS ---
279
+ print("Generating attention vectors for training data...")
280
+ train_df = generate_attention_vectors_from_rationales(train_df, tokenizer)
281
+ print("Generating attention vectors for validation data...")
282
+ val_df = generate_attention_vectors_from_rationales(val_df, tokenizer)
283
+
284
+ train_dataset = RationaleDataset(train_df, tokenizer, max_length, label_mapping=valid_labels)
285
+ val_dataset = RationaleDataset(val_df, tokenizer, max_length, label_mapping=valid_labels)
286
+
287
+ # --- GRID SEARCH LOOP ---
288
+ keys, values = zip(*param_grid.items())
289
+ param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
290
+ results_file = "grid_results_detailed.csv"
291
+ headers = ["learning_rate", "batch_size", "optimizer", "lambda", "epoch", "train_loss", "val_loss", "val_accuracy", "val_f1_macro", "val_auroc_ovr"]
292
+ sorted_labels = sorted(valid_labels, key=valid_labels.get)
293
+ for label_name in sorted_labels:
294
+ headers.extend([f"{label_name}_precision", f"{label_name}_recall", f"{label_name}_f1", f"{label_name}_accuracy", f"{label_name}_auroc"])
295
+ with open(results_file, mode="w", newline="") as f:
296
+ writer = csv.writer(f)
297
+ writer.writerow(headers)
298
+ for params in param_combinations:
299
+ print("\nRunning:", params)
300
+ learning_rate = params["learning_rate"]
301
+ batch_size = params["batch_size"]
302
+ optimizer_type = params["optimizer"]
303
+ lambda_attn = params["lambda"]
304
+ model = RationaleModel(model_name=model_name, num_labels=num_labels).to(device)
305
+ if 'emoji_set' in locals() and len(emoji_set) > 0:
306
+ model.bert.resize_token_embeddings(len(tokenizer))
307
+ if optimizer_type == "Adam":
308
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
309
+ else:
310
+ raise ValueError("Unsupported optimizer")
311
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(13))
312
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
313
+ train_model(
314
+ model=model,
315
+ train_loader=train_loader,
316
+ val_loader=val_loader,
317
+ num_epochs=num_epochs,
318
+ device=device,
319
+ lambda_attn=lambda_attn,
320
+ optimizer=optimizer,
321
+ learning_rate=learning_rate,
322
+ results_writer=writer,
323
+ results_file_handle=f,
324
+ params=params
325
+ )
326
+ print("Grid search complete. Results saved to:", results_file)
hyperparameter_tuning_without_rationale.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import pandas as pd
6
+ import itertools
7
+ from datetime import datetime
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from transformers import BertTokenizer, BertForSequenceClassification
10
+ from torch.optim import Adam
11
+ from sklearn.metrics import f1_score, roc_auc_score, precision_recall_fscore_support, accuracy_score
12
+
13
+ # Reproducibility
14
+ def set_seed(seed=13):
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ set_seed(13)
23
+
24
+ # Configurations
25
+ param_grid = {
26
+ "learning_rate": [1e-5, 2e-5, 3e-5, 4e-5, 5e-5],
27
+ "batch_size": [16, 32, 64]
28
+ }
29
+ num_epochs = 10
30
+ max_length = 128
31
+ model_name = "bert-base-multilingual-cased"
32
+ num_labels = 3
33
+
34
+ # Tokenizer + Emoji Extension
35
+ emoji_df = pd.read_csv("emoji.csv")
36
+ emoji_list = emoji_df.iloc[:, 0].dropna().astype(str).unique().tolist()
37
+
38
+ tokenizer = BertTokenizer.from_pretrained(model_name)
39
+ new_tokens = list(set(emoji_list) - set(tokenizer.vocab.keys()))
40
+ if new_tokens:
41
+ tokenizer.add_tokens(new_tokens)
42
+ print(f"Added {len(new_tokens)} emojis to tokenizer.")
43
+
44
+ # Data loading
45
+ train_df = pd.read_csv("train.csv")
46
+ val_df = pd.read_csv("val.csv")
47
+
48
+ valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
49
+ train_df = train_df[train_df["final_label"].isin(valid_labels)]
50
+ val_df = val_df[val_df["final_label"].isin(valid_labels)]
51
+
52
+ class CustomDataset(Dataset):
53
+ def __init__(self, dataframe, tokenizer, max_length):
54
+ self.dataframe = dataframe.reset_index(drop=True)
55
+ self.tokenizer = tokenizer
56
+ self.max_length = max_length
57
+
58
+ def __len__(self):
59
+ return len(self.dataframe)
60
+
61
+ def __getitem__(self, idx):
62
+ row = self.dataframe.iloc[idx]
63
+ text = row["Content"]
64
+ label = valid_labels[row["final_label"]]
65
+ encoding = self.tokenizer(
66
+ text,
67
+ padding="max_length",
68
+ truncation=True,
69
+ max_length=self.max_length,
70
+ return_tensors="pt"
71
+ )
72
+ return (
73
+ encoding["input_ids"].squeeze(0),
74
+ encoding["attention_mask"].squeeze(0),
75
+ torch.tensor(label, dtype=torch.long)
76
+ )
77
+
78
+ train_dataset = CustomDataset(train_df, tokenizer, max_length)
79
+ val_dataset = CustomDataset(val_df, tokenizer, max_length)
80
+
81
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ # Results directory and file
84
+ os.makedirs("results", exist_ok=True)
85
+ results_path = "results/grid_search_metrics.csv"
86
+
87
+ if not os.path.exists(results_path):
88
+ with open(results_path, "w") as f:
89
+ f.write("timestamp,learning_rate,batch_size,epoch,val_macro_f1,val_auroc,"
90
+ "acc_negative,prec_negative,rec_negative,f1_negative,"
91
+ "acc_neutral,prec_neutral,rec_neutral,f1_neutral,"
92
+ "acc_positive,prec_positive,rec_positive,f1_positive\n")
93
+
94
+ # Hyperparameter grid search
95
+ for lr, bs in itertools.product(param_grid["learning_rate"], param_grid["batch_size"]):
96
+ print(f"\nStarting config: LR={lr}, Batch Size={bs}")
97
+ set_seed(13)
98
+ train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
99
+ val_loader = DataLoader(val_dataset, batch_size=bs)
100
+
101
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)
102
+ if new_tokens:
103
+ model.resize_token_embeddings(len(tokenizer))
104
+
105
+ optimizer = Adam(model.parameters(), lr=lr)
106
+
107
+ for epoch in range(1, num_epochs + 1):
108
+ model.train()
109
+ for batch in train_loader:
110
+ input_ids, attention_mask, labels = [b.to(device) for b in batch]
111
+ optimizer.zero_grad()
112
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
113
+ outputs.loss.backward()
114
+ optimizer.step()
115
+
116
+ # Evaluation
117
+ model.eval()
118
+ val_preds, val_probs, val_labels = [], [], []
119
+ with torch.no_grad():
120
+ for batch in val_loader:
121
+ input_ids, attention_mask, labels = [b.to(device) for b in batch]
122
+ logits = model(input_ids, attention_mask=attention_mask).logits
123
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
124
+ preds = torch.argmax(logits, axis=1).cpu().tolist()
125
+
126
+ val_probs.extend(probs)
127
+ val_preds.extend(preds)
128
+ val_labels.extend(labels.cpu().tolist())
129
+
130
+ val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
131
+ val_auroc = roc_auc_score(
132
+ np.eye(num_labels)[val_labels],
133
+ np.array(val_probs),
134
+ average="macro",
135
+ multi_class="ovr"
136
+ )
137
+
138
+ # Label-wise metrics
139
+ report = precision_recall_fscore_support(val_labels, val_preds, labels=[0, 1, 2], zero_division=0)
140
+ acc_per_label = []
141
+ for i in range(num_labels):
142
+ idx = np.array(val_labels) == i
143
+ correct = (np.array(val_preds)[idx] == i).sum()
144
+ total = idx.sum()
145
+ acc = correct / total if total > 0 else 0
146
+ acc_per_label.append(acc)
147
+
148
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
149
+ row = [
150
+ timestamp, lr, bs, epoch, f"{val_macro_f1:.4f}", f"{val_auroc:.4f}"
151
+ ]
152
+ for i in range(num_labels):
153
+ row.extend([
154
+ f"{acc_per_label[i]:.4f}",
155
+ f"{report[0][i]:.4f}", # precision
156
+ f"{report[1][i]:.4f}", # recall
157
+ f"{report[2][i]:.4f}" # f1
158
+ ])
159
+
160
+ with open(results_path, "a") as f:
161
+ f.write(",".join(map(str, row)) + "\n")
162
+
163
+ print(f"[Epoch {epoch}] LR={lr}, BS={bs} | F1={val_macro_f1:.4f} | AUROC={val_auroc:.4f}")
164
+
165
+ print(f"\nGrid Search Complete. Results saved to: {results_path}")
model_training_with_rationale.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support
10
+ import warnings
11
+ import random
12
+
13
+ def set_seed(seed=13):
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
21
+
22
+ set_seed(13)
23
+ warnings.filterwarnings("ignore", category=FutureWarning)
24
+
25
+ # --- CONFIG ---
26
+ model_name = "bert-base-multilingual-cased" # Set your model name here
27
+ num_epochs = 4
28
+ max_length = 128
29
+ num_labels = 3
30
+ learning_rate = 2e-5
31
+ batch_size = 64
32
+ optimizer_type = "Adam"
33
+ lambda_attn = 0.6
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # --- LOAD DATA ---
38
+ train_df = pd.read_csv("train.csv")
39
+ val_df = pd.read_csv("val.csv")
40
+ test_df = pd.read_csv("test.csv")
41
+ valid_labels = {"Negative": 0, "Neutral": 1, "Positive": 2}
42
+ train_df = train_df[train_df["final_label"].isin(valid_labels.keys())]
43
+ val_df = val_df[val_df["final_label"].isin(valid_labels.keys())]
44
+ test_df = test_df[test_df["final_label"].isin(valid_labels.keys())]
45
+ if train_df.empty:
46
+ raise ValueError("Train dataset empty after filtering.")
47
+ if val_df.empty:
48
+ raise ValueError("Validation dataset empty after filtering.")
49
+
50
+ # --- FUNCTIONS ---
51
+ def generate_attention_vectors_from_rationales(df, tokenizer, epsilon=1e-8):
52
+ attention_vectors = []
53
+ for _, row in df.iterrows():
54
+ text = str(row["Content"])
55
+ final_label = str(row["final_label"]).strip()
56
+ encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
57
+ offsets = encoding["offset_mapping"]
58
+ num_tokens = len(offsets)
59
+ avg_vector = np.zeros(num_tokens, dtype=np.float32)
60
+ annotations = str(row.get("Annotations", "")).split("|")
61
+ rationales = str(row.get("Rationale", "")).split("|")
62
+ annot_vectors = []
63
+ for annot_label, annot_rationale in zip(annotations, rationales):
64
+ if not annot_label:
65
+ continue
66
+ if annot_label.split("-")[0].strip() != final_label:
67
+ continue
68
+ spans = [s.strip() for s in annot_rationale.split(",") if s.strip()]
69
+ if not spans:
70
+ continue
71
+ vec = np.zeros(num_tokens, dtype=np.float32)
72
+ for span_text in spans:
73
+ start = 0
74
+ while True:
75
+ idx = text.find(span_text, start)
76
+ if idx < 0:
77
+ break
78
+ span_start, span_end = idx, idx + len(span_text)
79
+ for i, (tok_start, tok_end) in enumerate(offsets):
80
+ if tok_end > span_start and tok_start < span_end:
81
+ vec[i] = 1.0
82
+ start = idx + 1
83
+ if vec.sum() > 0:
84
+ annot_vectors.append(vec)
85
+ if annot_vectors:
86
+ avg_vector = np.mean(annot_vectors, axis=0)
87
+ avg_vector = np.where(avg_vector == 0, epsilon, avg_vector)
88
+ attn_str = " ".join(f"{v:.8f}" for v in avg_vector)
89
+ attention_vectors.append(attn_str)
90
+ df["embert_attention"] = attention_vectors
91
+ return df
92
+
93
+ class RationaleDataset(Dataset):
94
+ def __init__(self, df, tokenizer, max_length=128, label_mapping=None):
95
+ self.df = df
96
+ self.tokenizer = tokenizer
97
+ self.max_length = max_length
98
+ self.label_mapping = label_mapping
99
+
100
+ def __len__(self):
101
+ return len(self.df)
102
+
103
+ def __getitem__(self, idx):
104
+ row = self.df.iloc[idx]
105
+ text = row["Content"]
106
+ label = self.label_mapping[row["final_label"]]
107
+ encoding = self.tokenizer(
108
+ text, padding="max_length", truncation=True,
109
+ max_length=self.max_length, return_tensors="pt"
110
+ )
111
+ rationale_raw = [float(x) for x in row["embert_attention"].split()] \
112
+ if pd.notna(row["embert_attention"]) and row["embert_attention"].strip() else []
113
+ rationale_vector = np.concatenate([
114
+ np.array([0.0], dtype=np.float32),
115
+ np.array(rationale_raw, dtype=np.float32),
116
+ np.array([0.0], dtype=np.float32)
117
+ ])
118
+ rationale_vector = rationale_vector[:self.max_length]
119
+ if len(rationale_vector) < self.max_length:
120
+ rationale_vector = np.pad(rationale_vector, (0, self.max_length - len(rationale_vector)), constant_values=0.0)
121
+ rationale_tensor = torch.tensor(rationale_vector, dtype=torch.float32)
122
+ if torch.sum(rationale_tensor) == 0.0:
123
+ has_rationale = False
124
+ rationale_probs = torch.ones(self.max_length, dtype=torch.float32) / self.max_length
125
+ else:
126
+ has_rationale = True
127
+ rationale_probs = torch.softmax(rationale_tensor, dim=0)
128
+ return (
129
+ encoding["input_ids"].squeeze(0),
130
+ encoding["attention_mask"].squeeze(0),
131
+ rationale_probs,
132
+ torch.tensor(label, dtype=torch.long),
133
+ torch.tensor(has_rationale, dtype=torch.bool)
134
+ )
135
+
136
+ class RationaleModel(nn.Module):
137
+ def __init__(self, model_name, num_labels):
138
+ super().__init__()
139
+ self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, output_attentions=True)
140
+ def forward(self, input_ids, attention_mask):
141
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
142
+ logits = outputs.logits
143
+ last_layer_attn = outputs.attentions[-1] # (batch, heads, seq, seq)
144
+ cls_attn = last_layer_attn[:, :, 0, :] # (batch, heads, seq)
145
+ cls_attn_avg = cls_attn.mean(dim=1) # (batch, seq)
146
+ return logits, cls_attn_avg
147
+
148
+ def evaluate_model(model, val_loader, criterion_cls, device, valid_labels, num_labels):
149
+ model.eval()
150
+ total_val_loss = 0.0
151
+ all_preds = []
152
+ all_labels = []
153
+ all_probs = []
154
+ with torch.no_grad():
155
+ for batch in val_loader:
156
+ input_ids, attention_mask, _, labels, _ = [b.to(device) for b in batch]
157
+ logits, _ = model(input_ids, attention_mask)
158
+ loss = criterion_cls(logits, labels)
159
+ total_val_loss += loss.item()
160
+ probs = torch.softmax(logits, dim=1)
161
+ preds = torch.argmax(probs, dim=1)
162
+ all_preds.extend(preds.cpu().numpy())
163
+ all_labels.extend(labels.cpu().numpy())
164
+ all_probs.extend(probs.cpu().numpy())
165
+ avg_val_loss = total_val_loss / len(val_loader)
166
+ all_labels_np = np.array(all_labels)
167
+ all_preds_np = np.array(all_preds)
168
+ all_probs_np = np.array(all_probs)
169
+ accuracy = accuracy_score(all_labels_np, all_preds_np)
170
+ f1_macro = f1_score(all_labels_np, all_preds_np, average="macro")
171
+ try:
172
+ y_true_oh = np.eye(num_labels)[all_labels_np]
173
+ auroc_ovr = roc_auc_score(y_true_oh, all_probs_np, multi_class="ovr")
174
+ except:
175
+ auroc_ovr = -1.0
176
+ class_wise_metrics = {}
177
+ target_names = sorted(valid_labels, key=valid_labels.get)
178
+ label_indices = [valid_labels[label_name] for label_name in target_names]
179
+ precision, recall, f1_per_class, support = precision_recall_fscore_support(
180
+ all_labels_np, all_preds_np, labels=label_indices, average=None)
181
+ for i, label_name in enumerate(target_names):
182
+ label_id = valid_labels[label_name]
183
+ class_wise_metrics[f"{label_name}_precision"] = precision[i]
184
+ class_wise_metrics[f"{label_name}_recall"] = recall[i]
185
+ class_wise_metrics[f"{label_name}_f1"] = f1_per_class[i]
186
+ label_mask = all_labels_np == label_id
187
+ correct_preds = np.sum((all_preds_np == label_id) & label_mask)
188
+ total_label = np.sum(label_mask)
189
+ if total_label > 0:
190
+ class_wise_metrics[f"{label_name}_accuracy"] = correct_preds / total_label
191
+ else:
192
+ class_wise_metrics[f"{label_name}_accuracy"] = -1.0
193
+ try:
194
+ binary_labels = (all_labels_np == label_id).astype(int)
195
+ class_probs = all_probs_np[:, label_id]
196
+ if len(np.unique(binary_labels)) > 1:
197
+ class_wise_metrics[f"{label_name}_auroc"] = roc_auc_score(binary_labels, class_probs)
198
+ else:
199
+ class_wise_metrics[f"{label_name}_auroc"] = -1.0
200
+ except:
201
+ class_wise_metrics[f"{label_name}_auroc"] = -1.0
202
+ return avg_val_loss, accuracy, f1_macro, auroc_ovr, class_wise_metrics
203
+
204
+ def train_model(model, train_loader, val_loader, num_epochs, device, lambda_attn=1.0, optimizer=None, learning_rate=2e-5, results_writer=None, results_file_handle=None):
205
+ criterion_cls = nn.CrossEntropyLoss()
206
+ criterion_kl = nn.KLDivLoss(reduction="batchmean")
207
+ if optimizer is None:
208
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
209
+ for epoch in range(num_epochs):
210
+ model.train()
211
+ total_train_loss = 0.0
212
+ for batch in train_loader:
213
+ input_ids, attention_mask, rationale_probs, labels, has_rationale = [b.to(device) for b in batch]
214
+ optimizer.zero_grad()
215
+ logits, model_attention = model(input_ids, attention_mask)
216
+ loss_cls = criterion_cls(logits, labels)
217
+ loss = loss_cls
218
+ if has_rationale.any():
219
+ model_attn_batch = model_attention[has_rationale]
220
+ rationale_batch = rationale_probs[has_rationale]
221
+ log_model_attn = torch.log(model_attn_batch + 1e-8)
222
+ loss_kl = criterion_kl(log_model_attn, rationale_batch)
223
+ loss += lambda_attn * loss_kl
224
+ loss.backward()
225
+ optimizer.step()
226
+ total_train_loss += loss.item()
227
+ avg_train_loss = total_train_loss / len(train_loader)
228
+ val_loss, val_acc, val_f1_macro, val_auroc_ovr, class_wise_metrics = evaluate_model(model, val_loader, criterion_cls, device, valid_labels, num_labels)
229
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1 (Macro): {val_f1_macro:.4f} | Val AUROC (OvR): {val_auroc_ovr:.4f}")
230
+ sorted_labels = sorted(valid_labels, key=valid_labels.get)
231
+ for label_name in sorted_labels:
232
+ print(f" {label_name}: P={class_wise_metrics[f'{label_name}_precision']:.4f}, R={class_wise_metrics[f'{label_name}_recall']:.4f}, F1={class_wise_metrics[f'{label_name}_f1']:.4f}, Acc={class_wise_metrics[f'{label_name}_accuracy']:.4f}, AUROC={class_wise_metrics[f'{label_name}_auroc']:.4f}")
233
+
234
+ if results_writer and results_file_handle:
235
+ row_data = [
236
+ learning_rate,
237
+ batch_size,
238
+ optimizer_type,
239
+ lambda_attn,
240
+ epoch + 1,
241
+ avg_train_loss,
242
+ val_loss,
243
+ val_acc,
244
+ val_f1_macro,
245
+ val_auroc_ovr
246
+ ]
247
+ for label_name in sorted_labels:
248
+ row_data.extend([
249
+ class_wise_metrics[f"{label_name}_precision"],
250
+ class_wise_metrics[f"{label_name}_recall"],
251
+ class_wise_metrics[f"{label_name}_f1"],
252
+ class_wise_metrics[f"{label_name}_accuracy"],
253
+ class_wise_metrics[f"{label_name}_auroc"]
254
+ ])
255
+ results_writer.writerow(row_data)
256
+ results_file_handle.flush()
257
+ os.fsync(results_file_handle.fileno())
258
+
259
+ # --- OUTPUT FOLDERS ---
260
+ csv_output_dir = "csv_outputs"
261
+ os.makedirs(csv_output_dir, exist_ok=True)
262
+ results_file = os.path.join(csv_output_dir, "results_detailed.csv")
263
+ headers = ["learning_rate", "batch_size", "optimizer", "lambda", "epoch", "train_loss", "val_loss", "val_accuracy", "val_f1_macro", "val_auroc_ovr"]
264
+ sorted_labels = sorted(valid_labels, key=valid_labels.get)
265
+ for label in sorted_labels:
266
+ headers.extend([f"{label}_precision", f"{label}_recall", f"{label}_f1", f"{label}_accuracy", f"{label}_auroc"])
267
+
268
+ # --- INITIALIZE TOKENIZER & ADD EMOJIS ---
269
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
270
+ emoji_path = "emoji.csv"
271
+ if os.path.exists(emoji_path):
272
+ emoji_df = pd.read_csv(emoji_path)
273
+ emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
274
+ existing_vocab = set(tokenizer.get_vocab().keys())
275
+ emoji_set = set(emoji_list) - existing_vocab
276
+ if emoji_set:
277
+ tokenizer.add_tokens(list(emoji_set))
278
+ print(f"Added {len(emoji_set)} new emoji tokens to the tokenizer.")
279
+ else:
280
+ print("No new emojis to add.")
281
+ else:
282
+ print(f"Emoji file not found at: {emoji_path}")
283
+
284
+ # --- PREPARE DATASETS ---
285
+ print("Generating attention vectors for training data...")
286
+ train_df_model = generate_attention_vectors_from_rationales(train_df.copy(), tokenizer)
287
+ print("Generating attention vectors for validation data...")
288
+ val_df_model = generate_attention_vectors_from_rationales(val_df.copy(), tokenizer)
289
+
290
+ train_dataset = RationaleDataset(train_df_model, tokenizer, max_length, label_mapping=valid_labels)
291
+ val_dataset = RationaleDataset(val_df_model, tokenizer, max_length, label_mapping=valid_labels)
292
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=torch.Generator().manual_seed(13))
293
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
294
+
295
+ # --- CSV Setup ---
296
+ with open(results_file, mode="w", newline="") as f:
297
+ writer = csv.writer(f)
298
+ writer.writerow(headers)
299
+ model = RationaleModel(model_name=model_name, num_labels=num_labels).to(device)
300
+ if 'emoji_set' in locals() and len(emoji_set) > 0:
301
+ model.bert.resize_token_embeddings(len(tokenizer))
302
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
303
+ train_model(
304
+ model=model,
305
+ train_loader=train_loader,
306
+ val_loader=val_loader,
307
+ num_epochs=num_epochs,
308
+ device=device,
309
+ lambda_attn=lambda_attn,
310
+ optimizer=optimizer,
311
+ learning_rate=learning_rate,
312
+ results_writer=writer,
313
+ results_file_handle=f
314
+ )
315
+ # Save final model and tokenizer
316
+ model.bert.save_pretrained("model_outputs")
317
+ tokenizer.save_pretrained("model_outputs")
318
+ print(f"Final model and tokenizer saved to model_outputs")
model_training_without_rationale.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from torch.optim import Adam
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import pandas as pd
6
+ import numpy as np
7
+ import os
8
+ import warnings
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from sklearn.metrics import (
12
+ f1_score, roc_auc_score, accuracy_score,
13
+ precision_recall_fscore_support, confusion_matrix
14
+ )
15
+
16
+ warnings.filterwarnings("ignore", category=FutureWarning)
17
+
18
+ # --- CONFIG ---
19
+ args_dict = {
20
+ "batch_size": 64,
21
+ "num_epochs": 4,
22
+ "learning_rate": 2e-5,
23
+ "max_length": 128,
24
+ "model_name": "bert-base-multilingual-cased",#replace with your model_name
25
+ "num_labels": 3,
26
+ "save_dir": "./saved_model"
27
+ }
28
+ os.makedirs(args_dict["save_dir"], exist_ok=True)
29
+
30
+ # --- LABEL MAPPING ---
31
+ label_mapping = {"Negative": 0, "Neutral": 1, "Positive": 2}
32
+ label2name = {v: k for k, v in label_mapping.items()}
33
+ label_ids = list(label2name.keys())
34
+
35
+ # --- LOAD DATA ---
36
+ train_df = pd.read_csv("train.csv")
37
+ val_df = pd.read_csv("val.csv")
38
+ test_df = pd.read_csv("test.csv")
39
+ emoji_df = pd.read_csv("emoji.csv")
40
+
41
+ # --- FILTER INVALID LABELS ---
42
+ train_df = train_df[train_df["final_label"].isin(label_mapping)]
43
+ val_df = val_df[val_df["final_label"].isin(label_mapping)]
44
+ test_df = test_df[test_df["final_label"].isin(label_mapping)]
45
+
46
+ # --- TOKENIZER ---
47
+ tokenizer = AutoTokenizer.from_pretrained(args_dict["model_name"])
48
+ emoji_list = emoji_df["emoji"].dropna().astype(str).str.strip().tolist()
49
+ emoji_set = set(emoji_list) - set(tokenizer.vocab.keys())
50
+ if emoji_set:
51
+ tokenizer.add_tokens(list(emoji_set))
52
+ print(f"Added {len(emoji_set)} emojis to tokenizer.")
53
+
54
+ # --- MODEL ---
55
+ model = AutoModelForSequenceClassification.from_pretrained(
56
+ args_dict["model_name"], num_labels=args_dict["num_labels"]
57
+ )
58
+ model.resize_token_embeddings(len(tokenizer))
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ model.to(device)
61
+
62
+ # --- DATASET ---
63
+ class SimpleTextDataset(Dataset):
64
+ def __init__(self, dataframe, tokenizer, max_length=128):
65
+ self.dataframe = dataframe
66
+ self.tokenizer = tokenizer
67
+ self.max_length = max_length
68
+
69
+ def __len__(self):
70
+ return len(self.dataframe)
71
+
72
+ def __getitem__(self, idx):
73
+ row = self.dataframe.iloc[idx]
74
+ text = row["Content"]
75
+ label = label_mapping[row["final_label"]]
76
+ encoding = self.tokenizer(
77
+ text, padding="max_length", truncation=True,
78
+ max_length=self.max_length, return_tensors="pt"
79
+ )
80
+ return (
81
+ encoding["input_ids"].squeeze(0),
82
+ encoding["attention_mask"].squeeze(0),
83
+ torch.tensor(label, dtype=torch.long),
84
+ text
85
+ )
86
+
87
+ # --- DATALOADERS ---
88
+ train_loader = DataLoader(SimpleTextDataset(train_df, tokenizer), batch_size=args_dict["batch_size"], shuffle=True)
89
+ val_loader = DataLoader(SimpleTextDataset(val_df, tokenizer), batch_size=args_dict["batch_size"])
90
+ test_loader = DataLoader(SimpleTextDataset(test_df, tokenizer), batch_size=args_dict["batch_size"])
91
+
92
+ # --- TRAINING ---
93
+ optimizer = Adam(model.parameters(), lr=args_dict["learning_rate"])
94
+ val_metrics_history = []
95
+
96
+ for epoch in range(1, args_dict["num_epochs"] + 1):
97
+ model.train()
98
+ total_loss = 0
99
+ for batch in train_loader:
100
+ input_ids, attn_mask, labels, _ = [x.to(device) for x in batch[:3]]
101
+ optimizer.zero_grad()
102
+ outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
103
+ outputs.loss.backward()
104
+ optimizer.step()
105
+ total_loss += outputs.loss.item()
106
+ avg_train_loss = total_loss / len(train_loader)
107
+
108
+ # --- VALIDATION ---
109
+ model.eval()
110
+ val_preds, val_labels, val_loss = [], [], 0
111
+ with torch.no_grad():
112
+ for batch in val_loader:
113
+ input_ids, attn_mask, labels, _ = [x.to(device) for x in batch[:3]]
114
+ outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
115
+ val_preds.extend(outputs.logits.argmax(dim=1).cpu().numpy())
116
+ val_labels.extend(labels.cpu().numpy())
117
+ val_loss += outputs.loss.item()
118
+ val_loss /= len(val_loader)
119
+ val_acc = accuracy_score(val_labels, val_preds)
120
+ val_f1 = f1_score(val_labels, val_preds, average="weighted")
121
+ try:
122
+ val_auroc = roc_auc_score(
123
+ pd.get_dummies(val_labels), pd.get_dummies(val_preds),
124
+ average="weighted", multi_class="ovo"
125
+ )
126
+ except:
127
+ val_auroc = float("nan")
128
+
129
+ # --- Label-wise Metrics ---
130
+ prec, rec, f1, supp = precision_recall_fscore_support(val_labels, val_preds, labels=[0,1,2])
131
+ labelwise = {}
132
+ for i in [0, 1, 2]:
133
+ idx = np.array(val_labels) == i
134
+ if idx.sum() > 0:
135
+ acc = (np.array(val_preds)[idx] == i).sum() / idx.sum()
136
+ else:
137
+ acc = 0.0
138
+ labelwise[label2name[i]] = {
139
+ "val_acc": acc,
140
+ "val_f1": f1[i],
141
+ "val_precision": prec[i],
142
+ "val_recall": rec[i],
143
+ "val_support": supp[i]
144
+ }
145
+
146
+ val_metrics_history.append({
147
+ "epoch": epoch,
148
+ "train_loss": avg_train_loss,
149
+ "val_loss": val_loss,
150
+ "val_accuracy": val_acc,
151
+ "val_f1": val_f1,
152
+ "val_auroc": val_auroc,
153
+ **{f"{label}_{m}": labelwise[label][m]
154
+ for label in labelwise for m in labelwise[label]}
155
+ })
156
+
157
+ print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f} | Val Acc={val_acc:.4f} | Val F1={val_f1:.4f} | AUROC={val_auroc:.4f}")
158
+
159
+ model.save_pretrained(args_dict["save_dir"])
160
+ tokenizer.save_pretrained(args_dict["save_dir"])
161
+ print(f"Last model saved after epoch {args_dict['num_epochs']}")
162
+ # --- SAVE VAL METRICS ---
163
+ pd.DataFrame(val_metrics_history).to_csv("val_metrics_detailed.csv", index=False)
164
+
165
+ # --- LOAD BEST MODEL ---
166
+ model = AutoModelForSequenceClassification.from_pretrained(args_dict["save_dir"]).to(device)
167
+ tokenizer = AutoTokenizer.from_pretrained(args_dict["save_dir"])
168
+
169
+ # --- TEST EVAL ---
170
+ model.eval()
171
+ all_preds, all_labels, all_sentences, all_tokens = [], [], [], []
172
+ test_loss = 0
173
+
174
+ with torch.no_grad():
175
+ for batch in test_loader:
176
+ input_ids, attn_mask, labels, sentences = batch
177
+ input_ids, attn_mask, labels = input_ids.to(device), attn_mask.to(device), labels.to(device)
178
+ outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
179
+ test_loss += outputs.loss.item()
180
+ preds = outputs.logits.argmax(dim=1)
181
+ all_preds.extend(preds.cpu().numpy())
182
+ all_labels.extend(labels.cpu().numpy())
183
+ all_sentences.extend(sentences)
184
+ all_tokens.extend(tokenizer.batch_decode(input_ids.cpu(), skip_special_tokens=True))
185
+
186
+ test_loss /= len(test_loader)
187
+ test_acc = accuracy_score(all_labels, all_preds)
188
+ test_f1 = f1_score(all_labels, all_preds, average="weighted")
189
+ try:
190
+ test_auroc = roc_auc_score(pd.get_dummies(all_labels), pd.get_dummies(all_preds), average="weighted", multi_class="ovo")
191
+ except:
192
+ test_auroc = float("nan")
193
+
194
+ # --- LABEL-WISE TEST METRICS ---
195
+ prec, rec, f1, supp = precision_recall_fscore_support(all_labels, all_preds, labels=[0,1,2])
196
+ label_metrics = {
197
+ "Label": [], "Accuracy": [], "F1": [], "Precision": [], "Recall": [], "Support": []
198
+ }
199
+ for i in [0, 1, 2]:
200
+ idx = np.array(all_labels) == i
201
+ if idx.sum() > 0:
202
+ acc = (np.array(all_preds)[idx] == i).sum() / idx.sum()
203
+ else:
204
+ acc = 0.0
205
+ label_name = label2name[i]
206
+ label_metrics["Label"].append(label_name)
207
+ label_metrics["Accuracy"].append(acc)
208
+ label_metrics["F1"].append(f1[i])
209
+ label_metrics["Precision"].append(prec[i])
210
+ label_metrics["Recall"].append(rec[i])
211
+ label_metrics["Support"].append(supp[i])
212
+ pd.DataFrame(label_metrics).to_csv("labelwise_test_metrics.csv", index=False)
213
+
214
+ # --- OVERALL TEST METRICS CSV ---
215
+ pd.DataFrame([{
216
+ "Test Loss": test_loss,
217
+ "Test Accuracy": test_acc,
218
+ "Test F1 Score": test_f1,
219
+ "Test AUROC": test_auroc
220
+ }]).to_csv("overall_test_metrics.csv", index=False)
221
+
222
+ # --- TEST PREDICTIONS ---
223
+ pd.DataFrame({
224
+ "Content": all_sentences,
225
+ "Tokens": all_tokens,
226
+ "final_label": [label2name[l] for l in all_labels],
227
+ "predicted_label": [label2name[p] for p in all_preds]
228
+ }).to_csv("test_predictions.csv", index=False)
229
+
230
+ # --- CONFUSION MATRIX ---
231
+ conf_matrix = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])
232
+ conf_matrix_df = pd.DataFrame(conf_matrix, index=[label2name[i] for i in [0,1,2]],
233
+ columns=[label2name[i] for i in [0,1,2]])
234
+ conf_matrix_df.to_csv("confusion_matrix.csv")
235
+
236
+ # --- CONFUSION MATRIX PLOT ---
237
+ plt.figure(figsize=(6, 5))
238
+ sns.heatmap(conf_matrix_df, annot=True, fmt='d', cmap='Blues')
239
+ plt.title("Confusion Matrix")
240
+ plt.ylabel("True Label")
241
+ plt.xlabel("Predicted Label")
242
+ plt.tight_layout()
243
+ plt.savefig("confusion_matrix.png")
244
+ plt.close()
245
+
246
+ # --- DONE ---
247
+ print("\n=== FINAL TEST METRICS ===")
248
+ print(f"Test Accuracy : {test_acc:.4f}")
249
+ print(f"Test F1 : {test_f1:.4f}")
250
+ print(f"Test AUROC : {test_auroc:.4f}")
251
+ print("All test metrics, predictions, and confusion matrix saved.")