Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # constants | |
| KModelId = "LikoKIko/OpenCensor-H1-Mini" | |
| KMaxLen = 256 | |
| KThreshold = 0.17 | |
| KDevice = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_num_threads(max(1, os.cpu_count() or 1)) | |
| # load once | |
| tok = AutoTokenizer.from_pretrained(KModelId) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| KModelId, num_labels=1 | |
| ).to(KDevice).eval() | |
| # warmup to force weights load before | |
| with torch.inference_mode(): | |
| _warm = tok("砖诇讜诐", return_tensors="pt", padding="max_length", | |
| truncation=True, max_length=KMaxLen).to(KDevice) | |
| _ = torch.sigmoid(model(**_warm).logits).item() | |
| # helpers | |
| clean = lambda s: re.sub(r"\s+", " ", str(s)).strip() | |
| def check(txt: str) -> str: | |
| txt = clean(txt) | |
| if not txt: | |
| return "Type something first." | |
| batch = tok( | |
| txt, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=KMaxLen, | |
| ).to(KDevice) | |
| prob = torch.sigmoid(model(**batch).logits).item() | |
| label = 1 if prob >= KThreshold else 0 | |
| return f"Prob: {prob:.4f} | Label: {label} (cutoff={KThreshold})" | |
| # ui | |
| with gr.Blocks(title="Hebrew Profanity Detector") as demo: | |
| gr.Markdown("## Hebrew Profanity Detector\nEnter Hebrew text. Output: probability and label.") | |
| inp = gr.Textbox(lines=4, label="Hebrew text") | |
| out = gr.Textbox(label="Result", interactive=False) | |
| btn = gr.Button("Check") | |
| btn.click(check, inputs=inp, outputs=out, api_name="/predict") | |
| gr.Examples( | |
| examples=[["讝讛 讚讘专 诪爪讜讬谉"], ["!讬砖 诇讬 讞专讗 讞讘专"]], | |
| inputs=inp, | |
| outputs=out, | |
| fn=check, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=port, show_error=True) |