MedCodeMCP / src /app.py
gpaasch's picture
Merge remote-tracking branch 'competition/main'
77640a8
raw
history blame
2.41 kB
import os
import gradio as gr
import openai
from transformers import pipeline
from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
from llama_index import HuggingFaceLLMPredictor
from src.parse_tabular import symptom_index
# --- Whisper ASR setup ---
asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=0,
chunk_length_s=30,
)
# --- LlamaIndex utils import ---
from utils.llama_index_utils import get_llm_predictor, build_index, query_symptoms
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. “Is your cough dry or productive?”)
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
def transcribe_and_respond(audio_chunk, state):
# Transcribe audio chunk
result = asr(audio_chunk)
text = result.get('text', '').strip()
if not text:
return state, []
# Append user message
state.append(("user", text))
# Build LLM predictor (you can swap OpenAI / HuggingFace here)
llm_predictor = HuggingFaceLLMPredictor(model_name_or_path=os.getenv("HF_MODEL", "gpt2-medium"))
# Query index with conversation
# (Assuming `symptom_index` is your GPTVectorStoreIndex)
# Prepare combined prompt from state
prompt = "\n".join([f"{role}: {msg}" for role, msg in state])
response = symptom_index.as_query_engine(
llm_predictor=llm_predictor
).query(prompt)
reply = response.response
# Append assistant message
state.append(("assistant", reply))
# Return updated state to chatbot
return state, state
# Build Gradio interface
demo = gr.Blocks()
with demo:
gr.Markdown("# Symptom to ICD-10 Code Lookup (Audio Input)")
chatbot = gr.Chatbot(label="Conversation")
state = gr.State([])
# Use streaming audio input for real-time transcription
mic = gr.Audio(source="microphone", type="filepath", streaming=True, label="Describe your symptoms")
mic.stream(
fn=transcribe_and_respond,
inputs=[mic, state],
outputs=[chatbot, state],
time_limit=60,
stream_every=5,
concurrency_limit=1
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", server_port=7860, mcp_server=True
)