import pandas as pd import numpy as np import torch from sklearn.metrics import f1_score, classification_report from datasets import Dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer # --- 配置 --- MODEL_PATH = "./final_model_best_macro" # 我们刚训练好的最好模型 VALID_FILE = "/tmp/home/wzh/file/val_data.csv" # 老师给的验证集 # --- 加载 --- print(f"正在加载模型: {MODEL_PATH} ...") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) print(f"正在加载验证集: {VALID_FILE} ...") val_df = pd.read_csv(VALID_FILE) label_map = {"real": 0, "fake": 1} val_df['label'] = val_df['label'].map(label_map) # 转换为Dataset val_dataset = Dataset.from_pandas(val_df) def tokenize(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) val_dataset = val_dataset.map(tokenize, batched=True) val_dataset = val_dataset.remove_columns(["id", "text", "__index_level_0__"] if "__index_level_0__" in val_df.columns else ["id", "text"]) val_dataset = val_dataset.rename_column("label", "labels") # --- 预测 --- print("正在进行预测...") trainer = Trainer(model=model) predictions = trainer.predict(val_dataset) # 获取概率 (Softmax) logits = torch.tensor(predictions.predictions) probs = torch.nn.functional.softmax(logits, dim=-1) # 获取 "fake" (label=1) 的概率 fake_probs = probs[:, 1].numpy() true_labels = predictions.label_ids # --- 搜索最佳阈值 --- print("\n开始搜索最佳阈值 (Threshold Search)...") best_f1 = 0 best_thresh = 0.5 # 从 0.1 遍历到 0.9 for thresh in np.arange(0.1, 0.91, 0.01): # 如果 fake_prob > thresh, 预测为1, 否则为0 preds = (fake_probs > thresh).astype(int) current_f1 = f1_score(true_labels, preds, average='macro') if current_f1 > best_f1: best_f1 = current_f1 best_thresh = thresh print("\n" + "="*40) print(f"🎉 搜索完成!") print(f"默认阈值 (0.50) Macro-F1: {f1_score(true_labels, (fake_probs > 0.5).astype(int), average='macro'):.4f}") print(f"🏆 最佳阈值: {best_thresh:.2f}") print(f"🚀 优化后 Macro-F1: {best_f1:.4f}") print("="*40) # 打印最佳阈值下的详细报告 final_preds = (fake_probs > best_thresh).astype(int) print("\n最佳阈值下的详细报告:") print(classification_report(true_labels, final_preds, target_names=["real", "fake"]))