tmt3103 commited on
Commit
767d296
·
verified ·
1 Parent(s): 4eb8b78

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # Load model
7
+ model_name = "tmt3103/BioASQ-yesno-PudMedBERT"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+
11
+ def predict_yesno(context, question):
12
+ inputs = tokenizer.encode_plus(
13
+ question,
14
+ context,
15
+ return_tensors="pt",
16
+ truncation=True,
17
+ max_length=512,
18
+ padding="max_length"
19
+ )
20
+
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ logits = outputs.logits
24
+ probs = F.softmax(logits, dim=1).squeeze()
25
+ pred_id = logits.argmax().item()
26
+
27
+ label = "Yes" if pred_id == 1 else "No"
28
+ return f"{label} (Confidence: {probs[pred_id]:.2f})"
29
+
30
+ # Gradio
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("#BioASQ Yes/No Question Answering")
33
+ gr.Markdown("""
34
+ This demo uses a fine-tuned BERT model to answer biomedical yes/no questions based on context.<br>
35
+ **Instructions**:
36
+ 1. Paste the context (e.g., PubMed abstract).
37
+ 2. Type your yes/no question.
38
+ 3. Click 'Predict' to get the answer.
39
+ """)
40
+
41
+ with gr.Row():
42
+ with gr.Column():
43
+ context_input = gr.Textbox(label="Context", lines=8, placeholder="Paste biomedical context here...")
44
+ question_input = gr.Textbox(label="Yes/No Question", lines=2, placeholder="Enter your question here...")
45
+ predict_button = gr.Button("Predict")
46
+ with gr.Column():
47
+ output = gr.Textbox(label="Prediction")
48
+
49
+ predict_button.click(
50
+ fn=predict_yesno,
51
+ inputs=[context_input, question_input],
52
+ outputs=output
53
+ )
54
+
55
+ demo.launch()