bookdabang commited on
Commit
aae2559
ยท
verified ยท
1 Parent(s): 7f8f27b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -1,26 +1,51 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("SAVSNET/PetBERT_ICD")
6
  model = AutoModelForSequenceClassification.from_pretrained("SAVSNET/PetBERT_ICD")
 
7
 
 
8
  def predict(text):
 
 
9
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
 
10
  with torch.no_grad():
11
  logits = model(**inputs).logits
 
 
12
  probs = torch.sigmoid(logits).squeeze().tolist()
 
13
 
14
  labels = model.config.id2label
15
- result = {labels[str(i)]: f"{p:.1%}" for i, p in enumerate(probs) if p > 0.5}
 
 
 
 
 
 
 
 
 
 
 
16
  return result or {"์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ": "๐Ÿคท"}
17
 
 
18
  demo = gr.Interface(
19
- fn=predict,
20
- inputs="text",
21
- outputs="json", # ๋˜๋Š” "text"
22
- title="๐Ÿพ PetBERT ICD",
23
- description="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด ์งˆ๋ณ‘ ์ฝ”๋“œ๋ณ„๋กœ ์˜ˆ์ธก ํ™•๋ฅ ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค."
24
  )
25
 
26
- demo.launch()
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # โœ… ๋ชจ๋ธ ๋กœ๋”ฉ
6
+ print("๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
7
  tokenizer = AutoTokenizer.from_pretrained("SAVSNET/PetBERT_ICD")
8
  model = AutoModelForSequenceClassification.from_pretrained("SAVSNET/PetBERT_ICD")
9
+ print("โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
10
 
11
+ # โœ… ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
12
  def predict(text):
13
+ print("\n๐Ÿ“ฅ ์ž…๋ ฅ ํ…์ŠคํŠธ:", text)
14
+
15
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
16
+ print("๐Ÿงช Tokenized inputs:", inputs)
17
+
18
  with torch.no_grad():
19
  logits = model(**inputs).logits
20
+ print("๐Ÿ“Š Logits:", logits)
21
+
22
  probs = torch.sigmoid(logits).squeeze().tolist()
23
+ print("๐Ÿ“ˆ ํ™•๋ฅ ๊ฐ’:", probs)
24
 
25
  labels = model.config.id2label
26
+ print("๐Ÿท๏ธ ๋ผ๋ฒจ ๋ชฉ๋ก:", labels)
27
+
28
+ result = {}
29
+ for i, p in enumerate(probs):
30
+ try:
31
+ label = labels[str(i)] # โ— str(i) ์ค‘์š”
32
+ if p > 0.5:
33
+ result[label] = f"{p:.1%}"
34
+ except Exception as e:
35
+ print(f"โŒ ๋ผ๋ฒจ ๋งคํ•‘ ์˜ค๋ฅ˜: {i} โ†’ {e}")
36
+
37
+ print("โœ… ์ตœ์ข… ์˜ˆ์ธก ๊ฒฐ๊ณผ:", result)
38
  return result or {"์˜ˆ์ธก๋œ ์งˆ๋ณ‘ ์—†์Œ": "๐Ÿคท"}
39
 
40
+ # โœ… Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
41
  demo = gr.Interface(
42
+ fn=predict, # ์—ฌ๊ธฐ์„œ predict()๊ฐ€ ํ˜ธ์ถœ๋จ
43
+ inputs=gr.Textbox(label="๋ฐ˜๋ ค๋™๋ฌผ ์ฆ์ƒ ์ž…๋ ฅ", placeholder="์˜ˆ: ๊ธฐ์นจ์„ ์ž์ฃผ ํ•ด์š”"),
44
+ outputs=gr.JSON(label="์˜ˆ์ธก ์งˆ๋ณ‘ ๋ชฉ๋ก"),
45
+ title="๐Ÿพ PetBERT ICD ์ˆ˜์˜์‚ฌ ์˜ˆ์ธก๊ธฐ",
46
+ description="๋ฐ˜๋ ค๋™๋ฌผ์˜ ์ฆ์ƒ ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜๋ฉด AI๊ฐ€ ์งˆ๋ณ‘ ๊ฐ€๋Šฅ์„ฑ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค."
47
  )
48
 
49
+ # โœ… ์•ฑ ์‹คํ–‰
50
+ if __name__ == "__main__":
51
+ demo.launch()