Abaryan commited on
Commit
8c5712e
·
verified ·
1 Parent(s): dee81c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -54
app.py CHANGED
@@ -6,11 +6,10 @@ import random
6
  import re
7
 
8
  # Load model and tokenizer
9
- # model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
10
  model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
11
 
12
  SYSTEM_PROMPT = """
13
- You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating.
14
  Respond in the following format:
15
  <answer>
16
  [correct answer]
@@ -45,15 +44,32 @@ def get_random_question():
45
  question_data.get('exp', None) # Explanation
46
  )
47
 
48
- def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
49
  correct_option: int = None, explanation: str = None,
50
  temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
51
- # Format the question with options
52
- formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Create chat-style prompt
55
  prompt = [
56
- {'role': 'system', 'content': SYSTEM_PROMPT},
57
  {'role': 'user', 'content': formatted_question}
58
  ]
59
 
@@ -69,20 +85,23 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
69
  max_new_tokens=max_tokens,
70
  temperature=temperature,
71
  top_p=top_p,
72
- # repetition_penalty=1.1,
73
  )
74
 
75
- # Get only the generated response (excluding the prompt)
76
  generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
77
  model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
78
 
79
- # Format output with evaluation if available
80
- output = model_response
 
 
81
 
82
- if correct_option is not None:
83
- correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
84
- # Extract answer from model response for evaluation
85
- answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE)
 
 
86
  model_answer = answer_match.group(1).upper() if answer_match else "Not found"
87
 
88
  is_correct = model_answer == correct_letter
@@ -95,61 +114,106 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
95
 
96
  return output
97
 
98
- # Create Gradio interface with Blocks for more control
99
- with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
100
- gr.Markdown("# Medical-QA (MedMCQA) Predictor")
101
- gr.Markdown("Get a random medical question or enter your own question and options.")
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  with gr.Row():
104
- with gr.Column():
105
- # Input fields
106
- question = gr.Textbox(label="Question", lines=3, interactive=True)
 
 
 
 
 
 
107
 
108
- # Options in an expandable accordion
109
- with gr.Accordion("Options", open=False):
110
- option_a = gr.Textbox(label="Option A", interactive=True)
111
- option_b = gr.Textbox(label="Option B", interactive=True)
112
- option_c = gr.Textbox(label="Option C", interactive=True)
113
- option_d = gr.Textbox(label="Option D", interactive=True)
114
-
115
- # Generation parameters
116
- with gr.Accordion("Generation Parameters", open=False):
117
- temperature = gr.Slider(
118
- minimum=0.1,
119
- maximum=1.0,
120
- value=0.6,
121
- step=0.1,
122
- label="Temperature",
123
- info="Higher values make output more random, lower values more focused"
124
  )
125
- top_p = gr.Slider(
126
- minimum=0.1,
127
- maximum=1.0,
128
- value=0.9,
129
- step=0.1,
130
- label="Top P",
131
- info="Higher values allow more diverse tokens, lower values more focused"
132
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  max_tokens = gr.Slider(
134
  minimum=50,
135
  maximum=512,
136
  value=256,
137
  step=32,
138
- label="Max Tokens",
139
- info="Maximum length of the generated response (recommended: 256)"
140
  )
141
 
142
- # Hidden fields for correct answer and explanation
143
  correct_option = gr.Number(visible=False)
144
  expert_explanation = gr.Textbox(visible=False)
145
 
146
- # Buttons
147
  with gr.Row():
148
- predict_btn = gr.Button("Predict", variant="primary")
149
- random_btn = gr.Button("Get Random Question", variant="secondary")
150
-
151
- # Output
152
- output = gr.Textbox(label="Model's Answer", lines=10)
 
 
 
 
 
153
 
154
  # Set up button actions
155
  predict_btn.click(
@@ -168,6 +232,76 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
168
  outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
169
  )
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Launch the app
172
  if __name__ == "__main__":
173
- demo.launch()
 
6
  import re
7
 
8
  # Load model and tokenizer
 
9
  model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
10
 
11
  SYSTEM_PROMPT = """
