File size: 2,574 Bytes
12c8869
aae2559
a35d9e4
12c8869
 
 
 
 
 
 
be4280c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8038f86
be4280c
12c8869
 
be4280c
12c8869
 
be4280c
12c8869
 
 
be4280c
 
12c8869
be4280c
12c8869
5aa25da
 
12c8869
5aa25da
 
 
 
be4280c
5aa25da
be4280c
 
 
12c8869
 
5aa25da
12c8869
be4280c
7f8f27b
12c8869
 
5aa25da
12c8869
be4280c
7f8f27b
 
be4280c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from deep_translator import GoogleTranslator

# 1. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model_name = "SAVSNET/PetBERT_ICD"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# 2. ๋ผ๋ฒจ ๋ชฉ๋ก (ICD 0~19 โ†’ ํ•œ๊ธ€ ๋ฒˆ์—ญ)
LABELS = {
    0: "๊ท€ ๋˜๋Š” ์œ ์–‘๋Œ๊ธฐ ์งˆํ™˜",
    1: "์ •์‹ , ํ–‰๋™ ๋˜๋Š” ์‹ ๊ฒฝ๋ฐœ๋‹ฌ ์žฅ์• ",
    2: "ํ˜ˆ์•ก ๋˜๋Š” ์กฐํ˜ˆ๊ธฐ๊ด€ ์งˆํ™˜",
    3: "์ˆœํ™˜๊ธฐ๊ณ„ ์งˆํ™˜",
    4: "์น˜๊ณผ ์งˆํ™˜",
    5: "๋ฐœ๋‹ฌ ์ด์ƒ",
    6: "์†Œํ™”๊ธฐ๊ณ„ ์งˆํ™˜",
    7: "๋‚ด๋ถ„๋น„, ์˜์–‘ ๋˜๋Š” ๋Œ€์‚ฌ ์งˆํ™˜",
    8: "๋ฉด์—ญ๊ณ„ ์งˆํ™˜",
    9: "ํŠน์ • ๊ฐ์—ผ์„ฑ ๋˜๋Š” ๊ธฐ์ƒ์ถฉ ์งˆํ™˜",
    10: "ํ”ผ๋ถ€ ์งˆํ™˜",
    11: "๊ทผ๊ณจ๊ฒฉ๊ณ„ ๋˜๋Š” ๊ฒฐํ•ฉ์กฐ์ง ์งˆํ™˜",
    12: "์‹ ์ƒ๋ฌผ(์ข…์–‘)",
    13: "์‹ ๊ฒฝ๊ณ„ ์งˆํ™˜",
    14: "์‹œ๊ฐ๊ณ„ ์งˆํ™˜",
    15: "์ฃผ์‚ฐ๊ธฐ ๊ธฐ์›์˜ ํŠน์ • ์ƒํƒœ",
    16: "์ž„์‹ , ์ถœ์‚ฐ ๋˜๋Š” ์‚ฐํ›„๊ธฐ ์ƒํƒœ",
    17: "ํ˜ธํก๊ธฐ๊ณ„ ์งˆํ™˜",
    18: "์™ธ์ƒ, ์ค‘๋… ๋˜๋Š” ์™ธ๋ถ€ ์›์ธ ๊ฒฐ๊ณผ",
    19: "๋น„๋‡จ์ƒ์‹๊ธฐ๊ณ„ ์งˆํ™˜"
}

# 3. ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict(text):
    try:
        # ๋ฒˆ์—ญ (ํ•œ๊ธ€ -> ์˜์–ด)
        translated = GoogleTranslator(source='auto', target='en').translate(text)

        # ํ† ํฐํ™” ๋ฐ ๋ชจ๋ธ ์˜ˆ์ธก
        inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=1).squeeze()

        # ์˜ˆ์ธก ๊ฒฐ๊ณผ ์ƒ์œ„ 3๊ฐœ ์ถ”์ถœ
        topk = torch.topk(probs, 3)
        results = [
            f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
            for idx, prob in zip(topk.indices, topk.values)
            if float(prob) > 0.1
        ]

        if results:
            summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘:\n" + "\n".join(results)
        else:
            summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ ๐Ÿซฅ"

        return summary_text

    except Exception as e:
        return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"

# 4. Gradio UI ๊ตฌ์„ฑ
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ์ž…๋ ฅ", placeholder="์˜ˆ: ๊ฐ•์•„์ง€๊ฐ€ ์ž์ฃผ ๊ธฐ์นจํ•ด"),
    outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
    title="๐Ÿพ PetBERT ICD ์ˆ˜์˜์‚ฌ ์˜ˆ์ธก๊ธฐ",
    description="๋ฐ˜๋ ค๋™๋ฌผ์˜ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด AI๊ฐ€ ์งˆ๋ณ‘ ๊ฐ€๋Šฅ์„ฑ์„ ์˜ˆ์ธกํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค."
)

# 5. ์•ฑ ์‹คํ–‰
demo.launch()