petbert / app.py
bookdabang's picture
Update app.py
be4280c verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from deep_translator import GoogleTranslator
# 1. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model_name = "SAVSNET/PetBERT_ICD"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 2. ๋ผ๋ฒจ ๋ชฉ๋ก (ICD 0~19 โ†’ ํ•œ๊ธ€ ๋ฒˆ์—ญ)
LABELS = {
0: "๊ท€ ๋˜๋Š” ์œ ์–‘๋Œ๊ธฐ ์งˆํ™˜",
1: "์ •์‹ , ํ–‰๋™ ๋˜๋Š” ์‹ ๊ฒฝ๋ฐœ๋‹ฌ ์žฅ์• ",
2: "ํ˜ˆ์•ก ๋˜๋Š” ์กฐํ˜ˆ๊ธฐ๊ด€ ์งˆํ™˜",
3: "์ˆœํ™˜๊ธฐ๊ณ„ ์งˆํ™˜",
4: "์น˜๊ณผ ์งˆํ™˜",
5: "๋ฐœ๋‹ฌ ์ด์ƒ",
6: "์†Œํ™”๊ธฐ๊ณ„ ์งˆํ™˜",
7: "๋‚ด๋ถ„๋น„, ์˜์–‘ ๋˜๋Š” ๋Œ€์‚ฌ ์งˆํ™˜",
8: "๋ฉด์—ญ๊ณ„ ์งˆํ™˜",
9: "ํŠน์ • ๊ฐ์—ผ์„ฑ ๋˜๋Š” ๊ธฐ์ƒ์ถฉ ์งˆํ™˜",
10: "ํ”ผ๋ถ€ ์งˆํ™˜",
11: "๊ทผ๊ณจ๊ฒฉ๊ณ„ ๋˜๋Š” ๊ฒฐํ•ฉ์กฐ์ง ์งˆํ™˜",
12: "์‹ ์ƒ๋ฌผ(์ข…์–‘)",
13: "์‹ ๊ฒฝ๊ณ„ ์งˆํ™˜",
14: "์‹œ๊ฐ๊ณ„ ์งˆํ™˜",
15: "์ฃผ์‚ฐ๊ธฐ ๊ธฐ์›์˜ ํŠน์ • ์ƒํƒœ",
16: "์ž„์‹ , ์ถœ์‚ฐ ๋˜๋Š” ์‚ฐํ›„๊ธฐ ์ƒํƒœ",
17: "ํ˜ธํก๊ธฐ๊ณ„ ์งˆํ™˜",
18: "์™ธ์ƒ, ์ค‘๋… ๋˜๋Š” ์™ธ๋ถ€ ์›์ธ ๊ฒฐ๊ณผ",
19: "๋น„๋‡จ์ƒ์‹๊ธฐ๊ณ„ ์งˆํ™˜"
}
# 3. ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict(text):
try:
# ๋ฒˆ์—ญ (ํ•œ๊ธ€ -> ์˜์–ด)
translated = GoogleTranslator(source='auto', target='en').translate(text)
# ํ† ํฐํ™” ๋ฐ ๋ชจ๋ธ ์˜ˆ์ธก
inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1).squeeze()
# ์˜ˆ์ธก ๊ฒฐ๊ณผ ์ƒ์œ„ 3๊ฐœ ์ถ”์ถœ
topk = torch.topk(probs, 3)
results = [
f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
for idx, prob in zip(topk.indices, topk.values)
if float(prob) > 0.1
]
if results:
summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘:\n" + "\n".join(results)
else:
summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ ๐Ÿซฅ"
return summary_text
except Exception as e:
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
# 4. Gradio UI ๊ตฌ์„ฑ
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ์ž…๋ ฅ", placeholder="์˜ˆ: ๊ฐ•์•„์ง€๊ฐ€ ์ž์ฃผ ๊ธฐ์นจํ•ด"),
outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
title="๐Ÿพ PetBERT ICD ์ˆ˜์˜์‚ฌ ์˜ˆ์ธก๊ธฐ",
description="๋ฐ˜๋ ค๋™๋ฌผ์˜ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด AI๊ฐ€ ์งˆ๋ณ‘ ๊ฐ€๋Šฅ์„ฑ์„ ์˜ˆ์ธกํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค."
)
# 5. ์•ฑ ์‹คํ–‰
demo.launch()