thong0710 commited on
Commit
ab41517
·
verified ·
1 Parent(s): 1d90061

push speed

Browse files
Files changed (1) hide show
  1. app/safety_check.py +132 -73
app/safety_check.py CHANGED
@@ -8,117 +8,176 @@ from PIL import Image
8
  import torch
9
  import re
10
  from urllib.parse import urlparse, unquote
 
 
11
 
12
- # ==== Load models 1 lần duy nhất ====
13
- # Text độc hại
14
  detox_model = Detoxify('original')
15
 
16
- # NSFW
17
  nsfw_model_id = "Falconsai/nsfw_image_detection"
18
  nsfw_processor = AutoProcessor.from_pretrained(nsfw_model_id)
19
  nsfw_model = AutoModelForImageClassification.from_pretrained(nsfw_model_id)
20
 
21
- # Bạo lực
22
  violence_model_id = "jaranohaal/vit-base-violence-detection"
23
- violence_processor = ViTFeatureExtractor.from_pretrained(violence_model_id)
24
  violence_model = ViTForImageClassification.from_pretrained(violence_model_id)
 
25
 
26
- # Caption ảnh
27
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
28
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
29
 
30
- # AI phân loại URL
31
- url_classifier = pipeline("zero-shot-classification")
32
 
33
- # ==== HÀM CHÍNH ====
34
 
 
 
35
  def is_prompt_safe(prompt: str):
36
  results = detox_model.predict(prompt)
37
  threshold = 0.5
38
  flagged = {label: score for label, score in results.items() if score > threshold}
39
- return False if flagged else True, list(flagged.keys())
 
 
 
40
 
 
41
  def generate_caption(image: Image.Image):
42
  inputs = blip_processor(images=image, return_tensors="pt")
43
  with torch.no_grad():
44
- output = blip_model.generate(**inputs)
45
- return blip_processor.decode(output[0], skip_special_tokens=True)
46
-
47
- def check_nsfw_image(image: Image.Image):
48
- inputs = nsfw_processor(images=image, return_tensors="pt")
49
- with torch.no_grad():
50
- logits = nsfw_model(**inputs).logits
51
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
52
- labels = list(nsfw_model.config.id2label.values())
53
- pred_idx = probs.argmax().item()
54
- return labels[pred_idx], probs[pred_idx].item() * 100
55
-
56
- def check_violence_image(image: Image.Image):
57
- inputs = violence_processor(images=image, return_tensors="pt")
58
- with torch.no_grad():
59
- logits = violence_model(**inputs).logits
60
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
61
- labels = ["Non-Violent", "Violent"]
62
- pred_idx = probs.argmax().item()
63
- return labels[pred_idx], probs[pred_idx].item() * 100
64
-
65
- def analyze_image(image: Image.Image) -> str:
66
- nsfw_label, nsfw_score = check_nsfw_image(image)
67
- violence_label, violence_score = check_violence_image(image)
68
- caption = generate_caption(image)
69
-
70
- result = f"🖼️ Mô tả ảnh: {caption}\n\n"
71
-
72
- # NSFW
73
- if nsfw_label.lower() in ["porn", "hentai", "sex", "nsfw"] and nsfw_score > 60:
74
- result += f"🚨 NSFW: {nsfw_label} ({nsfw_score:.2f}%)\n"
75
- else:
76
- result += f"✅ An toàn NSFW: {nsfw_label} ({nsfw_score:.2f}%)\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Bạo lực
79
- if violence_label == "Violent" and violence_score > 80:
80
- result += f"🚨 Bạo lực: {violence_label} ({violence_score:.2f}%)"
81
- else:
82
- result += f"✅ An toàn bạo lực: {violence_label} ({violence_score:.2f}%)"
 
83
 
84
- return result
85
 
86
- def check_url(url: str) -> str:
 
 
87
  try:
88
- decoded = unquote(url)
89
- parsed = urlparse(decoded)
90
  warnings = []
91
 
92
- # Kiểm tra dấu hiệu đáng ngờ
93
- if re.match(r'^https?://\d{1,3}(\.\d{1,3}){3}', decoded):
94
- warnings.append("🚨 Truy cập IP trực tiếp")
95
- if re.search(r'\.(exe|msi|bat|js|jar|apk|dmg)(\?|$)', parsed.path.lower()):
96
- warnings.append("🚨 URL chứa file thực thi")
 
97
  if 'redirect' in parsed.path.lower() or 'url=' in parsed.query.lower():
98
- warnings.append("⚠️ Chứa chức năng chuyển hướng")
 
99
  if re.search(r'%[0-9a-f]{2}|[\x00-\x1f\x7f]', url):
100
- warnings.append("🚨 tự mã hóa bất thường")
 
101
  if '@' in parsed.netloc:
102
- warnings.append("🚨 URL giả mạo domain (dùng @)")
103
- if any(k in parsed.netloc.lower() for k in ['login', 'secure', 'account']):
104
- warnings.append("⚠️ Domain trông giống dịch vụ đăng nhập")
 
 
 
