|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.metrics import f1_score, classification_report |
|
|
from datasets import Dataset |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer |
|
|
|
|
|
|
|
|
MODEL_PATH = "./final_model_best_macro" |
|
|
VALID_FILE = "/tmp/home/wzh/file/val_data.csv" |
|
|
|
|
|
|
|
|
print(f"正在加载模型: {MODEL_PATH} ...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device) |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
|
|
|
print(f"正在加载验证集: {VALID_FILE} ...") |
|
|
val_df = pd.read_csv(VALID_FILE) |
|
|
label_map = {"real": 0, "fake": 1} |
|
|
val_df['label'] = val_df['label'].map(label_map) |
|
|
|
|
|
|
|
|
val_dataset = Dataset.from_pandas(val_df) |
|
|
def tokenize(examples): |
|
|
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) |
|
|
val_dataset = val_dataset.map(tokenize, batched=True) |
|
|
val_dataset = val_dataset.remove_columns(["id", "text", "__index_level_0__"] if "__index_level_0__" in val_df.columns else ["id", "text"]) |
|
|
val_dataset = val_dataset.rename_column("label", "labels") |
|
|
|
|
|
|
|
|
print("正在进行预测...") |
|
|
trainer = Trainer(model=model) |
|
|
predictions = trainer.predict(val_dataset) |
|
|
|
|
|
|
|
|
logits = torch.tensor(predictions.predictions) |
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
fake_probs = probs[:, 1].numpy() |
|
|
true_labels = predictions.label_ids |
|
|
|
|
|
|
|
|
print("\n开始搜索最佳阈值 (Threshold Search)...") |
|
|
best_f1 = 0 |
|
|
best_thresh = 0.5 |
|
|
|
|
|
|
|
|
for thresh in np.arange(0.1, 0.91, 0.01): |
|
|
|
|
|
preds = (fake_probs > thresh).astype(int) |
|
|
current_f1 = f1_score(true_labels, preds, average='macro') |
|
|
|
|
|
if current_f1 > best_f1: |
|
|
best_f1 = current_f1 |
|
|
best_thresh = thresh |
|
|
|
|
|
print("\n" + "="*40) |
|
|
print(f"🎉 搜索完成!") |
|
|
print(f"默认阈值 (0.50) Macro-F1: {f1_score(true_labels, (fake_probs > 0.5).astype(int), average='macro'):.4f}") |
|
|
print(f"🏆 最佳阈值: {best_thresh:.2f}") |
|
|
print(f"🚀 优化后 Macro-F1: {best_f1:.4f}") |
|
|
print("="*40) |
|
|
|
|
|
|
|
|
final_preds = (fake_probs > best_thresh).astype(int) |
|
|
print("\n最佳阈值下的详细报告:") |
|
|
print(classification_report(true_labels, final_preds, target_names=["real", "fake"])) |