File size: 1,970 Bytes
2db78dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# serve-gradio/app.py

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# โ”€โ”€โ”€ ๋ชจ๋ธ ๋กœ๋“œ โ”€โ”€โ”€
MODEL_ID = "CLOUDYUL/cleaner-detector"  # ์ด๋ฏธ Hugging Face Hub์— ์˜ฌ๋ผ๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ
device = torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()

def predict_toxicity(texts):
    """
    texts: ๋‹จ์ผ ๋ฌธ์ž์—ด ํ˜น์€ ๋ฌธ์ž์—ด ๋ฆฌ์ŠคํŠธ
    ๋ฐ˜ํ™˜: [
      { "text": "์ž…๋ ฅ ๋ฌธ์žฅ", "label": 0 or 1, "score": ํ™•๋ฅ (float) },
      โ€ฆ
    ]
    """
    if isinstance(texts, str):
        texts = [texts]
    results = []
    for t in texts:
        # ํ† ํฐํ™”
        encoding = tokenizer(
            t,
            truncation=True,
            padding="max_length",
            max_length=128,
            return_attention_mask=True,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)

        # ๋ชจ๋ธ ์ถ”๋ก 
        with torch.no_grad():
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[0]
        # ์†Œํ”„ํŠธ๋งฅ์Šค๋กœ ํ™•๋ฅ  ๊ณ„์‚ฐ
        probs = torch.softmax(logits, dim=-1).cpu().tolist()
        label = int(probs.index(max(probs)))  # 0: ์ •์ƒ, 1: ์•…ํ”Œ
        score = float(max(probs))
        results.append({"text": t, "label": label, "score": score})
    return results

# โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€
demo = gr.Interface(
    fn=predict_toxicity,
    inputs=gr.Textbox(lines=2, placeholder="์—ฌ๊ธฐ์— ํ…Œ์ŠคํŠธ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”"),
    outputs=gr.JSON(label="Predictions"),
    title="AGaRiCleaner Toxicity Detector",
    description="๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด ์•…ํ”Œ ์—ฌ๋ถ€(label=0 ๋˜๋Š” 1)์™€ ํ™•๋ฅ (score)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."
)

if __name__ == "__main__":
    demo.launch()