Spaces:
Running
Running
File size: 8,258 Bytes
0f5ffe4 698b30e 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 1294687 dd2abf1 0c698cb 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 1294687 0c698cb 98ea36e 0c698cb cf0cedb 0c698cb 698b30e 0c698cb 698b30e 0c698cb 698b30e 1294687 0c698cb 1294687 0c698cb 1294687 0c698cb dd2abf1 1294687 0c698cb 1294687 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 0c698cb 1294687 dd2abf1 1294687 0c698cb 1294687 0c698cb 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 0c698cb 0f5ffe4 efdf187 052069b 0c698cb 0f5ffe4 0c698cb 0f5ffe4 dd2abf1 63c9bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import json
import numpy as np
import textwrap
from tokenizers import Tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import gradio as gr
class ONNXInferencePipeline:
def __init__(self, repo_id, repo_type="model"):
# Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel.
hf_token = os.getenv("HF_TOKEN")
# Load banned keywords list
self.banned_keywords = self.load_banned_keywords()
print(f"Loaded {len(self.banned_keywords)} banned keywords")
# Download artifacts. Newer huggingface_hub uses token=, not use_auth_token=
self.onnx_path = hf_hub_download(
repo_id=repo_id,
filename="model.onnx",
token=hf_token,
repo_type=repo_type
)
self.tokenizer_path = hf_hub_download(
repo_id=repo_id,
filename="train_bpe_tokenizer.json",
token=hf_token,
repo_type=repo_type
)
self.config_path = hf_hub_download(
repo_id=repo_id,
filename="hyperparameters.json",
token=hf_token,
repo_type=repo_type
)
# Load configuration
with open(self.config_path, "r") as f:
self.config = json.load(f)
# Initialize tokenizer
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
self.max_len = int(self.config.get("max_len", 256))
# Initialize ONNX runtime session
# Spaces CPU runtime typically uses CPUExecutionProvider
providers = ort.get_available_providers()
if "CUDAExecutionProvider" in providers:
use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
use_providers = ["CPUExecutionProvider"]
sess_options = ort.SessionOptions()
# Reduce memory and improve cold start a bit
sess_options.enable_mem_pattern = False
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers)
# Cache model input name to avoid mismatches like input vs input_ids
self.input_name = self.session.get_inputs()[0].name
print(f"ONNX model input name detected: {self.input_name}")
# If you want label order from config, you can read it
self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"])
def load_banned_keywords(self):
"""
Load banned keywords from env var named 'banned'.
Supports two formats:
1) Python code snippet that returns a list (your current method)
2) JSON array of strings
"""
code_str = os.getenv("banned")
if not code_str:
print("Environment variable 'banned' is not set. Using empty list.")
return []
# Try JSON first
try:
parsed = json.loads(code_str)
if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
return parsed
except Exception:
pass
# Fallback to executable code that returns a list
local_vars = {}
wrapped_code = f"""
def get_banned_keywords():
{textwrap.indent(code_str, ' ')}
"""
try:
exec(wrapped_code, {}, local_vars)
result = local_vars["get_banned_keywords"]()
if isinstance(result, list):
return [str(x) for x in result]
print("Loaded banned keywords code did not return a list. Using empty list.")
return []
except Exception as e:
print(f"Error loading banned keywords from code: {e}")
return []
def contains_banned_keyword(self, text):
"""Check if the input text contains any banned keywords as whole words."""
text_lower = text.lower()
words = "".join(c if c.isalnum() else " " for c in text_lower).split()
word_set = set(words)
for keyword in self.banned_keywords:
kw = str(keyword).lower().strip()
if not kw:
continue
if kw in word_set:
print(f"Keyword detected: '{keyword}'")
return True
print("Keywords Passed. No inappropriate keywords found")
return False
def preprocess(self, text):
encoding = self.tokenizer.encode(text)
ids = encoding.ids[: self.max_len]
padding = [0] * (self.max_len - len(ids))
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
@staticmethod
def softmax(logits):
# Numerically stable softmax
x = logits - np.max(logits, axis=1, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=1, keepdims=True)
def predict(self, text):
snippet = text[:50].replace("\n", " ")
print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)")
# First rule based filter
if self.contains_banned_keyword(text):
print("Input rejected by keyword filter")
return {
"label": self.class_labels[0],
"probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels),
}
# Preprocess
input_array = self.preprocess(text)
# Run inference. Use detected input name
outputs = self.session.run(None, {self.input_name: input_array})
# Post process
logits = outputs[0]
probs = self.softmax(logits)
pred_idx = int(np.argmax(probs))
label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx)
print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})")
return {"label": label, "probabilities": probs[0].tolist()}
# Gradio glue
def gradio_predict(text):
result = PIPELINE.predict(text)
return f"Prediction: {result['label']}\n"
# Create pipeline at import so the Space is ready
print("Initializing content filter pipeline...")
PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model")
print("Pipeline initialized successfully")
if __name__ == "__main__":
# Required in Spaces. PORT is injected. Bind to 0.0.0.0
iface = gr.Interface(
fn=gradio_predict,
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
outputs="text",
title="Abuse Detector - Offensive Language Detector",
description=(
"Abuse detector identifies inappropriate content in text. "
"It analyzes input for Australian slang and abusive language. "
"It is trained on a compact dataset. It may not catch highly nuanced language, "
"but it detects common day to day offensive language."
),
examples=[
# Explicitly offensive examples
"Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?",
"You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.",
"Your mother should have done better raising such a useless idiot.",
# Neutral or appropriate examples
"Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
"Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.",
"The weather today is lovely, and I am looking forward to a productive day at work.",
# Mixed
"I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying."
],
)
print("Launching Gradio interface...")
iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|