12
+ You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 200 words without repeating.
13
  Respond in the following format:
14
  <answer>
15
  [correct answer]
 
44
  question_data.get('exp', None) # Explanation
45
  )
46
 
47
+ def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "",
48
  correct_option: int = None, explanation: str = None,
49
  temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
50
+
51
+ # Determine if this is an MCQ by checking if any option is provided
52
+ # Only treat as MCQ if at least one option is non-empty
53
+ is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d])
54
+
55
+ if is_mcq:
56
+ # Format MCQ question with only non-empty options
57
+ options = []
58
+ if option_a.strip(): options.append(f"A. {option_a}")
59
+ if option_b.strip(): options.append(f"B. {option_b}")
60
+ if option_c.strip(): options.append(f"C. {option_c}")
61
+ if option_d.strip(): options.append(f"D. {option_d}")
62
+
63
+ formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options)
64
+ system_prompt = SYSTEM_PROMPT
65
+ else:
66
+ # Format regular question
67
+ formatted_question = f"Question: {question}"
68
+ system_prompt = SYSTEM_PROMPT
69
 
70
  # Create chat-style prompt
71
  prompt = [
72
+ {'role': 'system', 'content': system_prompt},
73
  {'role': 'user', 'content': formatted_question}
74
  ]
75
 
 
85
  max_new_tokens=max_tokens,
86
  temperature=temperature,
87
  top_p=top_p,
 
88
  )
89
 
90
+ # Get only the generated response
91
  generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
92
  model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
93
 
94
+ # Clean up the response by removing tags and formatting
95
+ cleaned_response = model_response
96
+ cleaned_response = re.sub(r'<answer>\s*([A-D])\s*</answer>', r'Answer: \1', cleaned_response, flags=re.IGNORECASE)
97
+ cleaned_response = re.sub(r'<reasoning>\s*(.*?)\s*</reasoning>', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL)
98
 
99
+ # Format output with evaluation if available (only for MCQs)
100
+ output = cleaned_response
101
+
102
+ if is_mcq and correct_option is not None:
103
+ correct_letter = chr(65 + correct_option)
104
+ answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE)
105
  model_answer = answer_match.group(1).upper() if answer_match else "Not found"
106
 
107
  is_correct = model_answer == correct_letter
 
114
 
115
  return output
116
 
117
+ # Create Gradio interface with mobile-optimized design
118
+ with gr.Blocks(
119
+ title="BioXP Medical MCQ Assistant",
120
+ theme=gr.themes.Soft(
121
+ primary_hue="blue",
122
+ secondary_hue="blue",
123
+ neutral_hue="slate",
124
+ radius_size="md",
125
+ font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
126
+ )
127
+ ) as demo:
128
+ gr.Markdown("""
129
+ # BioXP Medical MCQ Assistant
130
+ A specialized AI assistant for medical multiple-choice questions.
131
+ """)
132
 
133
  with gr.Row():
134
+ with gr.Column(scale=1):
135
+ # Input fields with mobile-friendly spacing
136
+ question = gr.Textbox(
137
+ label="Medical Question",
138
+ placeholder="Enter your medical question here...",
139
+ lines=3,
140
+ interactive=True,
141
+ elem_classes=["mobile-input"]
142
+ )
143
 
144
+ # Options in a mobile-friendly accordion
145
+ with gr.Accordion("Options", open=True):
146
+ option_a = gr.Textbox(
147
+ label="Option A",
148
+ placeholder="Enter option A...",
149
+ interactive=True,
150
+ elem_classes=["mobile-input"]
151
+ )
152
+ option_b = gr.Textbox(
153
+ label="Option B",
154
+ placeholder="Enter option B...",
155
+ interactive=True,
156
+ elem_classes=["mobile-input"]
 
 
 
157
  )
158
+ option_c = gr.Textbox(
159
+ label="Option C",
160
+ placeholder="Enter option C...",
161
+ interactive=True,
162
+ elem_classes=["mobile-input"]
 
 
163
  )
