kkAsmaa's picture
Update app.py
72e2b13 verified
import gradio as gr
import re
import os
import torch
import uvicorn
import json
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import BertTokenizer, AutoModelForSequenceClassification
from arabert.preprocess import ArabertPreprocessor
from tabulate import tabulate
MODEL_REPO = "kkAsmaa/ChildShield"
MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
SUB_FOLDER = "ChildShield"
HF_TOKEN = os.getenv("HF_TOKEN")
print("๐Ÿ”„ Loading ChildShield Model Weights with Deep Window Auto-Logging Features...")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER)
model.eval()
arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
app = FastAPI(title="ChildShield Backend API")
class InputData(BaseModel):
text: str
def clean_obfuscation(text):
text = str(text)
text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text)
text = re.sub(r'(?<=[ุฃ-ูŠ])[^\sุฃ-ูŠ](?=[ุฃ-ูŠ])', '', text)
text = re.sub(r'(?<=[ุฃ-ูŠ])\s(?=[ุฃ-ูŠ]\s|[ุฃ-ูŠ]$)', '', text)
text = re.sub(r'ู€+', '', text)
text = re.sub(r'(.)\1{2,}', r'\1\1', text)
text = re.sub(r'[^\w\s\.]', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def full_preprocess(text):
text_no_trickery = clean_obfuscation(text)
final_text = arabic_prep.preprocess(text_no_trickery)
return final_text
def predict_safety_api(text):
"""
Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap.
"""
cleaned_text = full_preprocess(text)
full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False)
input_ids = full_encodings['input_ids']
total_tokens_count = len(input_ids)
window_size = 60
overlap = 20
step = window_size - overlap
windows = []
if len(input_ids) <= window_size:
windows = [input_ids]
else:
for i in range(0, len(input_ids), step):
window = input_ids[i:i + window_size]
if len(window) > 0:
windows.append(window)
if i + window_size >= len(input_ids):
break
total_windows_count = len(windows)
is_blocked = False
highest_unsafe_prob = 0.0
highest_safe_prob = 0.0
windows_analysis = []
triggered_windows = []
windows_table_data = []
full_windows_text_list = []
for idx, win_ids in enumerate(windows):
window_text = tokenizer.decode(win_ids, skip_special_tokens=True)
inputs = tokenizer(
window_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=60
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist()
safe_prob = float(probs[0])
unsafe_prob = float(probs[1])
prediction = "UNSAFE" if unsafe_prob > 0.50 else "SAFE"
windows_analysis.append({
"window_id": idx + 1,
"window_text": window_text,
"safe_probability": round(safe_prob, 4),
"unsafe_probability": round(unsafe_prob, 4),
"prediction": prediction
})
windows_table_data.append([
f"Win {idx + 1}",
window_text[:45] + "..." if len(window_text) > 45 else window_text,
f"{safe_prob * 100:.2f}%",
f"{unsafe_prob * 100:.2f}%",
f"โŒ {prediction}" if prediction == "UNSAFE" else f"โœ… {prediction}"
])
full_windows_text_list.append(f"๐Ÿ“– [Window {idx + 1} Full Text]:\n \"{window_text}\"\n")
if unsafe_prob > 0.50:
is_blocked = True
highest_unsafe_prob = max(highest_unsafe_prob, unsafe_prob)
triggered_windows.append(idx + 1)
else:
highest_safe_prob = max(highest_safe_prob, safe_prob)
final_prediction = "UNSAFE" if is_blocked else "SAFE"
winning_probability = highest_unsafe_prob if is_blocked else highest_safe_prob
formatted_confidence = f"{winning_probability * 100:.2f}%"
unsafe_confidence_score = round(highest_unsafe_prob, 4)
safe_confidence_score = round(1.0 - highest_unsafe_prob, 4)
final_confidence = unsafe_confidence_score if is_blocked else safe_confidence_score
alert_banner = "๐Ÿšจ [BLOCK] CHILDSHIELD AI INFERENCE REPORT" if is_blocked else "โœ… [PASS] CHILDSHIELD AI INFERENCE REPORT"
print(f"\n================ {alert_banner} ================")
print(f" Received Original Text:\n\"{text.strip()}\"")
print(f"\n Preprocessed Cleaned Text:\n\"{cleaned_text}\"")
print(f"\n Total Page Tokens Count : {total_tokens_count}")
print(f" Total Sliding Windows Run : {total_windows_count} Windows (Size: 60, Overlap: 20)")
print(f" Final Security Verdict : {final_prediction}")
print(f" Model Decision Confidence : {formatted_confidence}")
print(f" Triggered Windows ID : {triggered_windows}")
print("\n --- Windows Detailed Semantic Analysis Table ---")
print(tabulate(windows_table_data, headers=["ID", "Window Text Preview", "Safe Prob", "Unsafe Prob", "Verdict"], tablefmt="grid"))
print("========================================================================\n")
print("\n๐Ÿ” === Windows Text Inspection (ุงู„ู†ุตูˆุต ุงู„ูƒุงู…ู„ุฉ ู„ุชู‚ุณูŠู… ุงู„ู†ูˆุงูุฐ) ===")
for window_log in full_windows_text_list:
print(window_log)
print("========================================================================\n")
return {
"original_text": text,
"cleaned_text": cleaned_text,
"total_tokens": total_tokens_count,
"window_size": window_size,
"overlap": overlap,
"total_windows": total_windows_count,
"triggered_windows": triggered_windows,
"windows_analysis": windows_analysis,
"final_prediction": final_prediction,
"blocked": is_blocked,
"highest_unsafe_confidence": round(highest_unsafe_prob, 4),
"confidence": formatted_confidence
}
@app.post("/predict")
def predict(data: InputData):
result = predict_safety_api(data.text)
return result
gradio_interface = gr.Interface(
fn=predict_safety_api,
inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
outputs=gr.JSON(label="Guard Response Object"),
title="ChildShield Production API Gate (Arabic Version)๐Ÿ›ก๏ธ"
)
app = gr.mount_gradio_app(app, gradio_interface, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)