File size: 8,258 Bytes
0f5ffe4
 
 
698b30e
0f5ffe4
 
 
 
 
0c698cb
0f5ffe4
0c698cb
 
0f5ffe4
 
1294687
 
dd2abf1
0c698cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f5ffe4
 
0c698cb
0f5ffe4
 
 
 
0c698cb
0f5ffe4
 
0c698cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f5ffe4
1294687
0c698cb
 
 
 
 
 
98ea36e
 
0c698cb
 
cf0cedb
0c698cb
 
 
 
 
 
 
 
 
698b30e
 
 
 
 
 
0c698cb
 
 
 
 
 
698b30e
0c698cb
698b30e
1294687
 
0c698cb
1294687
0c698cb
 
 
1294687
0c698cb
 
 
 
dd2abf1
1294687
0c698cb
1294687
 
0f5ffe4
 
0c698cb
0f5ffe4
 
 
0c698cb
 
 
 
 
 
 
0f5ffe4
0c698cb
 
 
 
1294687
dd2abf1
1294687
0c698cb
 
1294687
0c698cb
0f5ffe4
 
 
0c698cb
 
0f5ffe4
0c698cb
 
 
 
 
 
 
 
0f5ffe4
0c698cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f5ffe4
 
 
 
efdf187
052069b
0c698cb
 
 
 
 
0f5ffe4
0c698cb
 
 
 
 
 
 
 
 
 
 
0f5ffe4
dd2abf1
63c9bcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import json
import numpy as np
import textwrap
from tokenizers import Tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import gradio as gr


class ONNXInferencePipeline:
    def __init__(self, repo_id, repo_type="model"):
        # Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel.
        hf_token = os.getenv("HF_TOKEN")

        # Load banned keywords list
        self.banned_keywords = self.load_banned_keywords()
        print(f"Loaded {len(self.banned_keywords)} banned keywords")

        # Download artifacts. Newer huggingface_hub uses token=, not use_auth_token=
        self.onnx_path = hf_hub_download(
            repo_id=repo_id,
            filename="model.onnx",
            token=hf_token,
            repo_type=repo_type
        )
        self.tokenizer_path = hf_hub_download(
            repo_id=repo_id,
            filename="train_bpe_tokenizer.json",
            token=hf_token,
            repo_type=repo_type
        )
        self.config_path = hf_hub_download(
            repo_id=repo_id,
            filename="hyperparameters.json",
            token=hf_token,
            repo_type=repo_type
        )

        # Load configuration
        with open(self.config_path, "r") as f:
            self.config = json.load(f)

        # Initialize tokenizer
        self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
        self.max_len = int(self.config.get("max_len", 256))

        # Initialize ONNX runtime session
        # Spaces CPU runtime typically uses CPUExecutionProvider
        providers = ort.get_available_providers()
        if "CUDAExecutionProvider" in providers:
            use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
        else:
            use_providers = ["CPUExecutionProvider"]

        sess_options = ort.SessionOptions()
        # Reduce memory and improve cold start a bit
        sess_options.enable_mem_pattern = False
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

        self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers)

        # Cache model input name to avoid mismatches like input vs input_ids
        self.input_name = self.session.get_inputs()[0].name
        print(f"ONNX model input name detected: {self.input_name}")

        # If you want label order from config, you can read it
        self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"])

    def load_banned_keywords(self):
        """
        Load banned keywords from env var named 'banned'.
        Supports two formats:
          1) Python code snippet that returns a list (your current method)
          2) JSON array of strings
        """
        code_str = os.getenv("banned")
        if not code_str:
            print("Environment variable 'banned' is not set. Using empty list.")
            return []

        # Try JSON first
        try:
            parsed = json.loads(code_str)
            if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
                return parsed
        except Exception:
            pass

        # Fallback to executable code that returns a list
        local_vars = {}
        wrapped_code = f"""
def get_banned_keywords():
{textwrap.indent(code_str, '    ')}
"""
        try:
            exec(wrapped_code, {}, local_vars)
            result = local_vars["get_banned_keywords"]()
            if isinstance(result, list):
                return [str(x) for x in result]
            print("Loaded banned keywords code did not return a list. Using empty list.")
            return []
        except Exception as e:
            print(f"Error loading banned keywords from code: {e}")
            return []

    def contains_banned_keyword(self, text):
        """Check if the input text contains any banned keywords as whole words."""
        text_lower = text.lower()
        words = "".join(c if c.isalnum() else " " for c in text_lower).split()
        word_set = set(words)

        for keyword in self.banned_keywords:
            kw = str(keyword).lower().strip()
            if not kw:
                continue
            if kw in word_set:
                print(f"Keyword detected: '{keyword}'")
                return True
        print("Keywords Passed. No inappropriate keywords found")
        return False

    def preprocess(self, text):
        encoding = self.tokenizer.encode(text)
        ids = encoding.ids[: self.max_len]
        padding = [0] * (self.max_len - len(ids))
        return np.array(ids + padding, dtype=np.int64).reshape(1, -1)

    @staticmethod
    def softmax(logits):
        # Numerically stable softmax
        x = logits - np.max(logits, axis=1, keepdims=True)
        e = np.exp(x)
        return e / np.sum(e, axis=1, keepdims=True)

    def predict(self, text):
        snippet = text[:50].replace("\n", " ")
        print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)")

        # First rule based filter
        if self.contains_banned_keyword(text):
            print("Input rejected by keyword filter")
            return {
                "label": self.class_labels[0],
                "probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels),
            }

        # Preprocess
        input_array = self.preprocess(text)

        # Run inference. Use detected input name
        outputs = self.session.run(None, {self.input_name: input_array})

        # Post process
        logits = outputs[0]
        probs = self.softmax(logits)
        pred_idx = int(np.argmax(probs))
        label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx)

        print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})")
        return {"label": label, "probabilities": probs[0].tolist()}


# Gradio glue
def gradio_predict(text):
    result = PIPELINE.predict(text)
    return f"Prediction: {result['label']}\n"


# Create pipeline at import so the Space is ready
print("Initializing content filter pipeline...")
PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model")
print("Pipeline initialized successfully")


if __name__ == "__main__":
    # Required in Spaces. PORT is injected. Bind to 0.0.0.0
    iface = gr.Interface(
        fn=gradio_predict,
        inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
        outputs="text",
        title="Abuse Detector - Offensive Language Detector",
        description=(
            "Abuse detector identifies inappropriate content in text. "
            "It analyzes input for Australian slang and abusive language. "
            "It is trained on a compact dataset. It may not catch highly nuanced language, "
            "but it detects common day to day offensive language."
        ),
        examples=[
            # Explicitly offensive examples
            "Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?",
            "You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.",
            "Your mother should have done better raising such a useless idiot.",
            # Neutral or appropriate examples
            "Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
            "Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.",
            "The weather today is lovely, and I am looking forward to a productive day at work.",
            # Mixed
            "I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying."
        ],
    )
    print("Launching Gradio interface...")
    iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))