Bennie12 commited on
Commit
e543e92
·
verified ·
1 Parent(s): f7cb5ba

Update bert_explainer.py

Browse files
Files changed (1) hide show
  1. bert_explainer.py +198 -190
bert_explainer.py CHANGED
@@ -1,190 +1,198 @@
1
- import os
2
- os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
3
- import jieba
4
- import torch
5
- import re
6
- import easyocr
7
- import io
8
- import numpy as np
9
- from PIL import Image
10
- from huggingface_hub import hf_hub_download
11
- from transformers import BertTokenizer
12
- from AI_Model_architecture import BertLSTM_CNN_Classifier
13
- from lime.lime_text import LimeTextExplainer
14
-
15
- HF_TOKEN = os.environ.get("HF_TOKEN")
16
-
17
-
18
- # OCR 模組
19
- reader = easyocr.Reader(['ch_tra', 'en'], gpu=torch.cuda.is_available())
20
-
21
- # 設定裝置(GPU 優先)
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
-
24
- # 載入模型與 tokenizer
25
- def load_model_and_tokenizer():
26
- global model, tokenizer
27
-
28
- if os.path.exists("model.pth"):
29
- print("✅ 已找到 model.pth 載入模型")
30
- model_path = "model.pth"
31
- else:
32
- print("🚀 未找到 model.pth")
33
- model_path = hf_hub_download(repo_id="Bennie12/Bert-Lstm-Cnn-ScamDetecter",
34
- filename="model.pth",
35
- token=HF_TOKEN)
36
-
37
- model = BertLSTM_CNN_Classifier()
38
- model.load_state_dict(torch.load(model_path, map_location=device))
39
- model.to(device)
40
- model.eval()
41
-
42
- tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese", use_fast=False)
43
-
44
- return model, tokenizer
45
-
46
- model, tokenizer = load_model_and_tokenizer()
47
- model.eval()
48
-
49
- # 預測單一句子的分類結果
50
- def predict_single_sentence(model, tokenizer, sentence, max_len=256):
51
- sentence = re.sub(r"\s+", "", sentence)
52
- sentence = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/._-]", "", sentence)
53
-
54
- encoded = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
55
- encoded = {k: v.to(device) for k, v in encoded.items()}
56
-
57
- with torch.no_grad():
58
- output = model(encoded["input_ids"], encoded["attention_mask"], encoded["token_type_ids"])
59
- prob = torch.sigmoid(output).item()
60
- label = int(prob > 0.5)
61
- risk = "🟢 低風險(正常)"
62
- if prob > 0.9:
63
- risk = "🔴 高風險(極可能是詐騙)"
64
- elif prob > 0.5:
65
- risk = "🟡 中風險(可疑)"
66
-
67
- pre_label = '詐騙' if label == 1 else '正常'
68
-
69
- return {
70
- "label": pre_label,
71
- "prob": prob,
72
- "risk": risk
73
- }
74
-
75
- # 提供 LIME 用的 predict_proba
76
- def predict_proba(texts):
77
- # tokenizer 批次處理
78
- encoded = tokenizer(
79
- texts,
80
- return_tensors="pt",
81
- padding=True,
82
- truncation=True,
83
- max_length=256
84
- )
85
-
86
- # 移動到 GPU 或 CPU
87
- encoded = {k: v.to(device) for k, v in encoded.items()}
88
-
89
- with torch.no_grad():
90
- outputs = model(encoded["input_ids"], encoded["attention_mask"], encoded["token_type_ids"])
91
- # outputs shape: (batch_size,)
92
- probs = torch.sigmoid(outputs).cpu().numpy()
93
-
94
-
95
- # 轉成 LIME 格式:(N, 2)
96
- probs_2d = np.vstack([1-probs, probs]).T
97
- return probs_2d
98
-
99
-
100
-
101
- # 初始化 LIME explainer
102
- class_names = ['正常', '詐騙']
103
- lime_explainer = LimeTextExplainer(class_names=class_names)
104
-
105
- # 擷取可疑詞彙 (改用 LIME)
106
-
107
- def suspicious_tokens(text, explainer=lime_explainer, top_k=5):
108
- try:
109
- explanation = explainer.explain_instance(text, predict_proba, num_features=top_k, num_samples=200)
110
- keywords = [word for word, weight in explanation.as_list()]
111
- return keywords
112
- except Exception as e:
113
- print("⚠ LIME 失敗,啟用 fallback:", e)
114
- fallback = ["繳費", "終止", "逾期", "限時", "驗證碼"]
115
- return [kw for kw in fallback if kw in text]
116
-
117
-
118
- # 文字清理
119
- def clean_text(text):
120
- text = re.sub(r"https?://\S+", "", text)
121
- text = re.sub(r"[a-zA-Z0-9:/.%\-_=+]{4,}", "", text)
122
- text = re.sub(r"\+?\d[\d\s\-]{5,}", "", text)
123
- text = re.sub(r"[^一-龥。,!?、]", "", text)
124
- sentences = re.split(r"[。!?]", text)
125
- cleaned = "。".join(sentences[:4])
126
- return cleaned[:300]
127
-
128
- # 高亮顯示
129
- def highlight_keywords(text, keywords, prob):
130
-
131
- if prob < 0.15: # 低風險完全不標註
132
- return text
133
-
134
- # 決定標註顏色
135
- if prob >= 0.65:
136
- css_class = 'red-highlight'
137
- else:
138
- css_class = 'yellow-highlight'
139
- for word in keywords:
140
- if len(word.strip()) >= 2:
141
- text = text.replace(word, f"<span class='{css_class}'>{word}</span>")
142
- return text
143
-
144
-
145
-
146
-
147
- # 文字分析主流程
148
- def analyze_text(text):
149
- cleaned_text = clean_text(text)
150
- result = predict_single_sentence(model, tokenizer, cleaned_text)
151
- label = result["label"]
152
- prob = result["prob"]
153
- risk = result["risk"]
154
-
155
- suspicious = suspicious_tokens(cleaned_text)
156
- # 依照可疑度做不同標註
157
- highlighted_text = highlight_keywords(text, suspicious, prob)
158
- # 低風���下不回傳 suspicious_keywords
159
- if prob < 0.15:
160
- suspicious = []
161
-
162
- print(f"\n📩 訊息內容:{text}")
163
- print(f"✅ 預測結果:{label}")
164
- print(f"📊 信心值:{round(prob*100, 2)}")
165
- print(f"⚠️ 風險等級:{risk}")
166
- print(f"可疑關鍵字擷取: {suspicious}")
167
-
168
- return {
169
- "status": label,
170
- "confidence": round(prob * 100, 2),
171
- "suspicious_keywords": suspicious,
172
- "highlighted_text": highlighted_text
173
- }
174
-
175
- # 圖片 OCR 分析
176
- def analyze_image(file_bytes):
177
- image = Image.open(io.BytesIO(file_bytes))
178
- image_np = np.array(image)
179
- results = reader.readtext(image_np)
180
-
181
- text = ' '.join([res[1] for res in results]).strip()
182
-
183
- if not text:
184
- return {
185
- "status" : "無法辨識文字",
186
- "confidence" : 0.0,
187
- "suspicious_keywords" : ["圖片中無可辨識的中文英文"],
188
- "highlighted_text": "無法辨識可疑內容"
189
- }
190
- return analyze_text(text)
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
3
+ import jieba
4
+ import torch
5
+ import re
6
+ import easyocr
7
+ import io
8
+ import numpy as np
9
+ from PIL import Image
10
+ from huggingface_hub import hf_hub_download
11
+ from transformers import BertTokenizer
12
+ from AI_Model_architecture import BertLSTM_CNN_Classifier
13
+ from lime.lime_text import LimeTextExplainer
14
+
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+
17
+
18
+ # OCR 模組
19
+ reader = easyocr.Reader(['ch_tra', 'en'], gpu=torch.cuda.is_available())
20
+
21
+ # 設定裝置(GPU 優先)
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ # 載入模型與 tokenizer
25
+ def load_model_and_tokenizer():
26
+ global model, tokenizer
27
+
28
+ if os.path.exists("model.pth"):
29
+ print("✅ 已找到 model.pth 載入模型")
30
+ model_path = "model.pth"
31
+ else:
32
+ print("🚀 未找到 model.pth")
33
+ model_path = hf_hub_download(repo_id="Bennie12/Bert-Lstm-Cnn-ScamDetecter",
34
+ filename="model.pth",
35
+ token=HF_TOKEN)
36
+
37
+ model = BertLSTM_CNN_Classifier()
38
+ model.load_state_dict(torch.load(model_path, map_location=device))
39
+ model.to(device)
40
+ model.eval()
41
+
42
+ tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese", use_fast=False)
43
+
44
+ return model, tokenizer
45
+
46
+ model, tokenizer = load_model_and_tokenizer()
47
+ model.eval()
48
+
49
+ # 預測單一句子的分類結果
50
+ def predict_single_sentence(model, tokenizer, sentence, max_len=256):
51
+ sentence = re.sub(r"\s+", "", sentence)
52
+ sentence = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/._-]", "", sentence)
53
+
54
+ encoded = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
55
+ encoded = {k: v.to(device) for k, v in encoded.items()}
56
+
57
+ with torch.no_grad():
58
+ output = model(encoded["input_ids"], encoded["attention_mask"], encoded["token_type_ids"])
59
+ prob = torch.sigmoid(output).item()
60
+ label = int(prob > 0.5)
61
+ risk = "🟢 低風險(正常)"
62
+ if prob > 0.9:
63
+ risk = "🔴 高風險(極可能是詐騙)"
64
+ elif prob > 0.5:
65
+ risk = "🟡 中風險(可疑)"
66
+
67
+ pre_label = '詐騙' if label == 1 else '正常'
68
+
69
+ return {
70
+ "label": pre_label,
71
+ "prob": prob,
72
+ "risk": risk
73
+ }
74
+
75
+ # 提供 LIME 用的 predict_proba
76
+ def predict_proba(texts):
77
+ # tokenizer 批次處理
78
+ encoded = tokenizer(
79
+ texts,
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=256
84
+ )
85
+
86
+ # 移動到 GPU 或 CPU
87
+ encoded = {k: v.to(device) for k, v in encoded.items()}
88
+
89
+ with torch.no_grad():
90
+ outputs = model(encoded["input_ids"], encoded["attention_mask"], encoded["token_type_ids"])
91
+ # outputs shape: (batch_size,)
92
+ probs = torch.sigmoid(outputs).cpu().numpy()
93
+
94
+
95
+ # 轉成 LIME 格式:(N, 2)
96
+ probs_2d = np.vstack([1-probs, probs]).T
97
+ return probs_2d
98
+
99
+
100
+
101
+ # 初始化 LIME explainer
102
+ class_names = ['正常', '詐騙']
103
+ lime_explainer = LimeTextExplainer(class_names=class_names)
104
+
105
+ # 擷取可疑詞彙 (改用 LIME)
106
+
107
+ def suspicious_tokens(text, explainer=lime_explainer, top_k=5):
108
+ try:
109
+ explanation = explainer.explain_instance(text, predict_proba, num_features=top_k, num_samples=200)
110
+ keywords = [word for word, weight in explanation.as_list()]
111
+ return keywords
112
+ except Exception as e:
113
+ print("⚠ LIME 失敗,啟用 fallback:", e)
114
+ fallback = ["繳費", "終止", "逾期", "限時", "驗證碼"]
115
+ return [kw for kw in fallback if kw in text]
116
+
117
+
118
+ # 文字清理
119
+ def clean_text(text):
120
+ text = re.sub(r"https?://\S+", "", text)
121
+ text = re.sub(r"[a-zA-Z0-9:/.%\-_=+]{4,}", "", text)
122
+ text = re.sub(r"\+?\d[\d\s\-]{5,}", "", text)
123
+ text = re.sub(r"[^一-龥。,!?、]", "", text)
124
+ sentences = re.split(r"[。!?]", text)
125
+ cleaned = "。".join(sentences[:4])
126
+ return cleaned[:300]
127
+
128
+ # 高亮顯示
129
+
130
+ def highlight_keywords(text, keywords, prob):
131
+ """
132
+ 根據模型信心值 (prob) 動態決定螢光標註顏色,
133
+ 並結合 jieba 斷詞,針對 LIME 輸出長片段進行子詞高亮標註。
134
+ """
135
+ if prob < 0.15: # 低風險完全不標註
136
+ return text
137
+
138
+ # 決定標註顏色
139
+ if prob >= 0.65:
140
+ css_class = 'red-highlight'
141
+ else:
142
+ css_class = 'yellow-highlight'
143
+
144
+ # 對每個 keyword 進行 jieba 斷詞後標註
145
+ for phrase in keywords:
146
+ for word in jieba.cut(phrase):
147
+ word = word.strip()
148
+ if len(word) >= 2 and word in text:
149
+ text = text.replace(word, f"<span class='{css_class}'>{word}</span>")
150
+ return text
151
+
152
+
153
+
154
+
155
+ # 文字分析主流程
156
+ def analyze_text(text):
157
+ cleaned_text = clean_text(text)
158
+ result = predict_single_sentence(model, tokenizer, cleaned_text)
159
+ label = result["label"]
160
+ prob = result["prob"]
161
+ risk = result["risk"]
162
+
163
+ suspicious = suspicious_tokens(cleaned_text)
164
+ # 依照可疑度做不同標註
165
+ highlighted_text = highlight_keywords(text, suspicious, prob)
166
+ # 低風險下不回傳 suspicious_keywords
167
+ if prob < 0.15:
168
+ suspicious = []
169
+
170
+ print(f"\n📩 訊息內容:{text}")
171
+ print(f" 預測結果:{label}")
172
+ print(f"📊 信心值:{round(prob*100, 2)}")
173
+ print(f"⚠️ 風險等級:{risk}")
174
+ print(f"可疑關鍵字擷取: {suspicious}")
175
+
176
+ return {
177
+ "status": label,
178
+ "confidence": round(prob * 100, 2),
179
+ "suspicious_keywords": suspicious,
180
+ "highlighted_text": highlighted_text
181
+ }
182
+
183
+ # 圖片 OCR 分析
184
+ def analyze_image(file_bytes):
185
+ image = Image.open(io.BytesIO(file_bytes))
186
+ image_np = np.array(image)
187
+ results = reader.readtext(image_np)
188
+
189
+ text = ' '.join([res[1] for res in results]).strip()
190
+
191
+ if not text:
192
+ return {
193
+ "status" : "無法辨識文字",
194
+ "confidence" : 0.0,
195
+ "suspicious_keywords" : ["圖片中無可辨識的中文英文"],
196
+ "highlighted_text": "無法辨識可疑內容"
197
+ }
198
+ return analyze_text(text)