kkAsmaa commited on
Commit
a63d299
·
verified ·
1 Parent(s): e6f4d9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -8,7 +8,6 @@ from pydantic import BaseModel
8
  from transformers import BertTokenizer, AutoModelForSequenceClassification
9
  from arabert.preprocess import ArabertPreprocessor
10
 
11
-
12
  MODEL_REPO = "kkAsmaa/ChildShield"
13
  MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
14
  SUB_FOLDER = "ChildShield"
@@ -20,7 +19,6 @@ model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, token=HF_
20
  model.eval()
21
  arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
22
 
23
-
24
  app = FastAPI(title="ChildShield Backend API")
25
 
26
  class InputData(BaseModel):
@@ -104,7 +102,6 @@ def predict_safety_api(text):
104
 
105
  final_prediction = "UNSAFE" if is_blocked else "SAFE"
106
 
107
-
108
  print("\n📊 ===== CHILDSHIELD REPORT =====")
109
  print(f"📥 Original Text:\n{text[:100]}")
110
  print(f"\n🧹 Cleaned Text:\n{cleaned_text[:100]}")
@@ -121,7 +118,23 @@ def predict_safety_api(text):
121
  except Exception as e:
122
  print(f"⚠️ [Logging Warning] Could not write to log file: {e}")
123
 
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return {
126
  "original_text": text,
127
  "cleaned_text": cleaned_text,
@@ -133,16 +146,14 @@ def predict_safety_api(text):
133
  "windows_analysis": windows_analysis,
134
  "final_prediction": final_prediction,
135
  "blocked": is_blocked,
136
- "highest_unsafe_confidence": round(highest_unsafe_prob, 4)
137
  }
138
 
139
-
140
  @app.post("/predict")
141
  def predict(data: InputData):
142
  result = predict_safety_api(data.text)
143
  return result
144
 
145
-
146
  gradio_interface = gr.Interface(
147
  fn=predict_safety_api,
148
  inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
@@ -150,9 +161,7 @@ gradio_interface = gr.Interface(
150
  title="ChildShield Production API Gate (Arabic Version)🛡️"
151
  )
152
 
153
-
154
  app = gr.mount_gradio_app(app, gradio_interface, path="/")
155
 
156
-
157
  if __name__ == "__main__":
158
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  from transformers import BertTokenizer, AutoModelForSequenceClassification
9
  from arabert.preprocess import ArabertPreprocessor
10
 
 
11
  MODEL_REPO = "kkAsmaa/ChildShield"
12
  MODEL_NAME = "aubmindlab/bert-base-arabertv02-twitter"
13
  SUB_FOLDER = "ChildShield"
 
19
  model.eval()
20
  arabic_prep = ArabertPreprocessor(model_name=MODEL_NAME)
21
 
 
22
  app = FastAPI(title="ChildShield Backend API")
23
 
24
  class InputData(BaseModel):
 
102
 
103
  final_prediction = "UNSAFE" if is_blocked else "SAFE"
104
 
 
105
  print("\n📊 ===== CHILDSHIELD REPORT =====")
106
  print(f"📥 Original Text:\n{text[:100]}")
107
  print(f"\n🧹 Cleaned Text:\n{cleaned_text[:100]}")
 
118
  except Exception as e:
119
  print(f"⚠️ [Logging Warning] Could not write to log file: {e}")
120
 
121
+ unsafe_confidence_score = round(highest_unsafe_prob, 4)
122
+ safe_confidence_score = round(1.0 - highest_unsafe_prob, 4)
123
+ if is_blocked:
124
+ return {
125
+ "original_text": text,
126
+ "cleaned_text": cleaned_text,
127
+ "total_tokens": total_tokens_count,
128
+ "window_size": window_size,
129
+ "overlap": overlap,
130
+ "total_windows": total_windows_count,
131
+ "triggered_windows": triggered_windows,
132
+ "windows_analysis": windows_analysis,
133
+ "final_prediction": final_prediction,
134
+ "blocked": is_blocked,
135
+ "confidence": unsafe_confidence_score
136
+ }
137
+
138
  return {
139
  "original_text": text,
140
  "cleaned_text": cleaned_text,
 
146
  "windows_analysis": windows_analysis,
147
  "final_prediction": final_prediction,
148
  "blocked": is_blocked,
149
+ "confidence": safe_confidence_score
150
  }
151
 
 
152
  @app.post("/predict")
153
  def predict(data: InputData):
154
  result = predict_safety_api(data.text)
155
  return result
156
 
 
157
  gradio_interface = gr.Interface(
158
  fn=predict_safety_api,
159
  inputs=gr.Textbox(lines=4, placeholder="Enter Arabic text to analyze..."),
 
161
  title="ChildShield Production API Gate (Arabic Version)🛡️"
162
  )
163
 
 
164
  app = gr.mount_gradio_app(app, gradio_interface, path="/")
165
 
 
166
  if __name__ == "__main__":
167
  uvicorn.run(app, host="0.0.0.0", port=7860)