Spaces:
Running
Running
File size: 6,825 Bytes
7ba4896 a28108a 7ba4896 afe5f61 84bb3ad 85bf665 afe5f61 0bbfcb4 7ba4896 85bf665 7ba4896 60f592c a28108a 85bf665 84bb3ad 0bbfcb4 60f592c 7ba4896 85bf665 afe5f61 85bf665 7ba4896 0f9c412 7ba4896 e111c61 7ba4896 85bf665 7ba4896 85bf665 ce94a9b e8a997e e111c61 e8a997e 7ba4896 e111c61 7ba4896 afe5f61 7ba4896 afe5f61 e111c61 ce94a9b e832ad0 afe5f61 85bf665 ce94a9b 72e2b13 afe5f61 7ba4896 e8a997e afe5f61 e8a997e 85bf665 7ba4896 e8a997e afe5f61 fe08a39 85bf665 042c85d 85bf665 042c85d 72e2b13 afe5f61 ce94a9b afe5f61 e832ad0 afe5f61 e832ad0 042c85d b86bcb5 afe5f61 85bf665 042c85d 85bf665 042c85d e6f4d9c 72e2b13 e111c61 afe5f61 b95a972 fe08a39 85bf665 afe5f61 85bf665 afe5f61 ce94a9b afe5f61 ce94a9b e8a997e 7ba4896 afe5f61 7ba4896 afe5f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import gradio as gr
import re
import os
import torch
import uvicorn
import json
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, AutoModelForSequenceClassification
from arabert.preprocess import ArabertPreprocessor
from tabulate import tabulate
MODEL_REPO = "kkAsmaa/ChildShield"
MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
SUB_FOLDER = "ChildShield"
HF_TOKEN = os.getenv("HF_TOKEN")
print("๐ Loading ChildShield Model Weights with Deep Window Auto-Logging Features...")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER)
model.eval()
arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
app = FastAPI(title="ChildShield Backend API")
class InputData(BaseModel):
text: str
def clean_obfuscation(text):
text = str(text)
text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text)
text = re.sub(r'(?<=[ุฃ-ู])[^\sุฃ-ู](?=[ุฃ-ู])', '', text)
text = re.sub(r'(?<=[ุฃ-ู])\s(?=[ุฃ-ู]\s|[ุฃ-ู]$)', '', text)
text = re.sub(r'ู+', '', text)
text = re.sub(r'(.)\1{2,}', r'\1\1', text)
text = re.sub(r'[^\w\s\.]', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def full_preprocess(text):
text_no_trickery = clean_obfuscation(text)
final_text = arabic_prep.preprocess(text_no_trickery)
return final_text
def predict_safety_api(text):
"""
Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap.
"""
cleaned_text = full_preprocess(text)
full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False)
input_ids = full_encodings['input_ids']
total_tokens_count = len(input_ids)
window_size = 60
overlap = 20
step = window_size - overlap
windows = []
if len(input_ids) <= window_size:
windows = [input_ids]
else:
for i in range(0, len(input_ids), step):
window = input_ids[i:i + window_size]
if len(window) > 0:
windows.append(window)
if i + window_size >= len(input_ids):
break
total_windows_count = len(windows)
is_blocked = False
highest_unsafe_prob = 0.0
highest_safe_prob = 0.0
windows_analysis = []
triggered_windows = []
windows_table_data = []
full_windows_text_list = []
for idx, win_ids in enumerate(windows):
window_text = tokenizer.decode(win_ids, skip_special_tokens=True)
inputs = tokenizer(
window_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=60
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist()
safe_prob = float(probs[0])
unsafe_prob = float(probs[1])
prediction = "UNSAFE" if unsafe_prob > 0.50 else "SAFE"
windows_analysis.append({
"window_id": idx + 1,
"window_text": window_text,
"safe_probability": round(safe_prob, 4),
"unsafe_probability": round(unsafe_prob, 4),
"prediction": prediction
})
windows_table_data.append([
f"Win {idx + 1}",
window_text[:45] + "..." if len(window_text) > 45 else window_text,
f"{safe_prob * 100:.2f}%",
f"{unsafe_prob * 100:.2f}%",
f"โ {prediction}" if prediction == "UNSAFE" else f"โ
{prediction}"
])
full_windows_text_list.append(f"๐ [Window {idx + 1} Full Text]:\n \"{window_text}\"\n")
if unsafe_prob > 0.50:
is_blocked = True
highest_unsafe_prob = max(highest_unsafe_prob, unsafe_prob)
triggered_windows.append(idx + 1)
else:
highest_safe_prob = max(highest_safe_prob, safe_prob)
final_prediction = "UNSAFE" if is_blocked else "SAFE"
winning_probability = highest_unsafe_prob if is_blocked else highest_safe_prob
formatted_confidence = f"{winning_probability * 100:.2f}%"
unsafe_confidence_score = round(highest_unsafe_prob, 4)
safe_confidence_score = round(1.0 - highest_unsafe_prob, 4)
final_confidence = unsafe_confidence_score if is_blocked else safe_confidence_score
alert_banner = "๐จ [BLOCK] CHILDSHIELD AI INFERENCE REPORT" if is_blocked else "โ
[PASS] CHILDSHIELD AI INFERENCE REPORT"
print(f"\n================ {alert_banner} ================")
print(f" Received Original Text:\n\"{text.strip()}\"")
print(f"\n Preprocessed Cleaned Text:\n\"{cleaned_text}\"")
print(f"\n Total Page Tokens Count : {total_tokens_count}")
print(f" Total Sliding Windows Run : {total_windows_count} Windows (Size: 60, Overlap: 20)")
print(f" Final Security Verdict : {final_prediction}")
print(f" Model Decision Confidence : {formatted_confidence}")
print(f" Triggered Windows ID : {triggered_windows}")
print("\n --- Windows Detailed Semantic Analysis Table ---")
print(tabulate(windows_table_data, headers=["ID", "Window Text Preview", "Safe Prob", "Unsafe Prob", "Verdict"], tablefmt="grid"))
print("========================================================================\n")
print("\n๐ === Windows Text Inspection (ุงููุตูุต ุงููุงู
ูุฉ ูุชูุณูู
ุงูููุงูุฐ) ===")
for window_log in full_windows_text_list:
print(window_log)
print("========================================================================\n")
return {
"original_text": text,
"cleaned_text": cleaned_text,
"total_tokens": total_tokens_count,
"window_size": window_size,
"overlap": overlap,
"total_windows": total_windows_count,
"triggered_windows": triggered_windows,
"windows_analysis": windows_analysis,
"final_prediction": final_prediction,
"blocked": is_blocked,
"highest_unsafe_confidence": round(highest_unsafe_prob, 4),
"confidence": formatted_confidence
}
@app.post("/predict")
def predict(data: InputData):
result = predict_safety_api(data.text)
return result
gradio_interface = gr.Interface(
fn=predict_safety_api,
inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
outputs=gr.JSON(label="Guard Response Object"),
title="ChildShield Production API Gate (Arabic Version)๐ก๏ธ"
)
app = gr.mount_gradio_app(app, gradio_interface, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|