bookdabang commited on
Commit
8038f86
ยท
verified ยท
1 Parent(s): 5aa25da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -8,30 +8,31 @@ model_name = "SAVSNET/PetBERT_ICD"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
11
- # 2. ๋ผ๋ฒจ ๋ชฉ๋ก (ICD ์ฝ”๋“œ ๋“ฑ์€ ์˜ˆ์‹œ)
12
- LABELS = {
13
- 0: "์†Œํ™”๊ธฐ ์งˆํ™˜",
14
- 1: "ํ˜ธํก๊ธฐ ์งˆํ™˜",
15
- 2: "ํ”ผ๋ถ€ ์งˆํ™˜",
16
- 3: "์ •์‹  ํ–‰๋™ ์ด์ƒ",
17
- 4: "๊ฐ์—ผ์„ฑ ์งˆํ™˜",
18
- 5: "๊ธฐํƒ€ ์งˆ๋ณ‘"
19
- }
20
-
21
- # 3. ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
 
 
22
  def predict(text):
23
  try:
24
- # ๋ฒˆ์—ญ (ํ•œ๊ธ€ -> ์˜์–ด)
25
  translated = GoogleTranslator(source='auto', target='en').translate(text)
26
 
27
- # ํ† ํฐํ™” ๋ฐ ๋ชจ๋ธ ์˜ˆ์ธก
28
  inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
- logits = outputs.logits
32
- probs = torch.softmax(logits, dim=1).squeeze()
33
 
34
- # ์˜ˆ์ธก ๊ฒฐ๊ณผ ์ƒ์œ„ 3๊ฐœ ์ถ”์ถœ
35
  topk = torch.topk(probs, 3)
36
  results = [
37
  f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
@@ -40,23 +41,28 @@ def predict(text):
40
  ]
41
 
42
  if results:
43
- summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘:\n" + "\n".join(results)
44
  else:
45
- summary_text = "์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ ๐Ÿซฅ"
46
-
47
- return summary_text
48
 
49
  except Exception as e:
50
  return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
51
 
52
- # 4. Gradio UI ๊ตฌ์„ฑ
53
  demo = gr.Interface(
54
  fn=predict,
55
  inputs=gr.Textbox(label="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ์ž…๋ ฅ", placeholder="์˜ˆ: ๊ฐ•์•„์ง€๊ฐ€ ์ž์ฃผ ๊ธฐ์นจํ•ด"),
56
  outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
57
  title="๐Ÿพ PetBERT ICD ์ˆ˜์˜์‚ฌ ์˜ˆ์ธก๊ธฐ",
58
- description="๋ฐ˜๋ ค๋™๋ฌผ์˜ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด AI๊ฐ€ ์งˆ๋ณ‘ ๊ฐ€๋Šฅ์„ฑ์„ ์˜ˆ์ธกํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค."
 
 
59
  )
60
 
61
- # 5. ์•ฑ ์‹คํ–‰
62
- demo.launch()
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
11
+ # 2. ๋ชจ๋ธ์˜ id2label์—์„œ ๋ผ๋ฒจ ๋ชฉ๋ก ์ถ”์ถœ
12
+ raw_labels = model.config.id2label # e.g., {'0': 'Digestive', '1': 'Respiratory', ...}
13
+ LABELS = {int(k): v for k, v in raw_labels.items()}
14
+
15
+ # 3. ์ฝ˜์†”์— ์ถœ๋ ฅ
16
+ print("๐Ÿ” PetBERT ICD ๋ผ๋ฒจ ๋ชฉ๋ก:")
17
+ for i in range(len(LABELS)):
18
+ print(f"{i}: {LABELS[i]}")
19
+
20
+ # 4. ํ™”๋ฉด์— ์ถœ๋ ฅํ•  ๋ผ๋ฒจ ๋ชฉ๋ก ํ…์ŠคํŠธ ์ค€๋น„
21
+ label_info = "\n".join([f"**{i}**: {LABELS[i]}" for i in range(len(LABELS))])
22
+
23
+ # 5. ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
24
  def predict(text):
25
  try:
26
+ # ํ•œ๊ธ€ โ†’ ์˜์–ด ๋ฒˆ์—ญ
27
  translated = GoogleTranslator(source='auto', target='en').translate(text)
28
 
29
+ # ์˜ˆ์ธก
30
  inputs = tokenizer(translated, return_tensors="pt", truncation=True, padding=True)
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
+ probs = torch.softmax(outputs.logits, dim=1).squeeze()
 
34
 
35
+ # ์ƒ์œ„ 3๊ฐœ ๊ฒฐ๊ณผ ์ถ”์ถœ
36
  topk = torch.topk(probs, 3)
37
  results = [
38
  f"{LABELS.get(int(idx), f'Label {idx}')} ({prob:.1%})"
 
41
  ]
42
 
43
  if results:
44
+ return "์˜ˆ์ธก๋œ ์งˆ๋ณ‘:\n" + "\n".join(results)
45
  else:
46
+ return "์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ ๐Ÿซฅ"
 
 
47
 
48
  except Exception as e:
49
  return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
50
 
51
+ # 6. Gradio UI ๊ตฌ์„ฑ
52
  demo = gr.Interface(
53
  fn=predict,
54
  inputs=gr.Textbox(label="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ์ž…๋ ฅ", placeholder="์˜ˆ: ๊ฐ•์•„์ง€๊ฐ€ ์ž์ฃผ ๊ธฐ์นจํ•ด"),
55
  outputs=gr.Textbox(label="์˜ˆ์ธก ๊ฒฐ๊ณผ"),
56
  title="๐Ÿพ PetBERT ICD ์ˆ˜์˜์‚ฌ ์˜ˆ์ธก๊ธฐ",
57
+ description="๋ฐ˜๋ ค๋™๋ฌผ์˜ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด AI๊ฐ€ ์งˆ๋ณ‘ ๊ฐ€๋Šฅ์„ฑ์„ ์˜ˆ์ธกํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค.",
58
+ examples=["๊ณ ์–‘์ด๊ฐ€ ๋ฐฅ์„ ์•ˆ ๋จน๊ณ  ์ž๊พธ ํ† ํ•ด์š”", "๊ฐ•์•„์ง€๊ฐ€ ์ˆจ์„ ํ—๋–ก์ด๊ณ  ๊ธฐ์นจ์„ ํ•ด์š”"],
59
+ live=False
60
  )
61
 
62
+ # 7. Launch + ๋ผ๋ฒจ ๋ชฉ๋ก ํ•จ๊ป˜ ์ถœ๋ ฅ
63
+ with gr.Blocks() as app:
64
+ gr.Markdown("### ๐Ÿ“‹ PetBERT ICD ์งˆ๋ณ‘ ๋ผ๋ฒจ ๋ชฉ๋ก")
65
+ gr.Markdown(label_info)
66
+ demo.render()
67
+
68
+ app.launch()