File size: 1,967 Bytes
3ffc83b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import re
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# constants
KModelId = "LikoKIko/OpenCensor-H1-Mini"
KMaxLen = 256
KThreshold = 0.17
KDevice = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_num_threads(max(1, os.cpu_count() or 1))

# load once
tok = AutoTokenizer.from_pretrained(KModelId)
model = AutoModelForSequenceClassification.from_pretrained(
    KModelId, num_labels=1
).to(KDevice).eval()

# warmup to force weights load before
with torch.inference_mode():
    _warm = tok("砖诇讜诐", return_tensors="pt", padding="max_length",
                truncation=True, max_length=KMaxLen).to(KDevice)
    _ = torch.sigmoid(model(**_warm).logits).item()

# helpers
clean = lambda s: re.sub(r"\s+", " ", str(s)).strip()

@torch.inference_mode()
def check(txt: str) -> str:
    txt = clean(txt)
    if not txt:
        return "Type something first."
    batch = tok(
        txt,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=KMaxLen,
    ).to(KDevice)

    prob = torch.sigmoid(model(**batch).logits).item()
    label = 1 if prob >= KThreshold else 0
    return f"Prob: {prob:.4f} | Label: {label} (cutoff={KThreshold})"

# ui
with gr.Blocks(title="Hebrew Profanity Detector") as demo:
    gr.Markdown("## Hebrew Profanity Detector\nEnter Hebrew text. Output: probability and label.")
    inp = gr.Textbox(lines=4, label="Hebrew text")
    out = gr.Textbox(label="Result", interactive=False)
    btn = gr.Button("Check")
    btn.click(check, inputs=inp, outputs=out, api_name="/predict")
    gr.Examples(
        examples=[["讝讛 讚讘专 诪爪讜讬谉"], ["!讬砖 诇讬 讞专讗 讞讘专"]],
        inputs=inp,
        outputs=out,
        fn=check,
        cache_examples=False,
    )

if __name__ == "__main__":
    port = int(os.getenv("PORT", "7860"))
    demo.launch(server_name="0.0.0.0", server_port=port, show_error=True)