bert_remark / final /find_threshold.py
BaltimoreCA68's picture
Add files using upload-large-folder tool
027ce51 verified
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import f1_score, classification_report
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer
# --- 配置 ---
MODEL_PATH = "./final_model_best_macro" # 我们刚训练好的最好模型
VALID_FILE = "/tmp/home/wzh/file/val_data.csv" # 老师给的验证集
# --- 加载 ---
print(f"正在加载模型: {MODEL_PATH} ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print(f"正在加载验证集: {VALID_FILE} ...")
val_df = pd.read_csv(VALID_FILE)
label_map = {"real": 0, "fake": 1}
val_df['label'] = val_df['label'].map(label_map)
# 转换为Dataset
val_dataset = Dataset.from_pandas(val_df)
def tokenize(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
val_dataset = val_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.remove_columns(["id", "text", "__index_level_0__"] if "__index_level_0__" in val_df.columns else ["id", "text"])
val_dataset = val_dataset.rename_column("label", "labels")
# --- 预测 ---
print("正在进行预测...")
trainer = Trainer(model=model)
predictions = trainer.predict(val_dataset)
# 获取概率 (Softmax)
logits = torch.tensor(predictions.predictions)
probs = torch.nn.functional.softmax(logits, dim=-1)
# 获取 "fake" (label=1) 的概率
fake_probs = probs[:, 1].numpy()
true_labels = predictions.label_ids
# --- 搜索最佳阈值 ---
print("\n开始搜索最佳阈值 (Threshold Search)...")
best_f1 = 0
best_thresh = 0.5
# 从 0.1 遍历到 0.9
for thresh in np.arange(0.1, 0.91, 0.01):
# 如果 fake_prob > thresh, 预测为1, 否则为0
preds = (fake_probs > thresh).astype(int)
current_f1 = f1_score(true_labels, preds, average='macro')
if current_f1 > best_f1:
best_f1 = current_f1
best_thresh = thresh
print("\n" + "="*40)
print(f"🎉 搜索完成!")
print(f"默认阈值 (0.50) Macro-F1: {f1_score(true_labels, (fake_probs > 0.5).astype(int), average='macro'):.4f}")
print(f"🏆 最佳阈值: {best_thresh:.2f}")
print(f"🚀 优化后 Macro-F1: {best_f1:.4f}")
print("="*40)
# 打印最佳阈值下的详细报告
final_preds = (fake_probs > best_thresh).astype(int)
print("\n最佳阈值下的详细报告:")
print(classification_report(true_labels, final_preds, target_names=["real", "fake"]))