Spaces:
Runtime error
Runtime error
File size: 13,354 Bytes
2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 8d22bb4 2464a55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 | import gradio as gr
import torch
import numpy as np
import io
from scipy.io.wavfile import write
from transformers import pipeline
import time
from typing import Dict, List, Tuple
# --- TTS Engine ---
class FreeVoiceTTS:
def __init__(self):
self.model = None
self.device = "cpu"
self.sample_rate = 24000
def load_silero_tts(self):
"""Load Silero TTS - lightweight and reliable"""
try:
torch.set_num_threads(4)
model, example_text = torch.hub.load(
repo_or_dir='snakers4/silero-models',
model='silero_tts',
language='en',
speaker='v3_en'
)
self.silero_model = model
return True
except Exception as e:
print(f"Silero TTS loading failed: {e}")
return False
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
"""Convert text to speech, returning (sample_rate, audio_numpy)"""
try:
if not hasattr(self, 'silero_model'):
if not self.load_silero_tts():
return None
# Generate audio using Silero
audio = self.silero_model.apply_tts(
text=text,
speaker='en_0', # English female voice
sample_rate=self.sample_rate
)
# Convert to numpy array for Gradio
# Silero returns a torch tensor, we convert to numpy
return (self.sample_rate, audio.numpy())
except Exception as e:
print(f"Silero TTS failed: {e}")
return None
# --- STT Engine ---
class SpeechToText:
def __init__(self):
self.transcriber = None
def load_model(self):
try:
self.transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
return True
except Exception as e:
print(f"STT loading failed: {e}")
return False
def transcribe(self, audio_path: str) -> str:
if not self.transcriber:
self.load_model()
if not audio_path:
return ""
try:
result = self.transcriber(audio_path)
return result["text"]
except Exception as e:
print(f"Transcription failed: {e}")
return ""
# --- Application Logic ---
# Initialize Engines
tts_engine = FreeVoiceTTS()
stt_engine = SpeechToText()
# Pre-load models
print("Loading AI Models...")
tts_engine.load_silero_tts()
stt_engine.load_model()
print("Models Loaded.")
QUESTION_BANK = {
"upper_limb": [
{
"question": "Describe the course and distribution of the median nerve from its origin to the hand.",
"key_points": ["brachial plexus roots C5-T1", "medial and lateral cords", "carpal tunnel", "LOAF muscles"],
"follow_up": "What clinical condition results from median nerve compression at the wrist?",
"difficulty": "medium"
},
{
"question": "Explain the brachial plexus in detail, including its major branches.",
"key_points": ["roots, trunks, divisions, cords, branches", "mnemonic: Real Texans Drink Cold Beer", "musculocutaneous, axillary, radial, median, ulnar nerves"],
"follow_up": "Which cord of the brachial plexus is most vulnerable in shoulder dislocations?",
"difficulty": "hard"
},
{
"question": "What are the muscles of the rotator cuff and their functions?",
"key_points": ["supraspinatus", "infraspinatus", "teres minor", "subscapularis", "SITS mnemonic"],
"follow_up": "Which rotator cuff muscle is most commonly injured?",
"difficulty": "medium"
}
],
"lower_limb": [
{
"question": "Trace the course of the sciatic nerve from its origin to its terminal branches.",
"key_points": ["L4-S3 roots", "passes through greater sciatic foramen", "divides into tibial and common fibular nerves", "innervates hamstrings"],
"follow_up": "What are the clinical manifestations of sciatic nerve injury?",
"difficulty": "medium"
},
{
"question": "Describe the boundaries and contents of the femoral triangle.",
"key_points": ["inguinal ligament", "sartorius", "adductor longus", "femoral nerve, artery, vein", "NAVY arrangement"],
"follow_up": "Why is the femoral triangle important clinically?",
"difficulty": "medium"
}
],
"cardiology": [
{
"question": "Describe the blood supply to the heart and the coronary circulation.",
"key_points": ["left and right coronary arteries", "circumflex artery", "left anterior descending", "coronary sinus"],
"follow_up": "Which coronary artery is most commonly involved in myocardial infarction?",
"difficulty": "medium"
},
{
"question": "Explain the conduction system of the heart.",
"key_points": ["SA node", "AV node", "bundle of His", "bundle branches", "Purkinje fibers"],
"follow_up": "What is the clinical significance of the AV node?",
"difficulty": "hard"
}
],
"neuroanatomy": [
{
"question": "Describe the blood supply of the brain.",
"key_points": ["internal carotid arteries", "vertebral arteries", "circle of Willis", "anterior, middle, posterior cerebral arteries"],
"follow_up": "What is the clinical consequence of middle cerebral artery occlusion?",
"difficulty": "hard"
},
{
"question": "Name the twelve cranial nerves and their basic functions.",
"key_points": ["olfactory, optic, oculomotor, trochlear, trigeminal, abducens, facial, vestibulocochlear, glossopharyngeal, vagus, accessory, hypoglossal"],
"follow_up": "Which cranial nerve has the longest intracranial course?",
"difficulty": "medium"
}
]
}
def start_session(topic):
if not topic:
return (
None,
[],
"Please select a topic first.",
gr.update(visible=False),
gr.update(visible=True)
)
session_state = {
"topic": topic,
"question_index": 0,
"score": 0,
"history": [],
"current_question_data": QUESTION_BANK[topic][0]
}
first_question = session_state["current_question_data"]["question"]
# Generate audio for first question
audio = tts_engine.text_to_speech(first_question)
return (
session_state,
[(None, first_question)], # Chat history
f"Topic: {topic.replace('_', ' ').title()}",
gr.update(visible=True), # Show session
gr.update(visible=False), # Hide topic selection
audio # Auto-play question
)
def process_response(audio_input, text_input, session_state, history):
if not session_state:
return session_state, history, "Error: No active session", None, None
# Determine user answer (Audio takes precedence)
user_answer = ""
if audio_input:
user_answer = stt_engine.transcribe(audio_input)
elif text_input:
user_answer = text_input
if not user_answer:
return session_state, history, "", None, None # No input
# Evaluate Answer
question_data = session_state["current_question_data"]
score, feedback = evaluate_answer(user_answer, question_data)
# Update State
session_state["score"] += score
session_state["history"].append({
"question": question_data["question"],
"answer": user_answer,
"feedback": feedback,
"score": score
})
# Update Chat History
history.append((user_answer, feedback))
# Prepare Next Question
session_state["question_index"] += 1
topic_questions = QUESTION_BANK[session_state["topic"]]
next_audio = None
if session_state["question_index"] < len(topic_questions):
next_question_data = topic_questions[session_state["question_index"]]
session_state["current_question_data"] = next_question_data
next_q_text = next_question_data["question"]
history.append((None, next_q_text))
# Generate audio for next question
next_audio = tts_engine.text_to_speech(next_q_text)
else:
# End of session
final_score = session_state["score"]
count = len(topic_questions)
avg = final_score / count if count > 0 else 0
end_msg = f"Session Complete! Final Score: {final_score:.1f}/{count*10} (Avg: {avg:.1f})"
history.append((None, end_msg))
next_audio = tts_engine.text_to_speech(end_msg)
session_state = None # Reset state
return (
session_state,
history,
"", # Clear text input
None, # Clear audio input
next_audio
)
def evaluate_answer(answer: str, question_data: Dict) -> Tuple[float, str]:
"""Simple keyword matching evaluation"""
answer_lower = answer.lower()
key_points = question_data["key_points"]
covered_points = sum(1 for point in key_points if any(word in answer_lower for word in point.lower().split()))
score = min(10, (covered_points / len(key_points)) * 10)
if score >= 8:
feedback = f"Excellent! {question_data.get('follow_up', '')}"
elif score >= 5:
feedback = f"Good. You missed some details. {question_data.get('follow_up', '')}"
else:
missed = [p for p in key_points if not any(w in answer_lower for w in p.lower().split())]
feedback = f"Key points missed: {', '.join(missed[:2])}. {question_data.get('follow_up', '')}"
return score, feedback
# --- Gradio UI ---
with gr.Blocks(title="Anatomy Viva Voce", theme=gr.themes.Soft()) as demo:
state = gr.State(None) # Session state
gr.Markdown("# π§ Anatomy Viva Voce Simulator")
gr.Markdown("Practice medical anatomy with an AI Professor. Speak or type your answers!")
# Topic Selection View
with gr.Group(visible=True) as topic_view:
gr.Markdown("### Select a Topic to Begin")
with gr.Row():
btn_upper = gr.Button("Upper Limb", variant="primary")
btn_lower = gr.Button("Lower Limb", variant="primary")
btn_cardio = gr.Button("Cardiology", variant="primary")
btn_neuro = gr.Button("Neuroanatomy", variant="primary")
# Session View
with gr.Group(visible=False) as session_view:
session_info = gr.Markdown("Topic: ...")
chatbot = gr.Chatbot(label="Viva Session", height=400)
# Professor Audio Output (Hidden player, auto-played via return)
professor_audio = gr.Audio(label="Professor's Voice", autoplay=True, visible=False)
with gr.Row():
with gr.Column(scale=4):
txt_input = gr.Textbox(
show_label=False,
placeholder="Type your answer here...",
lines=2
)
with gr.Column(scale=1):
audio_input = gr.Audio(
source="microphone",
type="filepath",
label="Voice Answer",
show_label=False
)
with gr.Row():
submit_btn = gr.Button("Submit Answer", variant="primary")
end_btn = gr.Button("End Session", variant="stop")
# Event Handlers
topic_buttons = [btn_upper, btn_lower, btn_cardio, btn_neuro]
topics = ["upper_limb", "lower_limb", "cardiology", "neuroanatomy"]
for btn, topic in zip(topic_buttons, topics):
btn.click(
fn=start_session,
inputs=[gr.State(topic)],
outputs=[state, chatbot, session_info, session_view, topic_view, professor_audio]
)
# Submit via Text or Audio
submit_inputs = [audio_input, txt_input, state, chatbot]
submit_outputs = [state, chatbot, txt_input, audio_input, professor_audio]
submit_btn.click(fn=process_response, inputs=submit_inputs, outputs=submit_outputs)
txt_input.submit(fn=process_response, inputs=submit_inputs, outputs=submit_outputs)
audio_input.change(fn=process_response, inputs=submit_inputs, outputs=submit_outputs) # Auto-submit on stop recording? Maybe better to require button for audio to avoid accidental submits.
# Actually, let's NOT auto-submit audio on change, user might want to re-record.
# But `change` triggers when recording stops. Let's stick to button for now to be safe, or add a specific listener.
# For now, let's keep it simple: User records, then clicks submit.
# Wait, `audio_input.change` is triggered when file is updated.
def reset_ui():
return None, [], gr.update(visible=False), gr.update(visible=True)
end_btn.click(
fn=reset_ui,
inputs=None,
outputs=[state, chatbot, session_view, topic_view]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |