import os import pandas as pd import numpy as np import torch from torch import nn from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from sklearn.utils.class_weight import compute_class_weight from datasets import Dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EvalPrediction ) import optuna # --- 1. 网络与缓存配置 (自动修正路径版) --- os.environ['HF_HUB_OFFLINE'] = '1' # 定义基础路径 base_path = "/root/autodl-tmp/huggingface_cache/models--roberta-base" # 自动探测真实的快照路径 MODEL_NAME_OR_PATH = base_path if os.path.exists(os.path.join(base_path, "snapshots")): snap_path = os.path.join(base_path, "snapshots") snapshots = [d for d in os.listdir(snap_path) if os.path.isdir(os.path.join(snap_path, d))] if snapshots: # 使用找到的第一个快照文件夹 MODEL_NAME_OR_PATH = os.path.join(snap_path, snapshots[0]) print(f"✅ 成功定位模型真实路径: {MODEL_NAME_OR_PATH}") else: print("⚠️ 警告: snapshots 文件夹为空,尝试使用根目录") else: print(f"ℹ️ 未发现 snapshots 结构,尝试使用路径: {MODEL_NAME_OR_PATH}") # --- 2. 定义文件路径 --- TRAIN_FILE_PATH = "/tmp/home/wzh/file/train_data.csv" VALID_FILE_PATH = "/tmp/home/wzh/file/val_data.csv" # --- 3. 加载数据 --- print(f"加载训练集: {TRAIN_FILE_PATH}") train_df = pd.read_csv(TRAIN_FILE_PATH) print(f"加载验证集: {VALID_FILE_PATH}") eval_df = pd.read_csv(VALID_FILE_PATH) label_map = {"real": 0, "fake": 1} train_df['label'] = train_df['label'].map(label_map) eval_df['label'] = eval_df['label'].map(label_map) # --- 4. 计算类别权重 --- print("\n正在计算类别权重...") train_labels = np.array(train_df["label"]) class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels) device = "cuda" if torch.cuda.is_available() else "cpu" class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device) print(f"计算出的类别权重: {class_weights}") # --- 5. 创建Dataset与分词 --- train_dataset = Dataset.from_pandas(train_df) eval_dataset = Dataset.from_pandas(eval_df) print(f"\n正在加载本地模型: {MODEL_NAME_OR_PATH} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH) def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True) tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True) columns_to_remove = ["id", "text"] if "__index_level_0__" in train_df.columns: columns_to_remove.append("__index_level_0__") tokenized_train_dataset = tokenized_train_dataset.remove_columns(columns_to_remove) tokenized_eval_dataset = tokenized_eval_dataset.remove_columns(columns_to_remove) tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels") tokenized_eval_dataset = tokenized_eval_dataset.rename_column("label", "labels") # --- 6. 自定义Trainer (应用类别权重) --- class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.get("logits") loss_fct = nn.CrossEntropyLoss(weight=class_weights_tensor) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss # --- 7. 配置超参数搜索 --- id2label = {0: "real", 1: "fake"} label2id = {"real": 0, "fake": 1} def model_init(trial): return AutoModelForSequenceClassification.from_pretrained( MODEL_NAME_OR_PATH, num_labels=2, id2label=id2label, label2id=label2id, ) # !!!修复点 1:这里返回完整的字典,而不是 float !!! def compute_metrics_macro(p: EvalPrediction): labels = p.label_ids preds = np.argmax(p.predictions, axis=1) f1_macro = f1_score(labels, preds, average='macro', zero_division=0) acc = accuracy_score(labels, preds) precision_macro = precision_score(labels, preds, average='macro', zero_division=0) recall_macro = recall_score(labels, preds, average='macro', zero_division=0) return { "accuracy": acc, "f1_macro": f1_macro, "precision_macro": precision_macro, "recall_macro": recall_macro } # !!!修复点 2:专门定义一个函数给Optuna提取目标值 !!! def compute_objective(metrics): # Trainer 会自动给指标加上 'eval_' 前缀 return metrics['eval_f1_macro'] training_args = TrainingArguments( output_dir="./results_hyper_search_MACRO", per_device_train_batch_size=16, # 保持16防止爆显存 per_device_eval_batch_size=16, weight_decay=0.01, eval_strategy="steps", eval_steps=1000, save_strategy="steps", # 保持同步 save_steps=1000, logging_strategy="steps", logging_steps=1000, load_best_model_at_end=True, metric_for_best_model="f1_macro", # 对应 compute_metrics 返回字典里的 key greater_is_better=True, save_total_limit=1, ) trainer = CustomTrainer( model=None, args=training_args, model_init=model_init, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_eval_dataset, tokenizer=tokenizer, compute_metrics=compute_metrics_macro, # !!!这里传入返回字典的函数!!! ) # --- 8. 开始自动超参数搜索 --- print("\n" + "="*50) print("🚀 开始自动超参数搜索 (目标: 最大化 Macro-F1)...") print("="*50) best_run = trainer.hyperparameter_search( direction="maximize", n_trials=10, # 尝试10次 compute_objective=compute_objective, # !!!这里告诉Optuna怎么提取分数!!! backend="optuna" ) print("\n" + "="*50) print("🎉 超参数搜索完成!") print("="*50) print(f"最佳 Macro-F1: {best_run.objective:.4f}") print("最佳参数组合:", best_run.hyperparameters) # --- 9. 使用最佳参数进行最终训练 --- print("\n" + "="*50) print("🚀 使用最佳参数进行最终训练...") print("="*50) for k, v in best_run.hyperparameters.items(): setattr(training_args, k, v) training_args.output_dir = "./results_final_best_MACRO" training_args.logging_steps = 200 # 重新创建Trainer trainer = CustomTrainer( model_init=model_init, args=training_args, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_eval_dataset, compute_metrics=compute_metrics_macro, ) trainer.train() print("\n" + "="*50) print("🎉 最终训练完成!") print("="*50) # --- 10. 保存与最终报告 --- final_model_path = "./final_model_best_macro" trainer.save_model(final_model_path) tokenizer.save_pretrained(final_model_path) print(f"\n最终最优模型已保存至: {final_model_path}") print("\n--- 最终成绩单 (验证集) ---") final_metrics = trainer.evaluate() for key, value in final_metrics.items(): if key.startswith("eval_"): key = key[5:] if isinstance(value, float): print(f" - {key}: {value:.4f}") else: print(f" - {key}: {value}")