Spaces:
Running
Running
| import gradio as gr | |
| import re | |
| import os | |
| import torch | |
| import uvicorn | |
| import json | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import BertTokenizer, AutoModelForSequenceClassification | |
| from arabert.preprocess import ArabertPreprocessor | |
| from tabulate import tabulate | |
| MODEL_REPO = "kkAsmaa/ChildShield" | |
| MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter" | |
| SUB_FOLDER = "ChildShield" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| print("๐ Loading ChildShield Model Weights with Deep Window Auto-Logging Features...") | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER) | |
| model.eval() | |
| arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME) | |
| app = FastAPI(title="ChildShield Backend API") | |
| class InputData(BaseModel): | |
| text: str | |
| def clean_obfuscation(text): | |
| text = str(text) | |
| text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text) | |
| text = re.sub(r'(?<=[ุฃ-ู])[^\sุฃ-ู](?=[ุฃ-ู])', '', text) | |
| text = re.sub(r'(?<=[ุฃ-ู])\s(?=[ุฃ-ู]\s|[ุฃ-ู]$)', '', text) | |
| text = re.sub(r'ู+', '', text) | |
| text = re.sub(r'(.)\1{2,}', r'\1\1', text) | |
| text = re.sub(r'[^\w\s\.]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return text.strip() | |
| def full_preprocess(text): | |
| text_no_trickery = clean_obfuscation(text) | |
| final_text = arabic_prep.preprocess(text_no_trickery) | |
| return final_text | |
| def predict_safety_api(text): | |
| """ | |
| Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap. | |
| """ | |
| cleaned_text = full_preprocess(text) | |
| full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False) | |
| input_ids = full_encodings['input_ids'] | |
| total_tokens_count = len(input_ids) | |
| window_size = 60 | |
| overlap = 20 | |
| step = window_size - overlap | |
| windows = [] | |
| if len(input_ids) <= window_size: | |
| windows = [input_ids] | |
| else: | |
| for i in range(0, len(input_ids), step): | |
| window = input_ids[i:i + window_size] | |
| if len(window) > 0: | |
| windows.append(window) | |
| if i + window_size >= len(input_ids): | |
| break | |
| total_windows_count = len(windows) | |
| is_blocked = False | |
| highest_unsafe_prob = 0.0 | |
| highest_safe_prob = 0.0 | |
| windows_analysis = [] | |
| triggered_windows = [] | |
| windows_table_data = [] | |
| full_windows_text_list = [] | |
| for idx, win_ids in enumerate(windows): | |
| window_text = tokenizer.decode(win_ids, skip_special_tokens=True) | |
| inputs = tokenizer( | |
| window_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=60 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist() | |
| safe_prob = float(probs[0]) | |
| unsafe_prob = float(probs[1]) | |
| prediction = "UNSAFE" if unsafe_prob > 0.50 else "SAFE" | |
| windows_analysis.append({ | |
| "window_id": idx + 1, | |
| "window_text": window_text, | |
| "safe_probability": round(safe_prob, 4), | |
| "unsafe_probability": round(unsafe_prob, 4), | |
| "prediction": prediction | |
| }) | |
| windows_table_data.append([ | |
| f"Win {idx + 1}", | |
| window_text[:45] + "..." if len(window_text) > 45 else window_text, | |
| f"{safe_prob * 100:.2f}%", | |
| f"{unsafe_prob * 100:.2f}%", | |
| f"โ {prediction}" if prediction == "UNSAFE" else f"โ {prediction}" | |
| ]) | |
| full_windows_text_list.append(f"๐ [Window {idx + 1} Full Text]:\n \"{window_text}\"\n") | |
| if unsafe_prob > 0.50: | |
| is_blocked = True | |
| highest_unsafe_prob = max(highest_unsafe_prob, unsafe_prob) | |
| triggered_windows.append(idx + 1) | |
| else: | |
| highest_safe_prob = max(highest_safe_prob, safe_prob) | |
| final_prediction = "UNSAFE" if is_blocked else "SAFE" | |
| winning_probability = highest_unsafe_prob if is_blocked else highest_safe_prob | |
| formatted_confidence = f"{winning_probability * 100:.2f}%" | |
| unsafe_confidence_score = round(highest_unsafe_prob, 4) | |
| safe_confidence_score = round(1.0 - highest_unsafe_prob, 4) | |
| final_confidence = unsafe_confidence_score if is_blocked else safe_confidence_score | |
| alert_banner = "๐จ [BLOCK] CHILDSHIELD AI INFERENCE REPORT" if is_blocked else "โ [PASS] CHILDSHIELD AI INFERENCE REPORT" | |
| print(f"\n================ {alert_banner} ================") | |
| print(f" Received Original Text:\n\"{text.strip()}\"") | |
| print(f"\n Preprocessed Cleaned Text:\n\"{cleaned_text}\"") | |
| print(f"\n Total Page Tokens Count : {total_tokens_count}") | |
| print(f" Total Sliding Windows Run : {total_windows_count} Windows (Size: 60, Overlap: 20)") | |
| print(f" Final Security Verdict : {final_prediction}") | |
| print(f" Model Decision Confidence : {formatted_confidence}") | |
| print(f" Triggered Windows ID : {triggered_windows}") | |
| print("\n --- Windows Detailed Semantic Analysis Table ---") | |
| print(tabulate(windows_table_data, headers=["ID", "Window Text Preview", "Safe Prob", "Unsafe Prob", "Verdict"], tablefmt="grid")) | |
| print("========================================================================\n") | |
| print("\n๐ === Windows Text Inspection (ุงููุตูุต ุงููุงู ูุฉ ูุชูุณูู ุงูููุงูุฐ) ===") | |
| for window_log in full_windows_text_list: | |
| print(window_log) | |
| print("========================================================================\n") | |
| return { | |
| "original_text": text, | |
| "cleaned_text": cleaned_text, | |
| "total_tokens": total_tokens_count, | |
| "window_size": window_size, | |
| "overlap": overlap, | |
| "total_windows": total_windows_count, | |
| "triggered_windows": triggered_windows, | |
| "windows_analysis": windows_analysis, | |
| "final_prediction": final_prediction, | |
| "blocked": is_blocked, | |
| "highest_unsafe_confidence": round(highest_unsafe_prob, 4), | |
| "confidence": formatted_confidence | |
| } | |
| def predict(data: InputData): | |
| result = predict_safety_api(data.text) | |
| return result | |
| gradio_interface = gr.Interface( | |
| fn=predict_safety_api, | |
| inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."), | |
| outputs=gr.JSON(label="Guard Response Object"), | |
| title="ChildShield Production API Gate (Arabic Version)๐ก๏ธ" | |
| ) | |
| app = gr.mount_gradio_app(app, gradio_interface, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |