CLOUDYUL commited on
Commit
2db78dd
ยท
1 Parent(s): 369cc34

Add Gradio app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # serve-gradio/app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # โ”€โ”€โ”€ ๋ชจ๋ธ ๋กœ๋“œ โ”€โ”€โ”€
8
+ MODEL_ID = "CLOUDYUL/cleaner-detector" # ์ด๋ฏธ Hugging Face Hub์— ์˜ฌ๋ผ๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ
9
+ device = torch.device("cpu")
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
13
+ model.to(device)
14
+ model.eval()
15
+
16
+ def predict_toxicity(texts):
17
+ """
18
+ texts: ๋‹จ์ผ ๋ฌธ์ž์—ด ํ˜น์€ ๋ฌธ์ž์—ด ๋ฆฌ์ŠคํŠธ
19
+ ๋ฐ˜ํ™˜: [
20
+ { "text": "์ž…๋ ฅ ๋ฌธ์žฅ", "label": 0 or 1, "score": ํ™•๋ฅ (float) },
21
+ โ€ฆ
22
+ ]
23
+ """
24
+ if isinstance(texts, str):
25
+ texts = [texts]
26
+ results = []
27
+ for t in texts:
28
+ # ํ† ํฐํ™”
29
+ encoding = tokenizer(
30
+ t,
31
+ truncation=True,
32
+ padding="max_length",
33
+ max_length=128,
34
+ return_attention_mask=True,
35
+ return_tensors="pt",
36
+ )
37
+ input_ids = encoding["input_ids"].to(device)
38
+ attention_mask = encoding["attention_mask"].to(device)
39
+
40
+ # ๋ชจ๋ธ ์ถ”๋ก 
41
+ with torch.no_grad():
42
+ logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[0]
43
+ # ์†Œํ”„ํŠธ๋งฅ์Šค๋กœ ํ™•๋ฅ  ๊ณ„์‚ฐ
44
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()
45
+ label = int(probs.index(max(probs))) # 0: ์ •์ƒ, 1: ์•…ํ”Œ
46
+ score = float(max(probs))
47
+ results.append({"text": t, "label": label, "score": score})
48
+ return results
49
+
50
+ # โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€
51
+ demo = gr.Interface(
52
+ fn=predict_toxicity,
53
+ inputs=gr.Textbox(lines=2, placeholder="์—ฌ๊ธฐ์— ํ…Œ์ŠคํŠธ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”"),
54
+ outputs=gr.JSON(label="Predictions"),
55
+ title="AGaRiCleaner Toxicity Detector",
56
+ description="๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด ์•…ํ”Œ ์—ฌ๋ถ€(label=0 ๋˜๋Š” 1)์™€ ํ™•๋ฅ (score)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()