kkAsmaa commited on
Commit
afe5f61
Β·
verified Β·
1 Parent(s): d4e530e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -43
app.py CHANGED
@@ -2,20 +2,30 @@ import gradio as gr
2
  import re
3
  import os
4
  import torch
 
 
 
5
  from transformers import BertTokenizer, AutoModelForSequenceClassification
6
  from arabert.preprocess import ArabertPreprocessor
7
 
 
8
  MODEL_REPO = "kkAsmaa/ChildShield"
9
  MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
10
  SUB_FOLDER = "ChildShield"
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
- print("πŸ”„ Loading model weights from the secured ChildShield subfolder...")
13
 
 
14
  tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
15
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER)
16
  model.eval()
17
  arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
18
 
 
 
 
 
 
 
19
  def clean_obfuscation(text):
20
  text = str(text)
21
  text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text)
@@ -36,83 +46,105 @@ def predict_safety_api(text):
36
  """
37
  Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap.
38
  """
39
- print(f"[Incoming text to evaluate]: {text}")
40
  cleaned_text = full_preprocess(text)
41
-
42
  full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False)
43
  input_ids = full_encodings['input_ids']
44
-
45
  total_tokens_count = len(input_ids)
46
 
47
  window_size = 60
48
  overlap = 20
49
- windows = []
50
  step = window_size - overlap
 
51
 
52
  if len(input_ids) <= window_size:
53
  windows = [input_ids]
54
  else:
55
  for i in range(0, len(input_ids), step):
56
  window = input_ids[i:i + window_size]
57
- if len(window) > 0: windows.append(window)
58
- if i + window_size >= len(input_ids): break
59
-
 
 
60
  total_windows_count = len(windows)
61
-
62
  is_blocked = False
63
  highest_unsafe_prob = 0.0
64
- triggered_sentences = []
 
65
 
66
- for win_ids in windows:
67
  window_text = tokenizer.decode(win_ids, skip_special_tokens=True)
68
-
69
  inputs = tokenizer(
70
- window_text,
71
- return_tensors="pt",
72
- truncation=True,
73
- padding="max_length",
74
  max_length=60
75
  )
76
-
77
  with torch.no_grad():
78
  outputs = model(**inputs)
79
-
80
  probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist()
81
 
82
- unsafe_p = float(probs[1])
 
 
 
 
 
 
 
 
 
 
83
 
84
- if unsafe_p > 0.50:
85
  is_blocked = True
86
- highest_unsafe_prob = max(highest_unsafe_prob, unsafe_p)
87
- if window_text not in triggered_sentences:
88
- triggered_sentences.append(window_text)
89
-
90
- if is_blocked:
91
- return {
92
- "verdict": "UNSAFE",
93
- "block": True,
94
- "confidence": f"{highest_unsafe_prob * 100:.2f}%",
95
- "total_tokens": total_tokens_count,
96
- "total_windows": total_windows_count,
97
- "triggered_phrases": triggered_sentences
98
- }
 
99
 
100
- safe_p = 1.0 - highest_unsafe_prob
101
  return {
102
- "verdict": "SAFE",
103
- "block": False,
104
- "confidence": f"{safe_p * 100:.2f}%",
105
- "total_tokens": total_tokens_count,
106
- "total_windows": total_windows_count,
107
- "triggered_phrases": []
 
 
 
 
 
108
  }
109
 
110
- interface = gr.Interface(
 
 
 
 
 
 
 
111
  fn=predict_safety_api,
112
- inputs=gr.Textbox(lines=3, placeholder="Enter text to analyze..."),
113
  outputs=gr.JSON(label="Guard Response Object"),
114
  title="ChildShield Production API Gate (Arabic Version)πŸ›‘οΈ"
115
  )
116
 
 
 
 
 
117
  if __name__ == "__main__":
118
- interface.launch()
 
2
  import re
3
  import os
4
  import torch
5
+ import uvicorn
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
  from transformers import BertTokenizer, AutoModelForSequenceClassification
9
  from arabert.preprocess import ArabertPreprocessor
10
 
11
+
12
  MODEL_REPO = "kkAsmaa/ChildShield"
13
  MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
14
  SUB_FOLDER = "ChildShield"
15
  HF_TOKEN = os.getenv("HF_TOKEN")
 
16
 
17
+ print("πŸ”„ Loading ChildShield Model Weights...")
18
  tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
19
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_TOKEN, subfolder=SUB_FOLDER)
20
  model.eval()
21
  arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
22
 
