jerrynnms commited on
Commit
310d6ab
·
verified ·
1 Parent(s): dab03c4

Update bert_explainer.py

Browse files
Files changed (1) hide show
  1. bert_explainer.py +73 -67
bert_explainer.py CHANGED
@@ -1,67 +1,73 @@
1
- import torch
2
- from AI_Model_architecture import BertLSTM_CNN_Classifier, BertPreprocessor
3
- from transformers import BertTokenizer
4
- import re
5
- import requests
6
- import os
7
-
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # Google Drive 載入 model.pth
11
- def load_model_from_drive():
12
- model_url = "https://drive.google.com/uc?export=download&id=1UXkOqMPUiPUIbsy8iENHUqbNFLEHcFFg" # 替換為你的檔案 ID
13
- response = requests.get(model_url)
14
- if response.status_code == 200:
15
- with open("model.pth", "wb") as f:
16
- f.write(response.content)
17
- return True
18
- return False
19
-
20
- if not os.path.exists("model.pth"):
21
- if not load_model_from_drive():
22
- raise FileNotFoundError("無法從 Google Drive 載入 model.pth")
23
-
24
- model = BertLSTM_CNN_Classifier()
25
- model.load_state_dict(torch.load("model.pth", map_location=device))
26
- model.to(device)
27
- model.eval()
28
-
29
- tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
30
-
31
- def predict_single_sentence(model, tokenizer, sentence, max_len=256):
32
- model.eval()
33
- with torch.no_grad():
34
- sentence = re.sub(r"\s+", "", sentence)
35
- sentence = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", sentence)
36
-
37
- encoded = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
38
- input_ids = encoded["input_ids"].to(device)
39
- attention_mask = encoded["attention_mask"].to(device)
40
- token_type_ids = encoded["token_type_ids"].to(device)
41
-
42
- output = model(input_ids, attention_mask, token_type_ids)
43
- prob = output.item()
44
- label = int(prob > 0.5)
45
-
46
- if prob > 0.9:
47
- risk = "🔴 高風險(極可能是詐騙)"
48
- elif prob > 0.5:
49
- risk = "🟡 中風險(可疑)"
50
- else:
51
- risk = "🟢 低風險(正常)"
52
-
53
- pre_label = "詐騙" if label == 1 else "正常"
54
-
55
- print(f"\n📩 訊息內容:{sentence}")
56
- print(f"✅ 預測結果:{pre_label}")
57
- print(f"📊 信心值:{round(prob*100, 2)}")
58
- print(f"⚠️ 風險等級:{risk}")
59
- return pre_label, prob, risk
60
-
61
- def analyze_text(text):
62
- label, prob, risk = predict_single_sentence(model, tokenizer, text)
63
- return {
64
- "status": label,
65
- "confidence": round(prob*100, 2),
66
- "suspicious_keywords": [risk]
67
- }
 
 
 
 
 
 
 
1
+ import torch
2
+ from AI_Model_architecture import BertLSTM_CNN_Classifier, BertPreprocessor
3
+ from transformers import BertTokenizer
4
+ import re
5
+ import requests
6
+ import os
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # 指定可寫入的路徑(/tmp Hugging Face 允許寫入的暫存區)
11
+ model_path = "/tmp/model.pth"
12
+
13
+ # Google Drive 載入 model.pth
14
+ def load_model_from_drive():
15
+ model_url = "https://drive.google.com/uc?export=download&id=1UXkOqMPUiPUIbsy8iENHUqbNFLEHcFFg"
16
+ response = requests.get(model_url)
17
+ if response.status_code == 200:
18
+ with open(model_path, "wb") as f:
19
+ f.write(response.content)
20
+ return True
21
+ return False
22
+
23
+ # ✅ 檢查 model 是否已存在,否則載入
24
+ if not os.path.exists(model_path):
25
+ if not load_model_from_drive():
26
+ raise FileNotFoundError("❌ 無法從 Google Drive 載入 model.pth")
27
+
28
+ # ✅ 正確讀取模型
29
+ model = BertLSTM_CNN_Classifier()
30
+ model.load_state_dict(torch.load(model_path, map_location=device))
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ # 載入中文 tokenizer
35
+ tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
36
+
37
+ def predict_single_sentence(model, tokenizer, sentence, max_len=256):
38
+ model.eval()
39
+ with torch.no_grad():
40
+ sentence = re.sub(r"\s+", "", sentence)
41
+ sentence = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", sentence)
42
+
43
+ encoded = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
44
+ input_ids = encoded["input_ids"].to(device)
45
+ attention_mask = encoded["attention_mask"].to(device)
46
+ token_type_ids = encoded["token_type_ids"].to(device)
47
+
48
+ output = model(input_ids, attention_mask, token_type_ids)
49
+ prob = output.item()
50
+ label = int(prob > 0.5)
51
+
52
+ if prob > 0.9:
53
+ risk = "🔴 高風險(極可能是詐騙)"
54
+ elif prob > 0.5:
55
+ risk = "🟡 中風險(可疑)"
56
+ else:
57
+ risk = "🟢 低風險(正常)"
58
+
59
+ pre_label = "詐騙" if label == 1 else "正常"
60
+
61
+ print(f"\n📩 訊息內容:{sentence}")
62
+ print(f"✅ 預測結果:{pre_label}")
63
+ print(f"📊 信心值:{round(prob*100, 2)}")
64
+ print(f"⚠️ 風險等級:{risk}")
65
+ return pre_label, prob, risk
66
+
67
+ def analyze_text(text):
68
+ label, prob, risk = predict_single_sentence(model, tokenizer, text)
69
+ return {
70
+ "status": label,
71
+ "confidence": round(prob*100, 2),
72
+ "suspicious_keywords": [risk]
73
+ }