Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import torch | |
| import random | |
| import numpy as np | |
| import os | |
| import json | |
| import matplotlib.pyplot as plt | |
| import wandb | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig | |
| from datasets import Dataset | |
| from transformers import TrainingArguments, Trainer, EarlyStoppingCallback | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.utils.class_weight import compute_class_weight | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix | |
| import seaborn as sns | |
| from datetime import datetime | |
| import torch.nn.functional as F | |
| # ✅ 設定隨機種子,確保結果可重現 | |
| seed = 42 | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # ✅ 初始化wandb(可選,如果不需要可以禁用) | |
| # 選項1: 使用API密鑰 | |
| # WANDB_API_KEY = "YOUR_API_KEY_HERE" | |
| # os.environ["WANDB_API_KEY"] = WANDB_API_KEY | |
| # 選項2: 完全禁用wandb | |
| os.environ["WANDB_DISABLED"] = "true" | |
| # 初始化wandb (如果未禁用) | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| run = wandb.init( | |
| project="chinese-topic-classifier", | |
| name=f"roberta-topic-classifier-{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
| config={ | |
| "model_name": "hfl/chinese-roberta-wwm-ext", | |
| "epochs": 12, | |
| "batch_size": 8, | |
| "learning_rate": 1e-5, | |
| "weight_decay": 0.01, | |
| "max_length": 128, | |
| "seed": seed | |
| } | |
| ) | |
| # ✅ 創建輸出目錄 | |
| base_output_dir = "./roberta_output" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_dir = f"{base_output_dir}_{timestamp}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
| os.makedirs(f"{output_dir}/results", exist_ok=True) | |
| os.makedirs(f"{output_dir}/logs", exist_ok=True) | |
| os.makedirs(f"{output_dir}/api", exist_ok=True) | |
| # ✅ 配置日誌記錄 | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(f"{output_dir}/logs/training.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"開始訓練,輸出目錄: {output_dir}") | |
| # ✅ 讀取 CSV 檔案 (假設CSV已上傳到Hugging Face空間) | |
| file_path = "ragproject7.csv" | |
| logger.info(f"讀取資料集: {file_path}") | |
| df = pd.read_csv(file_path) | |
| logger.info(f"資料集大小: {df.shape}") | |
| # ✅ 處理數據 | |
| df = df[['text', 'topic']].dropna() | |
| df = df.drop_duplicates(subset=['text']) # 刪除重複文本 | |
| unique_topics = df["topic"].unique() | |
| logger.info(f"類別數量: {len(unique_topics)}") | |
| logger.info(f"類別分布: \n{df['topic'].value_counts()}") | |
| # 創建類別映射字典 | |
| topic_dict = {topic: i for i, topic in enumerate(unique_topics)} | |
| inv_topic_dict = {i: topic for topic, i in topic_dict.items()} | |
| # 更新wandb配置 (如果啟用) | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| wandb.config.update({ | |
| "num_classes": len(unique_topics), | |
| "class_distribution": df['topic'].value_counts().to_dict(), | |
| "topic_dict": topic_dict | |
| }) | |
| # 保存類別對照表,便於未來使用 | |
| with open(f"{output_dir}/topic_dict.json", "w", encoding="utf-8") as f: | |
| json.dump(topic_dict, f, ensure_ascii=False, indent=2) | |
| logger.info(f"保存類別對照表,共 {len(unique_topics)} 個類別") | |
| # 將類別轉換為數字 | |
| df["numeric_topic"] = df["topic"].map(topic_dict) | |
| # ✅ 計算類別權重以處理不平衡問題 | |
| class_counts = df['numeric_topic'].value_counts().sort_index() | |
| total_samples = len(df) | |
| class_weights = torch.FloatTensor([total_samples / (len(class_counts) * count) for count in class_counts]) | |
| logger.info(f"類別權重: {class_weights}") | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| wandb.config.update({"class_weights": class_weights.tolist()}) | |
| # ✅ 載入分詞器 (在Hugging Face上應該可以順利載入) | |
| model_name = "hfl/chinese-roberta-wwm-ext" | |
| logger.info(f"正在載入分詞器: {model_name}") | |
| # 使用AutoTokenizer替代特定的RobertaTokenizer,增加兼容性 | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| logger.info("成功載入分詞器") | |
| except Exception as e: | |
| logger.error(f"載入分詞器時發生錯誤: {e}") | |
| # 在Hugging Face平台嘗試備用模型 | |
| backup_model_names = ["hfl/chinese-macbert-base", "bert-base-chinese"] | |
| for backup_name in backup_model_names: | |
| try: | |
| logger.info(f"嘗試載入備用分詞器: {backup_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(backup_name) | |
| model_name = backup_name # 更新模型名稱 | |
| logger.info(f"成功載入備用分詞器: {backup_name}") | |
| break | |
| except Exception as e2: | |
| logger.error(f"載入備用分詞器 {backup_name} 失敗: {e2}") | |
| else: | |
| raise Exception("無法載入任何分詞器,請檢查環境設定") | |
| # ✅ 定義評估指標計算函數 | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predictions = np.argmax(logits, axis=1) | |
| # 計算基本指標 | |
| acc = accuracy_score(labels, predictions) | |
| f1 = f1_score(labels, predictions, average='weighted') | |
| precision = precision_score(labels, predictions, average='weighted') | |
| recall = recall_score(labels, predictions, average='weighted') | |
| # 計算每個類別的F1分數 | |
| f1_per_class = f1_score(labels, predictions, average=None) | |
| f1_per_class_dict = {inv_topic_dict[i]: score for i, score in enumerate(f1_per_class)} | |
| # 返回結果 | |
| result = { | |
| 'accuracy': acc, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall, | |
| } | |
| # 添加每個類別的F1分數 | |
| for class_name, score in f1_per_class_dict.items(): | |
| result[f'f1_{class_name}'] = score | |
| return result | |
| # ✅ 定義 tokenization 方法 | |
| max_length = 128 | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="np" | |
| ) | |
| # ✅ 準備數據集 | |
| logger.info("正在處理數據集...") | |
| # 添加數據增強(針對少數類) | |
| # 對類別樣本數據進行統計 | |
| class_samples = df['numeric_topic'].value_counts().sort_index() | |
| max_samples = class_samples.max() | |
| augmented_texts = [] | |
| augmented_topics = [] | |
| # 對少數類進行簡單的數據增強(這裡可以根據需求改進增強方法) | |
| for idx, count in enumerate(class_samples): | |
| if count < max_samples * 0.5: # 如果樣本數小於最多類的一半 | |
| # 找出這個類別的所有樣本 | |
| class_texts = df[df['numeric_topic'] == idx]['text'].tolist() | |
| # 計算需要增加的樣本數 | |
| augment_count = int(max_samples * 0.7) - count | |
| if augment_count > 0 and len(class_texts) > 0: | |
| # 從現有樣本中隨機抽樣進行輕微修改 | |
| for _ in range(augment_count): | |
| text = random.choice(class_texts) | |
| # 簡單增強:隨機刪除一些字符或重複一些字符 | |
| if len(text) > 20: # 確保文本足夠長 | |
| if random.random() < 0.5: | |
| # 隨機刪除一些字符 | |
| remove_pos = random.randint(0, len(text) - 10) | |
| remove_len = random.randint(1, 3) | |
| text = text[:remove_pos] + text[remove_pos + remove_len:] | |
| else: | |
| # 隨機重複一些字符 | |
| repeat_pos = random.randint(0, len(text) - 5) | |
| repeat_len = random.randint(1, 3) | |
| repeat_text = text[repeat_pos:repeat_pos + repeat_len] | |
| text = text[:repeat_pos] + repeat_text + text[repeat_pos:] | |
| augmented_texts.append(text) | |
| augmented_topics.append(idx) | |
| # 添加增強的樣本到原始數據 | |
| if augmented_texts: | |
| aug_df = pd.DataFrame({ | |
| 'text': augmented_texts, | |
| 'numeric_topic': augmented_topics | |
| }) | |
| df = pd.concat([df, aug_df], ignore_index=True) | |
| logger.info(f"添加了 {len(augmented_texts)} 個增強樣本,新的數據集大小: {df.shape}") | |
| logger.info(f"增強後的類別分布: \n{df['numeric_topic'].value_counts().sort_index()}") | |
| # 轉換為Dataset格式 | |
| dataset = Dataset.from_pandas(df[['text', 'numeric_topic']].rename(columns={'numeric_topic': 'labels'})) | |
| # 進行分詞處理 | |
| tokenized_dataset = dataset.map( | |
| lambda x: tokenizer(x['text'], padding="max_length", truncation=True, max_length=max_length), | |
| batched=True | |
| ) | |
| # 拆分訓練集和測試集 | |
| train_test_split_ratio = 0.2 | |
| train_test = tokenized_dataset.train_test_split(test_size=train_test_split_ratio, seed=seed, stratify_by_column="labels") | |
| train_dataset = train_test["train"] | |
| eval_dataset = train_test["test"] | |
| logger.info(f"訓練集大小: {len(train_dataset)},測試集大小: {len(eval_dataset)}") | |
| # ✅ 載入並配置模型 | |
| config = AutoConfig.from_pretrained( | |
| model_name, | |
| num_labels=len(unique_topics), | |
| hidden_dropout_prob=0.2, # 設置較低的dropout | |
| attention_probs_dropout_prob=0.2, | |
| classifier_dropout=0.3, # 分類器dropout率較高,可以減少過擬合 | |
| ) | |
| # 加載模型 | |
| logger.info(f"正在載入模型: {model_name}") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| config=config, | |
| ignore_mismatched_sizes=True # 允許分類層大小不匹配 | |
| ) | |
| # ✅ 凍結前10層,只微調最後2層 (適合小數據集) | |
| logger.info("正在凍結前10層,只微調最後2層...") | |
| # 獲取所有層 | |
| all_layers = list(model.named_parameters()) | |
| # 計算總層數 | |
| total_layers = sum(1 for name, _ in all_layers if "layer" in name) | |
| # 凍結前80%的層 | |
| freeze_layers = int(0.8 * total_layers) | |
| for i, (name, param) in enumerate(all_layers): | |
| # 保留最後幾層和分類層 | |
| if "layer" in name and int(name.split(".")[1]) < freeze_layers: | |
| param.requires_grad = False # 凍結前面大部分層 | |
| logger.info(f"凍結層: {name}") | |
| else: | |
| param.requires_grad = True # 訓練最後幾層和分類層 | |
| logger.info(f"訓練層: {name}") | |
| # 確認可訓練參數 | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| logger.info(f"可訓練參數: {trainable_params:,} / 總參數: {total_params:,} ({trainable_params/total_params:.2%})") | |
| # 添加類別權重到損失函數 | |
| model.class_weights = class_weights # 保存類別權重以供後續使用 | |
| # ✅ 設定訓練參數(小數據集最佳參數) | |
| logger.info("配置小數據集最佳訓練參數...") | |
| # 更新wandb配置 (如果啟用) | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| wandb.config.update({ | |
| "epochs": 12, # 小數據需要多輪學習 | |
| "batch_size": 8, # 小批次提高穩定性 | |
| "learning_rate": 1e-5, # 降低學習率避免過擬合 | |
| "weight_decay": 0.01, # 正則化防止模型記死 | |
| "gradient_accumulation_steps": 2, # 小批次補償,讓梯度更穩 | |
| "frozen_layers": "前80%層", # 凍結前80%層 | |
| "early_stopping_patience": 3 # 3輪無進步就停止 | |
| }) | |
| # 添加報告選項 | |
| report_to_list = ["tensorboard"] | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| report_to_list.append("wandb") | |
| training_args = TrainingArguments( | |
| output_dir=f"{output_dir}/checkpoints", | |
| num_train_epochs=12, # ✅ 訓練12輪(小數據需要多輪學習) | |
| per_device_train_batch_size=8, # ✅ 小批次(提高穩定性) | |
| per_device_eval_batch_size=8, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| logging_dir=f"{output_dir}/logs/tensorboard", | |
| logging_strategy="steps", | |
| logging_steps=50, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| learning_rate=1e-5, # ✅ 小數據降低學習率,避免過擬合 | |
| weight_decay=0.01, # ✅ 正則化,防止模型記死 | |
| warmup_ratio=0.1, # ✅ 設置warm-up,讓學習率慢慢上升 | |
| gradient_accumulation_steps=2, # ✅ 小批次補償,讓梯度更穩 | |
| fp16=True, | |
| remove_unused_columns=False, | |
| report_to=report_to_list, # 根據設置決定是否啟用wandb | |
| save_total_limit=3, # 只保存最近的3個檢查點 | |
| push_to_hub=False, # 不推送到HuggingFace Hub | |
| dataloader_num_workers=2, # 使用較少線程,避免小數據過度並行處理 | |
| group_by_length=True, # 分組相似長度的序列,提高效率 | |
| ) | |
| # 自定義損失函數的訓練器 | |
| class CustomTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # 獲取類別權重 | |
| device = logits.device | |
| class_weights = model.class_weights.to(device) | |
| # 計算帶權重的交叉熵損失 | |
| loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights) | |
| loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| # 定義早停回調,避免過擬合 | |
| early_stopping = EarlyStoppingCallback( | |
| early_stopping_patience=3, # ✅ 監控驗證集表現,3輪無進步就停 | |
| early_stopping_threshold=0.001 | |
| ) | |
| logger.info("配置訓練器,啟用Early Stopping...") | |
| trainer = CustomTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| compute_metrics=compute_metrics, | |
| callbacks=[early_stopping], # 使用Early Stopping避免過擬合 | |
| ) | |
| # ✅ 開始訓練 | |
| logger.info("開始訓練...") | |
| trainer.train() | |
| # ✅ 評估模型 | |
| logger.info("評估最終模型...") | |
| eval_results = trainer.evaluate() | |
| logger.info(f"評估結果: {eval_results}") | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| wandb.log({"final_results": eval_results}) | |
| # 在測試集上進行詳細評估 | |
| logger.info("生成詳細測試報告...") | |
| predictions = trainer.predict(eval_dataset) | |
| preds = np.argmax(predictions.predictions, axis=1) | |
| labels = predictions.label_ids | |
| # 生成分類報告 | |
| class_names = [inv_topic_dict[i] for i in range(len(unique_topics))] | |
| classification_rep = classification_report(labels, preds, target_names=class_names, output_dict=True) | |
| with open(f"{output_dir}/results/classification_report.json", "w", encoding="utf-8") as f: | |
| json.dump(classification_rep, f, ensure_ascii=False, indent=2) | |
| # 繪制混淆矩陣 | |
| plt.figure(figsize=(10, 8)) | |
| cm = confusion_matrix(labels, preds) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) | |
| plt.title('Confusion Matrix') | |
| plt.xlabel('Predicted') | |
| plt.ylabel('True') | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| plt.savefig(f"{output_dir}/results/confusion_matrix.png") | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| wandb.log({"confusion_matrix": wandb.Image(f"{output_dir}/results/confusion_matrix.png")}) | |
| logger.info(f"混淆矩陣已保存至 {output_dir}/results/confusion_matrix.png") | |
| # ✅ 保存最終模型和分詞器 | |
| final_model_path = f"{output_dir}/final_model" | |
| model.save_pretrained(final_model_path) | |
| tokenizer.save_pretrained(final_model_path) | |
| logger.info(f"最終模型和分詞器已保存到 {final_model_path}") | |
| # 將模型上傳到Hugging Face Hub (如需要) | |
| push_to_hub = False # 設置為True如果要上傳到Hugging Face Hub | |
| if push_to_hub: | |
| from huggingface_hub import HfFolder, Repository | |
| # 設置您的Hugging Face憑證 | |
| # HfFolder.save_token("YOUR_HF_TOKEN") | |
| # 推送到Hub | |
| repo_name = f"chinese-topic-classifier-{timestamp}" | |
| model.push_to_hub(repo_name) | |
| tokenizer.push_to_hub(repo_name) | |
| logger.info(f"模型已上傳至Hugging Face Hub: {repo_name}") | |
| # 將模型上傳到wandb(如果wandb已啟用) | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| try: | |
| model_artifact = wandb.Artifact('roberta-topic-model', type='model') | |
| model_artifact.add_dir(final_model_path) | |
| wandb.log_artifact(model_artifact) | |
| logger.info("已將模型上傳到wandb") | |
| except Exception as e: | |
| logger.warning(f"上傳模型到wandb時發生錯誤: {e}") | |
| else: | |
| logger.info("wandb已禁用,跳過模型上傳") | |
| # ✅ 定義預測函數 | |
| def predict(text, return_probs=False): | |
| inputs = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1)[0] | |
| prediction = torch.argmax(logits, dim=1).item() | |
| if return_probs: | |
| probs_dict = {inv_topic_dict[i]: float(probs[i]) for i in range(len(unique_topics))} | |
| return inv_topic_dict[prediction], probs_dict | |
| return inv_topic_dict[prediction] | |
| # ✅ 創建Gradio界面 (Hugging Face平台特別適用) | |
| try: | |
| import gradio as gr | |
| def predict_for_gradio(text): | |
| topic, probs = predict(text, return_probs=True) | |
| # 格式化機率為百分比 | |
| formatted_probs = {k: f"{v:.2%}" for k, v in probs.items()} | |
| # 排序並格式化結果 | |
| sorted_probs = sorted(formatted_probs.items(), key=lambda x: float(x[1].strip('%'))/100, reverse=True) | |
| result_text = f"預測主題: {topic}\n\n各類別機率:\n" | |
| for class_name, prob in sorted_probs: | |
| result_text += f"- {class_name}: {prob}\n" | |
| return result_text | |
| # 創建Gradio界面 | |
| demo = gr.Interface( | |
| fn=predict_for_gradio, | |
| inputs=gr.Textbox(lines=5, placeholder="請輸入要分類的中文文本..."), | |
| outputs="text", | |
| title="中文主題分類器", | |
| description=f"此模型可將文本分類為以下主題: {', '.join(unique_topics)}", | |
| examples=[ | |
| ["這篇文章探討了太陽能電池的最新研究進展。"], | |
| ["碳捕捉技術可以減少溫室氣體排放。"], | |
| ["社區參與對環保項目的成功至關重要。"] | |
| ] | |
| ) | |
| # 啟動Gradio應用 | |
| demo.launch(share=True) | |
| logger.info("Gradio界面已啟動") | |
| except ImportError: | |
| logger.info("未安裝Gradio,跳過界面創建") | |
| # ✅ 創建API服務代碼 | |
| api_code = ''' | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch.nn.functional as F | |
| import os | |
| app = Flask(__name__) | |
| CORS(app) | |
| # 全局變量 | |
| model = None | |
| tokenizer = None | |
| topic_dict = None | |
| inv_topic_dict = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| global model, tokenizer, topic_dict, inv_topic_dict | |
| model_path = "./final_model" | |
| if not os.path.exists(model_path): | |
| return {"error": f"模型路徑 {model_path} 不存在"} | |
| topic_dict_path = "./topic_dict.json" | |
| if not os.path.exists(topic_dict_path): | |
| return {"error": f"類別映射文件 {topic_dict_path} 不存在"} | |
| try: | |
| # 載入模型和分詞器 | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| model.to(device) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # 載入類別映射 | |
| with open(topic_dict_path, "r", encoding="utf-8") as f: | |
| topic_dict = json.load(f) | |
| inv_topic_dict = {v: k for k, v in topic_dict.items()} | |
| return {"success": "模型載入成功"} | |
| except Exception as e: | |
| return {"error": f"載入模型時發生錯誤: {str(e)}"} | |
| @app.route("/", methods=["GET"]) | |
| def index(): | |
| return jsonify({"status": "API服務運行中", "endpoints": {"/predict": "文本分類預測"}}) | |
| @app.route("/predict", methods=["POST"]) | |
| def predict_topic(): | |
| # 確保模型已載入 | |
| global model, tokenizer, topic_dict, inv_topic_dict | |
| if model is None: | |
| result = load_model() | |
| if "error" in result: | |
| return jsonify(result), 500 | |
| # 獲取請求數據 | |
| data = request.json | |
| if not data or "text" not in data: | |
| return jsonify({"error": "請求必須包含'text'字段"}), 400 | |
| text = data["text"] | |
| return_probs = data.get("return_probs", False) | |
| try: | |
| # 進行預測 | |
| inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1)[0] | |
| prediction = torch.argmax(logits, dim=1).item() | |
| result = {"topic": inv_topic_dict[prediction]} | |
| if return_probs: | |
| result["probabilities"] = {inv_topic_dict[i]: float(probs[i]) for i in range(len(inv_topic_dict))} | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({"error": f"預測過程中發生錯誤: {str(e)}"}), 500 | |
| if __name__ == "__main__": | |
| # 預先載入模型 | |
| load_model() | |
| app.run(host="0.0.0.0", port=5000, debug=False) | |
| ''' | |
| with open(f"{output_dir}/api/app.py", "w", encoding="utf-8") as f: | |
| f.write(api_code) | |
| # 創建啟動腳本 | |
| startup_script = ''' | |
| #!/bin/bash | |
| cd "$(dirname "$0")" | |
| export PYTHONIOENCODING=utf-8 | |
| export FLASK_APP=app.py | |
| flask run --host=0.0.0.0 --port=5000 | |
| ''' | |
| with open(f"{output_dir}/api/start_api.sh", "w", encoding="utf-8") as f: | |
| f.write(startup_script) | |
| os.chmod(f"{output_dir}/api/start_api.sh", 0o755) | |
| # 創建README文件 | |
| readme = f''' | |
| # 中文主題分類模型 API 服務 | |
| ## 概述 | |
| 這是一個使用預訓練語言模型訓練的中文主題分類API服務,可以將文本分類到以下類別: | |
| {json.dumps({v: k for k, v in topic_dict.items()}, ensure_ascii=False, indent=2)} | |
| ## 使用方法 | |
| ### 啟動API服務 | |
| 1. 確保已安裝所需套件: `pip install flask flask-cors transformers torch` | |
| 2. 運行啟動腳本: `./start_api.sh` | |
| ### API端點 | |
| - `GET /`: 檢查API狀態 | |
| - `POST /predict`: 進行文本分類 | |
| ### 預測請求示例 | |
| ```bash | |
| curl -X POST http://localhost:5000/predict \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{{"text": "您的文本內容", "return_probs": true}}' | |
| ``` | |
| ### 返回格式 | |
| ```json | |
| {{ | |
| "topic": "預測的類別", | |
| "probabilities": {{ | |
| "類別1": 0.8, | |
| "類別2": 0.1, | |
| "類別3": 0.05, | |
| "類別4": 0.03, | |
| "類別5": 0.02 | |
| }} | |
| }} | |
| ``` | |
| ## 在其他應用中使用 | |
| ### Python | |
| ```python | |
| import requests | |
| def predict_topic(text, return_probs=False): | |
| response = requests.post('http://localhost:5000/predict', | |
| json={{'text': text, 'return_probs': return_probs}}) | |
| return response.json() | |
| # 使用示例 | |
| result = predict_topic("您的文本", return_probs=True) | |
| print(f"預測類別: {{result['topic']}}") | |
| if 'probabilities' in result: | |
| for topic, prob in result['probabilities'].items(): | |
| print(f"{{topic}}: {{prob:.2f}}") | |
| ``` | |
| ### JavaScript | |
| ```javascript | |
| async function predictTopic(text, returnProbs = false) {{ | |
| const response = await fetch('http://localhost:5000/predict', {{ | |
| method: 'POST', | |
| headers: {{ | |
| 'Content-Type': 'application/json', | |
| }}, | |
| body: JSON.stringify({{ text, return_probs: returnProbs }}), | |
| }}); | |
| return await response.json(); | |
| }} | |
| // 使用示例 | |
| predictTopic("您的文本", true).then(result => {{ | |
| console.log(`預測類別: ${{result.topic}}`); | |
| if (result.probabilities) {{ | |
| Object.entries(result.probabilities).forEach(([topic, prob]) => {{ | |
| console.log(`${{topic}}: ${{prob.toFixed(2)}}`); | |
| }}); | |
| }} | |
| }}); | |
| ``` | |
| ## 訓練詳情 | |
| - 模型: {model_name} | |
| - 訓練時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| - 訓練集大小: {len(train_dataset)} | |
| - 測試集大小: {len(eval_dataset)} | |
| - 最終測試集F1分數: {eval_results.get('eval_f1', 'N/A')} | |
| ''' | |
| with open(f"{output_dir}/api/README.md", "w", encoding="utf-8") as f: | |
| f.write(readme) | |
| # 將API所需文件打包 | |
| import shutil | |
| os.makedirs(f"{output_dir}/api/final_model", exist_ok=True) | |
| shutil.copytree(final_model_path, f"{output_dir}/api/final_model", dirs_exist_ok=True) | |
| shutil.copy(f"{output_dir}/topic_dict.json", f"{output_dir}/api/topic_dict.json") | |
| logger.info(f"API服務代碼和文件已準備完成,位於 {output_dir}/api/") | |
| # 嘗試打包API文件夾 (在Hugging Face環境中可能不需要) | |
| try: | |
| import zipfile | |
| def zip_directory(directory_path, zip_path): | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, _, files in os.walk(directory_path): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| arcname = os.path.relpath(file_path, os.path.dirname(directory_path)) | |
| zipf.write(file_path, arcname) | |
| zip_directory(f"{output_dir}/api", f"{output_dir}/api.zip") | |
| logger.info(f"API服務文件已打包為 {output_dir}/api.zip") | |
| except Exception as e: | |
| logger.warning(f"打包API文件時發生錯誤: {e},但這不影響模型和API功能") | |
| # ✅ 手動輸入文本並分類 | |
| print("\n" + "="*50) | |
| print("模型訓練完成!現在可以進行文本分類測試") | |
| print("="*50 + "\n") | |
| # 在Hugging Face平台上,我們可以提供幾個示例文本自動展示結果 | |
| sample_texts = [ | |
| "這篇文章探討了太陽能電池的最新研究進展。", | |
| "碳捕捉技術可以減少溫室氣體排放。", | |
| "社區參與對環保項目的成功至關重要。" | |
| ] | |
| print("示例文本分類結果:") | |
| for i, text in enumerate(sample_texts, 1): | |
| topic, probs = predict(text, return_probs=True) | |
| print(f"\n示例 {i}: {text}") | |
| print(f"預測主題:{topic}") | |
| print("各類別機率:") | |
| for topic, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True): | |
| print(f"- {topic}: {prob:.4f}") | |
| # 如果在交互式環境,仍然提供輸入選項 | |
| if 'ipykernel' in sys.modules: | |
| text_input = input("\n請輸入要分類的文本:") | |
| if text_input: | |
| topic, probs = predict(text_input, return_probs=True) | |
| print(f"\n預測主題:{topic}") | |
| print("各類別機率:") | |
| for topic, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True): | |
| print(f"- {topic}: {prob:.4f}") | |
| # ✅ 保存模型評估結果 | |
| with open(f"{output_dir}/results/model_evaluation.txt", "w", encoding="utf-8") as f: | |
| f.write(f"模型評估結果:\n") | |
| f.write(f"準確率: {eval_results.get('eval_accuracy', 'N/A')}\n") | |
| f.write(f"F1分數: {eval_results.get('eval_f1', 'N/A')}\n") | |
| f.write(f"精確率: {eval_results.get('eval_precision', 'N/A')}\n") | |
| f.write(f"召回率: {eval_results.get('eval_recall', 'N/A')}\n\n") | |
| f.write("各類別F1分數:\n") | |
| for class_name in class_names: | |
| f.write(f"{class_name}: {eval_results.get(f'eval_f1_{class_name}', 'N/A')}\n") | |
| # ✅ 完成wandb運行(如果已啟用) | |
| if os.environ.get("WANDB_DISABLED") != "true": | |
| try: | |
| wandb.finish() | |
| logger.info("wandb運行已完成") | |
| except Exception as e: | |
| logger.warning(f"結束wandb運行時發生錯誤: {e}") | |
| # 保存到Hugging Face Hub的說明 | |
| hub_instructions = f''' | |
| # 將模型保存到Hugging Face Hub | |
| 如果您想將訓練好的模型分享到Hugging Face Hub,請按照以下步驟操作: | |
| 1. 確保您已登入Hugging Face: | |
| ```python | |
| from huggingface_hub import login | |
| login() # 會提示您輸入token | |
| ``` | |
| 2. 將模型上傳到Hub: | |
| ```python | |
| model_id = "your-username/chinese-topic-classifier" # 替換為您的用戶名 | |
| # 上傳模型 | |
| model.push_to_hub(model_id) | |
| # 上傳分詞器 | |
| tokenizer.push_to_hub(model_id) | |
| # 上傳配置文件 | |
| with open("config.json", "w") as f: | |
| json.dump({{"model_name": "{model_name}", | |
| "num_classes": {len(unique_topics)}, | |
| "classes": {list(topic_dict.keys())} | |
| }}, f) | |
| from huggingface_hub import upload_file | |
| upload_file( | |
| path_or_fileobj="config.json", | |
| path_in_repo="config.json", | |
| repo_id=model_id | |
| ) | |
| ``` | |
| 3. 創建一個基於Gradio的Demo並上傳: | |
| ```python | |
| %%writefile app.py | |
| import gradio as gr | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch.nn.functional as F | |
| import torch | |
| # 定義模型ID (您上面使用的) | |
| model_id = "your-username/chinese-topic-classifier" | |
| # 載入模型和分詞器 | |
| model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # 類別映射 | |
| topic_names = {list(topic_dict.keys())} | |
| # 預測函數 | |
| def predict(text): | |
| inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt") | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1)[0] | |
| prediction = torch.argmax(logits, dim=1).item() | |
| # 格式化結果 | |
| result = f"預測類別: {{topic_names[prediction]}}\\n\\n機率分布:\\n" | |
| for i, prob in enumerate(probs): | |
| result += f"- {{topic_names[i]}}: {{prob:.4f}}\\n" | |
| return result | |
| # 創建Gradio界面 | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=5, placeholder="請輸入中文文本..."), | |
| outputs="text", | |
| title="中文主題分類器", | |
| description="輸入中文文本,預測其所屬主題類別。", | |
| examples=[ | |
| "這篇文章探討了太陽能電池的最新研究進展。", | |
| "碳捕捉技術可以減少溫室氣體排放。", | |
| "社區參與對環保項目的成功至關重要。" | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
| # 然後上傳此應用到Hugging Face Spaces | |
| ``` | |
| ''' | |
| with open(f"{output_dir}/hub_instructions.md", "w", encoding="utf-8") as f: | |
| f.write(hub_instructions) | |
| print(f""" | |
| ======================================== | |
| 訓練和API服務準備完成! | |
| ======================================== | |
| 訓練結果摘要: | |
| - 模型保存在: {final_model_path} | |
| - API服務代碼在: {output_dir}/api/ | |
| 將模型部署到Hugging Face Hub: | |
| - 說明文件: {output_dir}/hub_instructions.md | |
| API使用示例: | |
| curl -X POST http://localhost:5000/predict \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{{"text": "您的文本", "return_probs": true}}' | |
| 更多詳情請參閱: {output_dir}/api/README.md | |
| """) | |
| # 特別為Hugging Face環境添加的Spaces部署說明 | |
| print(""" | |
| 在Hugging Face平台上部署模型: | |
| 1. 創建一個新的Space: | |
| - 前往 huggingface.co/new-space | |
| - 選擇Gradio作為SDK | |
| - 填寫名稱和描述 | |
| 2. 上傳模型和app.py: | |
| - 將訓練好的模型上傳到你的Hugging Face賬戶 | |
| - 根據hub_instructions.md中的說明創建app.py | |
| - 上傳到你的Space | |
| 3. 配置Space: | |
| - 在Space設置中添加依賴項: transformers, torch, gradio | |
| 完成這些步驟後,你將有一個公開可訪問的模型推理界面! | |
| """) |