23
+
24
+ app = FastAPI(title="ChildShield Backend API")
25
+
26
+ class InputData(BaseModel):
27
+ text: str
28
+
29
  def clean_obfuscation(text):
30
  text = str(text)
31
  text = re.sub(r'https?://\S+|www\.\S+|@\S+|#', '', text)
 
46
  """
47
  Arabic text classification gateway utilizing a custom sliding window configuration with 20 token overlap.
48
  """
 
49
  cleaned_text = full_preprocess(text)
 
50
  full_encodings = tokenizer(cleaned_text, add_special_tokens=False, return_attention_mask=False)
51
  input_ids = full_encodings['input_ids']
 
52
  total_tokens_count = len(input_ids)
53
 
54
  window_size = 60
55
  overlap = 20
 
56
  step = window_size - overlap
57
+ windows = []
58
 
59
  if len(input_ids) <= window_size:
60
  windows = [input_ids]
61
  else:
62
  for i in range(0, len(input_ids), step):
63
  window = input_ids[i:i + window_size]
64
+ if len(window) > 0:
65
+ windows.append(window)
66
+ if i + window_size >= len(input_ids):
67
+ break
68
+
69
  total_windows_count = len(windows)
 
70
  is_blocked = False
71
  highest_unsafe_prob = 0.0
72
+ windows_analysis = []
73
+ triggered_windows = []
74
 
75
+ for idx, win_ids in enumerate(windows):
76
  window_text = tokenizer.decode(win_ids, skip_special_tokens=True)
 
77
  inputs = tokenizer(
78
+ window_text,
79
+ return_tensors="pt",
80
+ truncation=True,
81
+ padding="max_length",
82
  max_length=60
83
  )
 
84
  with torch.no_grad():
85
  outputs = model(**inputs)
 
86
  probs = torch.softmax(outputs.logits, dim=-1).flatten().tolist()
87
 
88
+ safe_prob = float(probs[0])
89
+ unsafe_prob = float(probs[1])
90
+ prediction = "UNSAFE" if unsafe_prob > 0.50 else "SAFE"
91
+
92
+ windows_analysis.append({
93
+ "window_id": idx + 1,
94
+ "window_text": window_text,
95
+ "safe_probability": round(safe_prob, 4),
96
+ "unsafe_probability": round(unsafe_prob, 4),
97
+ "prediction": prediction
98
+ })
99
 
100
+ if unsafe_prob > 0.50:
101
  is_blocked = True
102
+ highest_unsafe_prob = max(highest_unsafe_prob, unsafe_prob)
103
+ triggered_windows.append(idx + 1)
104
+
105
+ final_prediction = "UNSAFE" if is_blocked else "SAFE"
106
+
107
+
108
+ print("\nπŸ“Š ===== CHILDSHIELD REPORT =====")
109
+ print(f"πŸ“₯ Original Text:\n{text[:100]}")
110
+ print(f"\n🧹 Cleaned Text:\n{cleaned_text[:100]}")
111
+ print(f"\nπŸ”‘ Total Tokens: {total_tokens_count}")
112
+ print(f"πŸͺŸ Total Windows: {total_windows_count}")
113
+ print(f"🚨 Final Verdict: {final_prediction}")
114
+ print(f"πŸ›‘ Triggered Windows ID: {triggered_windows}")
115
+ print("=================================\n")
116
 
 
117
  return {
118
+ "original_text": text,
119
+ "cleaned_text": cleaned_text,
120
+ "total_tokens": total_tokens_count,
121
+ "window_size": window_size,
122
+ "overlap": overlap,
123
+ "total_windows": total_windows_count,
124
+ "triggered_windows": triggered_windows,
125
+ "windows_analysis": windows_analysis,
126
+ "final_prediction": final_prediction,
127
+ "blocked": is_blocked,
128
+ "highest_unsafe_confidence": round(highest_unsafe_prob, 4)
129
  }
130
 
131
+
132
+ @app.post("/predict")
133
+ def predict(data: InputData):
134
+ result = predict_safety_api(data.text)
135
+ return result
136
+
137
+
138
+ gradio_interface = gr.Interface(
139
  fn=predict_safety_api,
140
+ inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
141
  outputs=gr.JSON(label="Guard Response Object"),
142
  title="ChildShield Production API Gate (Arabic Version)πŸ›‘οΈ"
143
  )
144
 
145
+
146
+ app = gr.mount_gradio_app(app, gradio_interface, path="/")
147
+
148
+
149
  if __name__ == "__main__":
150
+ uvicorn.run(app, host="0.0.0.0", port=7860)