164
+ option_d = gr.Textbox(
165
+ label="Option D",
166
+ placeholder="Enter option D...",
167
+ interactive=True,
168
+ elem_classes=["mobile-input"]
169
+ )
170
+
171
+ # Generation parameters in a collapsible section
172
+ with gr.Accordion("Advanced Settings", open=False):
173
+ with gr.Row():
174
+ with gr.Column(scale=1):
175
+ temperature = gr.Slider(
176
+ minimum=0.1,
177
+ maximum=1.0,
178
+ value=0.6,
179
+ step=0.1,
180
+ label="Temperature",
181
+ info="Higher = more creative, Lower = more focused"
182
+ )
183
+ with gr.Column(scale=1):
184
+ top_p = gr.Slider(
185
+ minimum=0.1,
186
+ maximum=1.0,
187
+ value=0.9,
188
+ step=0.1,
189
+ label="Top P",
190
+ info="Controls response diversity"
191
+ )
192
  max_tokens = gr.Slider(
193
  minimum=50,
194
  maximum=512,
195
  value=256,
196
  step=32,
197
+ label="Max Response Length",
198
+ info="Maximum length of the response"
199
  )
200
 
201
+ # Hidden fields
202
  correct_option = gr.Number(visible=False)
203
  expert_explanation = gr.Textbox(visible=False)
204
 
205
+ # Buttons with mobile-friendly spacing
206
  with gr.Row():
207
+ predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"])
208
+ random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"])
209
+
210
+ with gr.Column(scale=1):
211
+ # Output with mobile-friendly styling
212
+ output = gr.Textbox(
213
+ label="Model's Response",
214
+ lines=12,
215
+ elem_classes=["response-box", "mobile-output"]
216
+ )
217
 
218
  # Set up button actions
219
  predict_btn.click(
 
232
  outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
233
  )
234
 
235
+ # Add mobile-optimized CSS
236
+ gr.HTML("""
237
+ <style>
238
+ /* Mobile-friendly base styles */
239
+ .container {
240
+ max-width: 100%;
241
+ padding: 0.5rem;
242
+ }
243
+
244
+ /* Input styling */
245
+ .mobile-input textarea {
246
+ font-size: 1rem;
247
+ padding: 0.75rem;
248
+ border-radius: 0.5rem;
249
+ min-height: 2.5rem;
250
+ }
251
+
252
+ /* Button styling */
253
+ .mobile-button {
254
+ width: 100%;
255
+ margin: 0.5rem 0;
256
+ padding: 0.75rem;
257
+ font-size: 1rem;
258
+ font-weight: 500;
259
+ }
260
+
261
+ /* Response box styling */
262
+ .response-box {
263
+ font-family: 'Inter', sans-serif;
264
+ line-height: 1.6;
265
+ }
266
+ .response-box textarea {
267
+ font-size: 1rem;
268
+ padding: 1rem;
269
+ border-radius: 0.5rem;
270
+ }
271
+
272
+ /* Mobile-specific adjustments */
273
+ @media (max-width: 768px) {
274
+ .gr-form {
275
+ padding: 0.75rem;
276
+ }
277
+ .gr-box {
278
+ margin: 0.5rem 0;
279
+ }
280
+ .gr-button {
281
+ min-height: 2.5rem;
282
+ }
283
+ .gr-accordion {
284
+ margin: 0.5rem 0;
285
+ }
286
+ .gr-input {
287
+ margin-bottom: 0.5rem;
288
+ }
289
+ }
290
+
291
+ /* Dark mode support */
292
+ @media (prefers-color-scheme: dark) {
293
+ .gr-box {
294
+ background-color: #1a1a1a;
295
+ }
296
+ .mobile-input textarea,
297
+ .response-box textarea {
298
+ background-color: #2a2a2a;
299
+ color: #ffffff;
300
+ }
301
+ }
302
+ </style>
303
+ """)
304
+
305
  # Launch the app
306
  if __name__ == "__main__":
307
+ demo.launch(share=False)