File size: 6,825 Bytes
7ba4896
 
a28108a
7ba4896
afe5f61
84bb3ad
85bf665
afe5f61
 
0bbfcb4
7ba4896
85bf665
 
7ba4896
 
 
60f592c
a28108a
 
85bf665
84bb3ad
0bbfcb4
60f592c
7ba4896
 
 
85bf665
afe5f61
 
 
 
 
85bf665
7ba4896
 
 
 
0f9c412
7ba4896
 
 
 
 
e111c61
7ba4896
 
85bf665
7ba4896
 
85bf665
ce94a9b
e8a997e
e111c61
e8a997e
7ba4896
 
 
e111c61
 
7ba4896
 
 
afe5f61
7ba4896
 
 
 
 
 
afe5f61
 
 
 
 
e111c61
ce94a9b
 
e832ad0
afe5f61
 
85bf665
ce94a9b
72e2b13
 
afe5f61
7ba4896
e8a997e
afe5f61
 
 
 
e8a997e
 
85bf665
7ba4896
 
 
e8a997e
afe5f61
 
 
 
 
 
 
 
 
 
 
fe08a39
85bf665
042c85d
 
 
 
 
85bf665
042c85d
 
72e2b13
 
afe5f61
ce94a9b
afe5f61
 
e832ad0
 
afe5f61
 
e832ad0
042c85d
b86bcb5
 
 
 
afe5f61
85bf665
 
042c85d
85bf665
 
 
 
 
 
 
 
042c85d
 
e6f4d9c
72e2b13
 
 
 
 
 
 
e111c61
afe5f61
 
 
 
 
 
 
 
 
 
b95a972
 
 
fe08a39
85bf665
afe5f61
 
 
 
 
85bf665
afe5f61
ce94a9b
afe5f61
ce94a9b
e8a997e
7ba4896
 
afe5f61
 
7ba4896
afe5f61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import gradio as gr
import re
import os
import torch
import uvicorn
import json

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, AutoModelForSequenceClassification
from arabert.preprocess import ArabertPreprocessor
from tabulate import tabulate  


MODEL_REPO = "kkAsmaa/ChildShield"
MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
SUB_FOLDER = "ChildShield"
HF_TOKEN = os.getenv("HF_TOKEN")


print("๐Ÿ”„ Loading ChildShield Model Weights with Deep Window Auto-Logging Features...")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER)
model.eval()
arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)


app = FastAPI(title="ChildShield Backend API")

class InputData(BaseModel):
    text: str


