Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,30 +8,31 @@ model_name = "SAVSNET/PetBERT_ICD"
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 10 |
|
| 11 |
-
# 2. ๋ผ๋ฒจ ๋ชฉ๋ก
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
def predict(text):
|
| 23 |
try:
|
| 24 |
-
#
|
| 25 |
translated = GoogleTranslator(source='auto', target='en').translate(text)
|
| 26 |
|
| 27 |
-
#
|
| 28 |
inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
|
| 29 |
with torch.no_grad():
|
| 30 |
outputs = model(**inputs)
|
| 31 |
-
|
| 32 |
-
probs = torch.softmax(logits, dim=1).squeeze()
|
| 33 |
|
| 34 |
-
#
|
| 35 |
topk = torch.topk(probs, 3)
|
| 36 |
results = [
|
| 37 |
f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
|
|
@@ -40,23 +41,28 @@ def predict(text):
|
|
| 40 |
]
|
| 41 |
|
| 42 |
if results:
|
| 43 |
-
|
| 44 |
else:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
return summary_text
|
| 48 |
|
| 49 |
except Exception as e:
|
| 50 |
return f"์ค๋ฅ ๋ฐ์: {str(e)}"
|
| 51 |
|
| 52 |
-
#
|
| 53 |
demo = gr.Interface(
|
| 54 |
fn=predict,
|
| 55 |
inputs=gr.Textbox(label="๋ฐ๋ ค๋๋ฌผ ์ฆ์ ์
๋ ฅ", placeholder="์: ๊ฐ์์ง๊ฐ ์์ฃผ ๊ธฐ์นจํด"),
|
| 56 |
outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"),
|
| 57 |
title="๐พ PetBERT ICD ์์์ฌ ์์ธก๊ธฐ",
|
| 58 |
-
description="๋ฐ๋ ค๋๋ฌผ์ ์ฆ์ ๋ฌธ์ฅ์ ์
๋ ฅํ๋ฉด AI๊ฐ ์ง๋ณ ๊ฐ๋ฅ์ฑ์ ์์ธกํด๋๋ฆฝ๋๋ค."
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 10 |
|
| 11 |
+
# 2. ๋ชจ๋ธ์ id2label์์ ๋ผ๋ฒจ ๋ชฉ๋ก ์ถ์ถ
|
| 12 |
+
raw_labels = model.config.id2label # e.g., {'0': 'Digestive', '1': 'Respiratory', ...}
|
| 13 |
+
LABELS = {int(k): v for k, v in raw_labels.items()}
|
| 14 |
+
|
| 15 |
+
# 3. ์ฝ์์ ์ถ๋ ฅ
|
| 16 |
+
print("๐ PetBERT ICD ๋ผ๋ฒจ ๋ชฉ๋ก:")
|
| 17 |
+
for i in range(len(LABELS)):
|
| 18 |
+
print(f"{i}: {LABELS[i]}")
|
| 19 |
+
|
| 20 |
+
# 4. ํ๋ฉด์ ์ถ๋ ฅํ ๋ผ๋ฒจ ๋ชฉ๋ก ํ
์คํธ ์ค๋น
|
| 21 |
+
label_info = "\n".join([f"**{i}**: {LABELS[i]}" for i in range(len(LABELS))])
|
| 22 |
+
|
| 23 |
+
# 5. ์์ธก ํจ์ ์ ์
|
| 24 |
def predict(text):
|
| 25 |
try:
|
| 26 |
+
# ํ๊ธ โ ์์ด ๋ฒ์ญ
|
| 27 |
translated = GoogleTranslator(source='auto', target='en').translate(text)
|
| 28 |
|
| 29 |
+
# ์์ธก
|
| 30 |
inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
|
| 31 |
with torch.no_grad():
|
| 32 |
outputs = model(**inputs)
|
| 33 |
+
probs = torch.softmax(outputs.logits, dim=1).squeeze()
|
|
|
|
| 34 |
|
| 35 |
+
# ์์ 3๊ฐ ๊ฒฐ๊ณผ ์ถ์ถ
|
| 36 |
topk = torch.topk(probs, 3)
|
| 37 |
results = [
|
| 38 |
f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
|
|
|
|
| 41 |
]
|
| 42 |
|
| 43 |
if results:
|
| 44 |
+
return "์์ธก๋ ์ง๋ณ:\n" + "\n".join(results)
|
| 45 |
else:
|
| 46 |
+
return "์์ธก๋ ์ง๋ณ ์์ ๐ซฅ"
|
|
|
|
|
|
|
| 47 |
|
| 48 |
except Exception as e:
|
| 49 |
return f"์ค๋ฅ ๋ฐ์: {str(e)}"
|
| 50 |
|
| 51 |
+
# 6. Gradio UI ๊ตฌ์ฑ
|
| 52 |
demo = gr.Interface(
|
| 53 |
fn=predict,
|
| 54 |
inputs=gr.Textbox(label="๋ฐ๋ ค๋๋ฌผ ์ฆ์ ์
๋ ฅ", placeholder="์: ๊ฐ์์ง๊ฐ ์์ฃผ ๊ธฐ์นจํด"),
|
| 55 |
outputs=gr.Textbox(label="์์ธก ๊ฒฐ๊ณผ"),
|
| 56 |
title="๐พ PetBERT ICD ์์์ฌ ์์ธก๊ธฐ",
|
| 57 |
+
description="๋ฐ๋ ค๋๋ฌผ์ ์ฆ์ ๋ฌธ์ฅ์ ์
๋ ฅํ๋ฉด AI๊ฐ ์ง๋ณ ๊ฐ๋ฅ์ฑ์ ์์ธกํด๋๋ฆฝ๋๋ค.",
|
| 58 |
+
examples=["๊ณ ์์ด๊ฐ ๋ฐฅ์ ์ ๋จน๊ณ ์๊พธ ํ ํด์", "๊ฐ์์ง๊ฐ ์จ์ ํ๋ก์ด๊ณ ๊ธฐ์นจ์ ํด์"],
|
| 59 |
+
live=False
|
| 60 |
)
|
| 61 |
|
| 62 |
+
# 7. Launch + ๋ผ๋ฒจ ๋ชฉ๋ก ํจ๊ป ์ถ๋ ฅ
|
| 63 |
+
with gr.Blocks() as app:
|
| 64 |
+
gr.Markdown("### ๐ PetBERT ICD ์ง๋ณ ๋ผ๋ฒจ ๋ชฉ๋ก")
|
| 65 |
+
gr.Markdown(label_info)
|
| 66 |
+
demo.render()
|
| 67 |
+
|
| 68 |
+
app.launch()
|