import os import gradio as gr import openai from transformers import pipeline from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex from llama_index import HuggingFaceLLMPredictor from src.parse_tabular import symptom_index # --- Whisper ASR setup --- asr = pipeline( "automatic-speech-recognition", model="openai/whisper-small", device=0, chunk_length_s=30, ) # --- LlamaIndex utils import --- from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms # --- System prompt --- SYSTEM_PROMPT = """ You are a medical assistant helping a user narrow down to the most likely ICD-10 code. At each turn, EITHER ask one focused clarifying question (e.g. “Is your cough dry or productive?”) or, if you have enough info, output a final JSON with fields: {"diagnoses":[…], "confidences":[…]}. """ def transcribe_and_respond(audio_chunk, state): # Transcribe audio chunk result = asr(audio_chunk) text = result.get('text', '').strip() if not text: return state, [] # Append user message state.append(("user", text)) # Build LLM predictor (you can swap OpenAI / HuggingFace here) llm_predictor = HuggingFaceLLMPredictor(model_name_or_path=os.getenv("HF_MODEL", "gpt2-medium")) # Query index with conversation # (Assuming `symptom_index` is your GPTVectorStoreIndex) # Prepare combined prompt from state prompt = "\n".join([f"{role}: {msg}" for role, msg in state]) response = symptom_index.as_query_engine( llm_predictor=llm_predictor ).query(prompt) reply = response.response # Append assistant message state.append(("assistant", reply)) # Return updated state to chatbot return state, state # Build Gradio interface demo = gr.Blocks() with demo: gr.Markdown("# Symptom to ICD-10 Code Lookup (Audio Input)") chatbot = gr.Chatbot(label="Conversation") state = gr.State([]) # Use streaming audio input for real-time transcription mic = gr.Audio(source="microphone", type="filepath", streaming=True, label="Describe your symptoms") mic.stream( fn=transcribe_and_respond, inputs=[mic, state], outputs=[chatbot, state], time_limit=60, stream_every=5, concurrency_limit=1 ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, mcp_server=True )