105
  if parsed.scheme == 'http':
106
- warnings.append("⚠️ Kết nối không mã hóa (HTTP)")
107
 
108
- # Phân tích bằng AI
109
- ai_result = url_classifier(url, candidate_labels=["malicious", "safe"])
110
  ai_label = ai_result["labels"][0]
111
  ai_score = ai_result["scores"][0] * 100
112
 
113
- # Kết luận
114
- result = f"🔗 URL: {url}\n"
 
 
 
 
 
 
 
 
 
 
115
  if warnings or ai_label == "malicious":
116
- result += f"🚨 KHÔNG AN TOÀN\n\n📢 Cảnh báo:\n" + "\n".join(f"- {w}" for w in warnings)
117
  else:
118
- result += "✅ An toàn\n"
119
-
120
- result += f"\n🤖 AI đánh giá: {ai_label} ({ai_score:.2f}%)"
121
- return result
122
 
123
  except Exception as e:
124
- return f"⚠️ Lỗi phân tích URL: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch
9
  import re
10
  from urllib.parse import urlparse, unquote
11
+ from functools import lru_cache
12
+ import threading
13
 
14
+ # Load models
 
15
  detox_model = Detoxify('original')
16
 
 
17
  nsfw_model_id = "Falconsai/nsfw_image_detection"
18
  nsfw_processor = AutoProcessor.from_pretrained(nsfw_model_id)
19
  nsfw_model = AutoModelForImageClassification.from_pretrained(nsfw_model_id)
20
 
 
21
  violence_model_id = "jaranohaal/vit-base-violence-detection"
 
22
  violence_model = ViTForImageClassification.from_pretrained(violence_model_id)
23
+ violence_processor = ViTFeatureExtractor.from_pretrained(violence_model_id)
24
 
 
25
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
26
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
27
 
28
+ # Load hình zero-shot nhẹ
29
+ classifier = pipeline("zero-shot-classification", model="MoritzLaurer/distilbert-zero-shot-v1")
30
 
 
31
 
32
+ # ========================== TEXT (Prompt) ==========================
33
+ @lru_cache(maxsize=128)
34
  def is_prompt_safe(prompt: str):
35
  results = detox_model.predict(prompt)
36
  threshold = 0.5
37
  flagged = {label: score for label, score in results.items() if score > threshold}
38
+ if flagged:
39
+ return False, list(flagged.keys())
40
+ return True, []
41
+
42
 
43
+ # ========================== IMAGE (NSFW + Violence) ==========================
44
  def generate_caption(image: Image.Image):
45
  inputs = blip_processor(images=image, return_tensors="pt")
46
  with torch.no_grad():
47
+ out = blip_model.generate(**inputs)
48
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
49
+ return caption
50
+
51
+ def analyze_image(image: Image.Image):
52
+ result = {
53
+ "nsfw": "",
54
+ "violence": ""
55
+ }
56
+
57
+ def nsfw_task():
58
+ nsfw_inputs = nsfw_processor(images=image, return_tensors="pt")
59
+ with torch.no_grad():
60
+ nsfw_outputs = nsfw_model(**nsfw_inputs)
61
+ nsfw_probs = torch.nn.functional.softmax(nsfw_outputs.logits, dim=1)[0]
62
+ nsfw_labels = list(nsfw_model.config.id2label.values())
63
+ nsfw_pred = nsfw_probs.argmax().item()
64
+ nsfw_label = nsfw_labels[nsfw_pred]
65
+ nsfw_score = nsfw_probs[nsfw_pred].item() * 100
66
+ caption = generate_caption(image)
67
+
68
+ if nsfw_label.lower() in ["porn", "hentai", "sex", "nsfw"]:
69
+ result["nsfw"] = f"""\ud83d\udea8 Ảnh KHÔNG an toàn (NSFW):\n- Loại: {nsfw_label}\n- Độ chính xác: {nsfw_score:.2f}%\n- tả: {caption}"""
70
+ else:
71
+ result["nsfw"] = f"""✅ Ảnh an toàn (NSFW):\n- Loại: {nsfw_label}\n- Độ chính xác: {nsfw_score:.2f}%\n- Mô tả: {caption}"""
72
+
73
+ def violence_task():
74
+ violence_inputs = violence_processor(images=image, return_tensors="pt")
75
+ with torch.no_grad():
76
+ violence_outputs = violence_model(**violence_inputs)
77
+ violence_probs = torch.nn.functional.softmax(violence_outputs.logits, dim=1)[0]
78
+ violence_labels = ["Non-Violent", "Violent"]
79
+ violence_pred = violence_probs.argmax().item()
80
+ violence_label = violence_labels[violence_pred]
81
+ violence_score = violence_probs[violence_pred].item() * 100
82
+ caption = generate_caption(image)
83
+
84
+ is_violent = False
85
+ if violence_label.lower() == "non-violent" and violence_score > 50:
86
+ is_violent = True
87
+ elif violence_label.lower() == "violent" and violence_score > 80:
88
+ is_violent = True
89
+
90
+ if is_violent:
91
+ result["violence"] = f"""\ud83d\udea8 Ảnh KHÔNG an toàn (Bạo lực):\n- Loại: {violence_label}\n- Độ chính xác: {violence_score:.2f}%\n- Mô tả: {caption}"""
92
+ else:
93
+ result["violence"] = f"""✅ Ảnh an toàn (Bạo lực):\n- Loại: {violence_label}\n- Độ chính xác: {violence_score:.2f}%\n- Mô tả: {caption}"""
94
 