def clean_obfuscation(text):
    text = str(text)
    text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text)
    text = re.sub(r'(?<=[ุฃ-ูŠ])[^\sุฃ-ูŠ](?=[ุฃ-ูŠ])', '', text)
    text = re.sub(r'(?<=[ุฃ-ูŠ])\s(?=[ุฃ-ูŠ]\s|[ุฃ-ูŠ]$)', '', text)
    text = re.sub(r'ู€+', '', text)
    text = re.sub(r'(.)\1{2,}', r'\1\1', text)
    text = re.sub(r'[^\w\s\.]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def full_preprocess(text):
    text_no_trickery = clean_obfuscation(text)
    final_text = arabic_prep.preprocess(text_no_trickery)   
    return final_text


def predict_safety_api(text):
    """
    Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap.
    """
    cleaned_text = full_preprocess(text)
    full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False)
    input_ids = full_encodings['input_ids']
    total_tokens_count = len(input_ids)
    
    window_size = 60
    overlap = 20
    step = window_size - overlap
    windows = []
    
    if len(input_ids) <= window_size:
        windows = [input_ids]
    else:
        for i in range(0, len(input_ids), step):
            window = input_ids[i:i + window_size]
            if len(window) > 0:
                windows.append(window)
            if i + window_size >= len(input_ids):
                break
                
    total_windows_count = len(windows)
    is_blocked = False
    highest_unsafe_prob = 0.0
    highest_safe_prob = 0.0
    windows_analysis = []
    triggered_windows = []
    windows_table_data = [] 

    full_windows_text_list = []
    
    for idx, win_ids in enumerate(windows):
        window_text = tokenizer.decode(win_ids, skip_special_tokens=True)
        inputs = tokenizer(
            window_text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=60
        )
        
        with torch.no_grad():
            outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist()
        
        safe_prob = float(probs[0])
        unsafe_prob = float(probs[1])
        prediction = "UNSAFE" if unsafe_prob > 0.50 else "SAFE"
        
        windows_analysis.append({
            "window_id": idx + 1,
            "window_text": window_text,
            "safe_probability": round(safe_prob, 4),
            "unsafe_probability": round(unsafe_prob, 4),
            "prediction": prediction
        })
        
        
        windows_table_data.append([
            f"Win {idx + 1}",
            window_text[:45] + "..." if len(window_text) > 45 else window_text,
            f"{safe_prob * 100:.2f}%",
            f"{unsafe_prob * 100:.2f}%",
            f"โŒ {prediction}" if prediction == "UNSAFE" else f"โœ… {prediction}"
        ])
        
        full_windows_text_list.append(f"๐Ÿ“– [Window {idx + 1} Full Text]:\n   \"{window_text}\"\n")
        
        if unsafe_prob > 0.50:
            is_blocked = True
            highest_unsafe_prob = max(highest_unsafe_prob, unsafe_prob)
            triggered_windows.append(idx + 1)
        else:
            highest_safe_prob = max(highest_safe_prob, safe_prob)    

    final_prediction = "UNSAFE" if is_blocked else "SAFE"
    winning_probability = highest_unsafe_prob if is_blocked else highest_safe_prob
    formatted_confidence = f"{winning_probability * 100:.2f}%"
    
    unsafe_confidence_score = round(highest_unsafe_prob, 4)
    safe_confidence_score = round(1.0 - highest_unsafe_prob, 4)
    final_confidence = unsafe_confidence_score if is_blocked else safe_confidence_score

  
    alert_banner = "๐Ÿšจ [BLOCK] CHILDSHIELD AI INFERENCE REPORT" if is_blocked else "โœ… [PASS] CHILDSHIELD AI INFERENCE REPORT"
    print(f"\n================ {alert_banner} ================")
    print(f" Received Original Text:\n\"{text.strip()}\"")
    print(f"\n Preprocessed Cleaned Text:\n\"{cleaned_text}\"")
    print(f"\n Total Page Tokens Count  : {total_tokens_count}")
    print(f" Total Sliding Windows Run : {total_windows_count} Windows (Size: 60, Overlap: 20)")
    print(f" Final Security Verdict    : {final_prediction}")
    print(f" Model Decision Confidence : {formatted_confidence}")
    print(f" Triggered Windows ID      : {triggered_windows}")
    print("\n --- Windows Detailed Semantic Analysis Table ---")
    print(tabulate(windows_table_data, headers=["ID", "Window Text Preview", "Safe Prob", "Unsafe Prob", "Verdict"], tablefmt="grid"))
    print("========================================================================\n")

        
    print("\n๐Ÿ” === Windows Text Inspection (ุงู„ู†ุตูˆุต ุงู„ูƒุงู…ู„ุฉ ู„ุชู‚ุณูŠู… ุงู„ู†ูˆุงูุฐ) ===")
    for window_log in full_windows_text_list:
        print(window_log)
    print("========================================================================\n")


    return {
        "original_text": text,
        "cleaned_text": cleaned_text,
        "total_tokens": total_tokens_count,
        "window_size": window_size,
        "overlap": overlap,
        "total_windows": total_windows_count,
        "triggered_windows": triggered_windows,
        "windows_analysis": windows_analysis,
        "final_prediction": final_prediction,
        "blocked": is_blocked,
        "highest_unsafe_confidence": round(highest_unsafe_prob, 4),
        "confidence": formatted_confidence 
    }  


@app.post("/predict")
def predict(data: InputData):
    result = predict_safety_api(data.text)
    return result


gradio_interface = gr.Interface(    
    fn=predict_safety_api,
    inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
    outputs=gr.JSON(label="Guard Response Object"),
    title="ChildShield Production API Gate (Arabic Version)๐Ÿ›ก๏ธ"
)

app = gr.mount_gradio_app(app, gradio_interface, path="/")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)