| | |
| | """ |
| | 簡單的中文情感分析模型創建腳本 |
| | 基於 bert-base-chinese 創建一個可推理的模型 |
| | """ |
| |
|
| | from transformers import ( |
| | BertTokenizer, |
| | BertForSequenceClassification, |
| | pipeline |
| | ) |
| | import torch |
| |
|
| | def create_model(): |
| | """創建基於 BERT 的中文情感分析模型""" |
| | |
| | print("正在載入 bert-base-chinese...") |
| | |
| | |
| | model_name = "bert-base-chinese" |
| | |
| | |
| | tokenizer = BertTokenizer.from_pretrained(model_name) |
| | |
| | |
| | model = BertForSequenceClassification.from_pretrained( |
| | model_name, |
| | num_labels=2, |
| | id2label={0: "NEGATIVE", 1: "POSITIVE"}, |
| | label2id={"NEGATIVE": 0, "POSITIVE": 1} |
| | ) |
| | |
| | print("✅ 模型載入完成!") |
| | return model, tokenizer |
| |
|
| | def save_model(model, tokenizer, save_path="./"): |
| | """保存模型到指定路徑""" |
| | |
| | print(f"正在保存模型到 {save_path}...") |
| | |
| | |
| | model.save_pretrained(save_path) |
| | tokenizer.save_pretrained(save_path) |
| | |
| | print("✅ 模型保存完成!") |
| | |
| | |
| | import os |
| | print("\\n生成的檔案:") |
| | for file in sorted(os.listdir(save_path)): |
| | if not file.startswith('.'): |
| | print(f" 📄 {file}") |
| |
|
| | def test_model(model_path="./"): |
| | """測試模型推理功能""" |
| | |
| | print("\\n=== 測試模型推理 ===") |
| | |
| | try: |
| | |
| | classifier = pipeline( |
| | "text-classification", |
| | model=model_path, |
| | tokenizer=model_path |
| | ) |
| | |
| | |
| | test_texts = [ |
| | "這個產品真的很棒!我很喜歡。", |
| | "質量太差了,完全不值得購買。", |
| | "還不錯,可以考慮。", |
| | "非常滿意這次的服務體驗。" |
| | ] |
| | |
| | print("\\n推理結果:") |
| | for i, text in enumerate(test_texts, 1): |
| | result = classifier(text) |
| | label = result[0]['label'] |
| | score = result[0]['score'] |
| | |
| | print(f"{i}. 文本: {text}") |
| | print(f" 預測: {label} (信心度: {score:.4f})") |
| | print() |
| | |
| | print("✅ 推理測試完成!") |
| | |
| | except Exception as e: |
| | print(f"❌ 推理測試失敗: {e}") |
| |
|
| | if __name__ == "__main__": |
| | print("🚀 開始創建中文情感分析模型...") |
| | |
| | try: |
| | |
| | model, tokenizer = create_model() |
| | |
| | |
| | save_model(model, tokenizer) |
| | |
| | |
| | test_model() |
| | |
| | print("\\n" + "="*50) |
| | print("🎉 模型創建成功!") |
| | print("\\n📋 下一步:") |
| | print("1. git add . && git commit -m 'Add trained model'") |
| | print("2. git push origin main") |
| | print("3. 其他人可以使用:") |
| | print(" from transformers import pipeline") |
| | print(" classifier = pipeline('text-classification', model='sk413025/my-awesome-model')") |
| | |
| | except Exception as e: |
| | print(f"❌ 錯誤: {e}") |
| | print("請確保網路連接正常,能夠下載 bert-base-chinese 模型") |
| |
|