MedCodeMCP / src /app.py
gpaasch's picture
should resolve font-related console errors
112a6dc
raw
history blame
12.2 kB
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
import gradio as gr
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.llama_cpp import LlamaCPP
from .parse_tabular import create_symptom_index # Use relative import
import json
import psutil
from typing import Tuple, Dict
import torch
from gtts import gTTS
import io
import base64
# Model options mapped to their requirements
MODEL_OPTIONS = {
"tiny": {
"name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf",
"repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
"vram_req": 2, # GB
"ram_req": 4 # GB
},
"small": {
"name": "phi-2.Q4_K_M.gguf",
"repo": "TheBloke/phi-2-GGUF",
"vram_req": 4,
"ram_req": 8
},
"medium": {
"name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf",
"repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"vram_req": 6,
"ram_req": 16
}
}
def get_system_specs() -> Dict[str, float]:
"""Get system specifications."""
# Get RAM
ram_gb = psutil.virtual_memory().total / (1024**3)
# Get GPU info if available
gpu_vram_gb = 0
if torch.cuda.is_available():
try:
# Query GPU memory in bytes and convert to GB
gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except Exception as e:
print(f"Warning: Could not get GPU memory: {e}")
return {
"ram_gb": ram_gb,
"gpu_vram_gb": gpu_vram_gb
}
def select_best_model() -> Tuple[str, str]:
"""Select the best model based on system specifications."""
specs = get_system_specs()
print(f"\nSystem specifications:")
print(f"RAM: {specs['ram_gb']:.1f} GB")
print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB")
# Prioritize GPU if available
if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work
model_tier = "small" # phi-2 should work well on RTX 2060
elif specs['ram_gb'] >= 8:
model_tier = "small"
else:
model_tier = "tiny"
selected = MODEL_OPTIONS[model_tier]
print(f"\nSelected model tier: {model_tier}")
print(f"Model: {selected['name']}")
return selected['name'], selected['repo']
# Set up model paths
MODEL_NAME, REPO_ID = select_best_model()
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
from typing import Optional
def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str:
"""Ensures model is available, downloading only if needed."""
# Determine environment and set cache directory
if os.path.exists("/home/user"):
# HF Space environment
cache_dir = "/home/user/.cache/models"
else:
# Local development environment
cache_dir = os.path.join(BASE_DIR, "models")
# Create cache directory if it doesn't exist
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as e:
print(f"Warning: Could not create cache directory {cache_dir}: {e}")
# Fall back to temporary directory if needed
cache_dir = os.path.join("/tmp", "models")
os.makedirs(cache_dir, exist_ok=True)
# Get model details
if not model_name or not repo_id:
model_option = MODEL_OPTIONS["small"] # default to small model
model_name = model_option["name"]
repo_id = model_option["repo"]
# Ensure model_name and repo_id are not None
if model_name is None:
raise ValueError("model_name cannot be None")
if repo_id is None:
raise ValueError("repo_id cannot be None")
# Check if model already exists in cache
model_path = os.path.join(cache_dir, model_name)
if os.path.exists(model_path):
print(f"\nUsing cached model: {model_path}")
return model_path
print(f"\nDownloading model {model_name} from {repo_id}...")
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=model_name,
cache_dir=cache_dir,
local_dir=cache_dir
)
print(f"Model downloaded successfully to {model_path}")
return model_path
except Exception as e:
print(f"Error downloading model: {str(e)}")
raise
# Ensure model is downloaded
model_path = ensure_model()
# Configure local LLM with LlamaCPP
print("\nInitializing LLM...")
llm = LlamaCPP(
model_path=model_path,
temperature=0.7,
max_new_tokens=256,
context_window=2048
)
print("LLM initialized successfully")
# Configure global settings
print("\nConfiguring settings...")
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
print("Settings configured")
# Create the index at startup
print("\nCreating symptom index...")
symptom_index = create_symptom_index()
print("Index created successfully")
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
def process_speech(audio_path, history):
"""Process speech input and convert to text."""
try:
if not audio_path:
return []
# The audio_path now contains the transcribed text directly from Gradio
transcript = audio_path
# Query the symptom index
diagnosis_query = f"""
Given these symptoms: '{transcript}'
Identify the most likely ICD-10 diagnoses and key questions to differentiate between them.
Focus only on symptoms mentioned and their clinical implications.
"""
response = symptom_index.as_query_engine().query(diagnosis_query)
# Format response
formatted_response = {
"diagnoses": [],
"confidences": [],
"follow_up": str(response)
}
return [
{"role": "user", "content": transcript},
{"role": "assistant", "content": json.dumps(formatted_response)}
]
except Exception as e:
print(f"Error processing speech: {e}")
return []
def text_to_speech(text):
"""Convert text to speech and return audio HTML element."""
tts = gTTS(text=text, lang='en')
audio_fp = io.BytesIO()
tts.write_to_fp(audio_fp)
audio_b64 = base64.b64encode(audio_fp.getvalue()).decode()
return f'<audio src="data:audio/mp3;base64,{audio_b64}" autoplay></audio>'
def format_response_for_user(response_dict):
"""Convert JSON response to user-friendly format."""
diagnoses = response_dict.get("diagnoses", [])
confidences = response_dict.get("confidences", [])
follow_up = response_dict.get("follow_up", "")
message = ""
if diagnoses and confidences:
for d, c in zip(diagnoses, confidences):
conf_percent = int(c * 100)
message += f"Possible diagnosis ({conf_percent}% confidence): {d}\n"
if follow_up:
message += f"\n{follow_up}"
return message
# Build enhanced Gradio interface
with gr.Blocks(
theme="default",
css="""
* {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI',
Roboto, Ubuntu, 'Helvetica Neue', Arial, sans-serif;
}
code, pre {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
'Liberation Mono', 'Courier New', monospace;
}
"""
) as demo:
gr.Markdown("""
# 🏥 Medical Symptom to ICD-10 Code Assistant
## About
This application is part of the Agents+MCP Hackathon. It helps medical professionals
and patients understand potential diagnoses based on described symptoms.
### How it works:
1. Click the microphone button and describe your symptoms
2. The AI will analyze your description and suggest possible diagnoses
3. Answer follow-up questions to refine the diagnosis
""")
with gr.Row():
with gr.Column(scale=2):
# Moved microphone row above chatbot
with gr.Row():
microphone = gr.Audio(
type="filepath", # Use filepath to get the audio file path
label="Describe your symptoms",
streaming=True
)
transcript_box = gr.Textbox(
label="Transcribed Text",
interactive=False,
show_label=True
)
clear_btn = gr.Button("Clear Chat", variant="secondary")
chatbot = gr.Chatbot(
label="Medical Consultation",
height=500,
container=True,
type="messages" # This is now properly supported by our message format
)
with gr.Column(scale=1):
with gr.Accordion("Advanced Settings", open=False):
api_key = gr.Textbox(
label="OpenAI API Key (optional)",
type="password",
placeholder="sk-..."
)
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
label="Model Tier",
value="small",
interactive=True
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
label="Temperature"
)
# Event handlers
clear_btn.click(lambda: None, None, chatbot, queue=False)
def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7):
"""Handle speech processing and chat formatting."""
if not audio_path:
return history
# Process the new audio input
new_messages = process_speech(audio_path, history)
if not new_messages:
return history
try:
# Format last assistant response
assistant_response = new_messages[-1]["content"]
response_dict = json.loads(assistant_response)
formatted_text = format_response_for_user(response_dict)
# Add to history with proper message format
return history + [
{"role": "user", "content": new_messages[0]["content"]},
{"role": "assistant", "content": formatted_text}
]
except Exception as e:
print(f"Error formatting response: {e}")
return history
microphone.stream(
fn=enhanced_process_speech,
inputs=[
microphone,
chatbot,
api_key,
model_selector,
temperature
],
outputs=chatbot,
show_progress="hidden"
)
# Add footer with social links
gr.Markdown("""
---
### 👋 About the Creator
Hi! I'm Graham Paasch, an experienced technology professional!
🎥 **Check out my YouTube channel** for more tech content:
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
💼 **Looking for a skilled developer?**
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
⭐ If you found this tool helpful, please consider:
- Subscribing to my YouTube channel
- Connecting on LinkedIn
- Sharing this tool with others in healthcare tech
""")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True
)