95
+ # Chạy song song
96
+ t1 = threading.Thread(target=nsfw_task)
97
+ t2 = threading.Thread(target=violence_task)
98
+ t1.start(); t2.start()
99
+ t1.join(); t2.join()
100
+ return result["nsfw"], result["violence"]
101
 
 
102
 
103
+ # ========================== URL ==========================
104
+ @lru_cache(maxsize=128)
105
+ def check_url(url: str):
106
  try:
107
+ decoded_url = unquote(url)
108
+ parsed = urlparse(decoded_url)
109
  warnings = []
110
 
111
+ if re.match(r'^https?://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', decoded_url):
112
+ warnings.append("\ud83d\udea8 Nguy hiểm: Truy cập trực tiếp bằng IP (thường dùng cho tấn công)")
113
+
114
+ if re.search(r'\\.(exe|msi|bat|js|jar|apk|dmg)(\\?|$)', parsed.path.lower()):
115
+ warnings.append("\ud83d\udea8 Nguy hiểm: URL chứa file thực thi có thể độc hại")
116
+
117
  if 'redirect' in parsed.path.lower() or 'url=' in parsed.query.lower():
118
+ warnings.append("\u26a0\ufe0f Cảnh báo: URL chứa chức năng redirect (có thể lừa đảo)")
119
+
120
  if re.search(r'%[0-9a-f]{2}|[\x00-\x1f\x7f]', url):
121
+ warnings.append("\ud83d\udea8 Nguy hiểm: URL chứa ký tự mã hóa đáng ngờ (có thể tấn công)")
122
+
123
  if '@' in parsed.netloc:
124
+ warnings.append("\ud83d\udea8 Lừa đảo: URL chứa kỹ thuật giả mạo domain (user@fake-domain)")
125
+
126
+ deceptive_domains = ['login', 'secure', 'account', 'verify', 'update']
127
+ if any(keyword in parsed.netloc.lower() for keyword in deceptive_domains):
128
+ warnings.append("\u26a0\ufe0f Cảnh báo: Domain có dấu hiệu giả mạo dịch vụ đăng nhập")
129
+
130
  if parsed.scheme == 'http':
131
+ warnings.append("\u26a0\ufe0f Cảnh báo: Kết nối không mã hóa (HTTP)")
132
 
133
+ ai_result = classifier(url, candidate_labels=["malicious", "safe"])
 
134
  ai_label = ai_result["labels"][0]
135
  ai_score = ai_result["scores"][0] * 100
136
 
137
+ report = {
138
+ "url": url,
139
+ "decoded_url": decoded_url,
140
+ "domain": parsed.netloc,
141
+ "path": parsed.path,
142
+ "warnings": warnings,
143
+ "ai_analysis": {
144
+ "label": ai_label,
145
+ "confidence": ai_score
146
+ }
147
+ }
148
+
149
  if warnings or ai_label == "malicious":
150
+ return format_report(report, is_safe=False)
151
  else:
152
+ return format_report(report, is_safe=True)
 
 
 
153
 
154
  except Exception as e:
155
+ return f"\u26a0\ufe0f Lỗi khi phân tích URL: {str(e)}"
156
+
157
+ def format_report(report: dict, is_safe: bool):
158
+ warning_text = "\n".join(f"- {w}" for w in report["warnings"]) if report["warnings"] else "- Không phát hiện cảnh báo"
159
+
160
+ if not is_safe:
161
+ return f"""\ud83d\udea8 URL KHÔNG AN TOÀN
162
+ \ud83d\udd0d Phân tích chi tiết:
163
+ • URL gốc: {report['url']}
164
+ • Domain: {report['domain']}
165
+ • Đường dẫn: {report['path']}
166
+
167
+ \ud83d\udce3 CẢNH BÁO:
168
+ {warning_text}
169
+
170
+ 🤖 Phân tích AI:
171
+ - Kết quả: {report['ai_analysis']['label']}
172
+ - Độ tin cậy: {report['ai_analysis']['confidence']:.2f}%
173
+
174
+ 🛡️ Khuyến nghị: KHÔNG TRUY CẬP!"""
175
+ else:
176
+ return f"""✅ URL AN TOÀN
177
+ \ud83d\udd0d Phân tích chi tiết:
178
+ • URL gốc: {report['url']}
179
+ • Domain: {report['domain']}
180
+
181
+ 🤖 Phân tích AI:
182
+ - Kết quả: {report['ai_analysis']['label']}
183
+ - Độ tin cậy: {report['ai_analysis']['confidence']:.2f}%"""