Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from llama_index.core import Settings | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.llms.llama_cpp import LlamaCPP | |
| from .parse_tabular import create_symptom_index # Use relative import | |
| import json | |
| import psutil | |
| from typing import Tuple, Dict | |
| import torch | |
| from gtts import gTTS | |
| import io | |
| import base64 | |
| # Model options mapped to their requirements | |
| MODEL_OPTIONS = { | |
| "tiny": { | |
| "name": "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf", | |
| "repo": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", | |
| "vram_req": 2, # GB | |
| "ram_req": 4 # GB | |
| }, | |
| "small": { | |
| "name": "phi-2.Q4_K_M.gguf", | |
| "repo": "TheBloke/phi-2-GGUF", | |
| "vram_req": 4, | |
| "ram_req": 8 | |
| }, | |
| "medium": { | |
| "name": "mistral-7b-instruct-v0.1.Q4_K_M.gguf", | |
| "repo": "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", | |
| "vram_req": 6, | |
| "ram_req": 16 | |
| } | |
| } | |
| def get_system_specs() -> Dict[str, float]: | |
| """Get system specifications.""" | |
| # Get RAM | |
| ram_gb = psutil.virtual_memory().total / (1024**3) | |
| # Get GPU info if available | |
| gpu_vram_gb = 0 | |
| if torch.cuda.is_available(): | |
| try: | |
| # Query GPU memory in bytes and convert to GB | |
| gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| except Exception as e: | |
| print(f"Warning: Could not get GPU memory: {e}") | |
| return { | |
| "ram_gb": ram_gb, | |
| "gpu_vram_gb": gpu_vram_gb | |
| } | |
| def select_best_model() -> Tuple[str, str]: | |
| """Select the best model based on system specifications.""" | |
| specs = get_system_specs() | |
| print(f"\nSystem specifications:") | |
| print(f"RAM: {specs['ram_gb']:.1f} GB") | |
| print(f"GPU VRAM: {specs['gpu_vram_gb']:.1f} GB") | |
| # Prioritize GPU if available | |
| if specs['gpu_vram_gb'] >= 4: # You have 6GB, so this should work | |
| model_tier = "small" # phi-2 should work well on RTX 2060 | |
| elif specs['ram_gb'] >= 8: | |
| model_tier = "small" | |
| else: | |
| model_tier = "tiny" | |
| selected = MODEL_OPTIONS[model_tier] | |
| print(f"\nSelected model tier: {model_tier}") | |
| print(f"Model: {selected['name']}") | |
| return selected['name'], selected['repo'] | |
| # Set up model paths | |
| MODEL_NAME, REPO_ID = select_best_model() | |
| BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "models") | |
| MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) | |
| from typing import Optional | |
| def ensure_model(model_name: Optional[str] = None, repo_id: Optional[str] = None) -> str: | |
| """Ensures model is available, downloading only if needed.""" | |
| # Determine environment and set cache directory | |
| if os.path.exists("/home/user"): | |
| # HF Space environment | |
| cache_dir = "/home/user/.cache/models" | |
| else: | |
| # Local development environment | |
| cache_dir = os.path.join(BASE_DIR, "models") | |
| # Create cache directory if it doesn't exist | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except Exception as e: | |
| print(f"Warning: Could not create cache directory {cache_dir}: {e}") | |
| # Fall back to temporary directory if needed | |
| cache_dir = os.path.join("/tmp", "models") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Get model details | |
| if not model_name or not repo_id: | |
| model_option = MODEL_OPTIONS["small"] # default to small model | |
| model_name = model_option["name"] | |
| repo_id = model_option["repo"] | |
| # Ensure model_name and repo_id are not None | |
| if model_name is None: | |
| raise ValueError("model_name cannot be None") | |
| if repo_id is None: | |
| raise ValueError("repo_id cannot be None") | |
| # Check if model already exists in cache | |
| model_path = os.path.join(cache_dir, model_name) | |
| if os.path.exists(model_path): | |
| print(f"\nUsing cached model: {model_path}") | |
| return model_path | |
| print(f"\nDownloading model {model_name} from {repo_id}...") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=model_name, | |
| cache_dir=cache_dir, | |
| local_dir=cache_dir | |
| ) | |
| print(f"Model downloaded successfully to {model_path}") | |
| return model_path | |
| except Exception as e: | |
| print(f"Error downloading model: {str(e)}") | |
| raise | |
| # Ensure model is downloaded | |
| model_path = ensure_model() | |
| # Configure local LLM with LlamaCPP | |
| print("\nInitializing LLM...") | |
| llm = LlamaCPP( | |
| model_path=model_path, | |
| temperature=0.7, | |
| max_new_tokens=256, | |
| context_window=2048 | |
| ) | |
| print("LLM initialized successfully") | |
| # Configure global settings | |
| print("\nConfiguring settings...") | |
| Settings.llm = llm | |
| Settings.embed_model = HuggingFaceEmbedding( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| print("Settings configured") | |
| # Create the index at startup | |
| print("\nCreating symptom index...") | |
| symptom_index = create_symptom_index() | |
| print("Index created successfully") | |
| # --- System prompt --- | |
| SYSTEM_PROMPT = """ | |
| You are a medical assistant helping a user narrow down to the most likely ICD-10 code. | |
| At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?") | |
| or, if you have enough info, output a final JSON with fields: | |
| {"diagnoses":[…], "confidences":[…]}. | |
| """ | |
| def process_speech(audio_path, history): | |
| """Process speech input and convert to text.""" | |
| try: | |
| if not audio_path: | |
| return [] | |
| # The audio_path now contains the transcribed text directly from Gradio | |
| transcript = audio_path | |
| # Query the symptom index | |
| diagnosis_query = f""" | |
| Given these symptoms: '{transcript}' | |
| Identify the most likely ICD-10 diagnoses and key questions to differentiate between them. | |
| Focus only on symptoms mentioned and their clinical implications. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| # Format response | |
| formatted_response = { | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response) | |
| } | |
| return [ | |
| {"role": "user", "content": transcript}, | |
| {"role": "assistant", "content": json.dumps(formatted_response)} | |
| ] | |
| except Exception as e: | |
| print(f"Error processing speech: {e}") | |
| return [] | |
| def text_to_speech(text): | |
| """Convert text to speech and return audio HTML element.""" | |
| tts = gTTS(text=text, lang='en') | |
| audio_fp = io.BytesIO() | |
| tts.write_to_fp(audio_fp) | |
| audio_b64 = base64.b64encode(audio_fp.getvalue()).decode() | |
| return f'<audio src="data:audio/mp3;base64,{audio_b64}" autoplay></audio>' | |
| def format_response_for_user(response_dict): | |
| """Convert JSON response to user-friendly format.""" | |
| diagnoses = response_dict.get("diagnoses", []) | |
| confidences = response_dict.get("confidences", []) | |
| follow_up = response_dict.get("follow_up", "") | |
| message = "" | |
| if diagnoses and confidences: | |
| for d, c in zip(diagnoses, confidences): | |
| conf_percent = int(c * 100) | |
| message += f"Possible diagnosis ({conf_percent}% confidence): {d}\n" | |
| if follow_up: | |
| message += f"\n{follow_up}" | |
| return message | |
| # Build enhanced Gradio interface | |
| with gr.Blocks( | |
| theme="default", | |
| css=""" | |
| * { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', | |
| Roboto, Ubuntu, 'Helvetica Neue', Arial, sans-serif; | |
| } | |
| code, pre { | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, | |
| 'Liberation Mono', 'Courier New', monospace; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🏥 Medical Symptom to ICD-10 Code Assistant | |
| ## About | |
| This application is part of the Agents+MCP Hackathon. It helps medical professionals | |
| and patients understand potential diagnoses based on described symptoms. | |
| ### How it works: | |
| 1. Click the microphone button and describe your symptoms | |
| 2. The AI will analyze your description and suggest possible diagnoses | |
| 3. Answer follow-up questions to refine the diagnosis | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Moved microphone row above chatbot | |
| with gr.Row(): | |
| microphone = gr.Audio( | |
| type="filepath", # Use filepath to get the audio file path | |
| label="Describe your symptoms", | |
| streaming=True | |
| ) | |
| transcript_box = gr.Textbox( | |
| label="Transcribed Text", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| chatbot = gr.Chatbot( | |
| label="Medical Consultation", | |
| height=500, | |
| container=True, | |
| type="messages" # This is now properly supported by our message format | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| api_key = gr.Textbox( | |
| label="OpenAI API Key (optional)", | |
| type="password", | |
| placeholder="sk-..." | |
| ) | |
| model_selector = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| label="Model Tier", | |
| value="small", | |
| interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| label="Temperature" | |
| ) | |
| # Event handlers | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| def enhanced_process_speech(audio_path, history, api_key=None, model_tier="small", temp=0.7): | |
| """Handle speech processing and chat formatting.""" | |
| if not audio_path: | |
| return history | |
| # Process the new audio input | |
| new_messages = process_speech(audio_path, history) | |
| if not new_messages: | |
| return history | |
| try: | |
| # Format last assistant response | |
| assistant_response = new_messages[-1]["content"] | |
| response_dict = json.loads(assistant_response) | |
| formatted_text = format_response_for_user(response_dict) | |
| # Add to history with proper message format | |
| return history + [ | |
| {"role": "user", "content": new_messages[0]["content"]}, | |
| {"role": "assistant", "content": formatted_text} | |
| ] | |
| except Exception as e: | |
| print(f"Error formatting response: {e}") | |
| return history | |
| microphone.stream( | |
| fn=enhanced_process_speech, | |
| inputs=[ | |
| microphone, | |
| chatbot, | |
| api_key, | |
| model_selector, | |
| temperature | |
| ], | |
| outputs=chatbot, | |
| show_progress="hidden" | |
| ) | |
| # Add footer with social links | |
| gr.Markdown(""" | |
| --- | |
| ### 👋 About the Creator | |
| Hi! I'm Graham Paasch, an experienced technology professional! | |
| 🎥 **Check out my YouTube channel** for more tech content: | |
| [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
| 💼 **Looking for a skilled developer?** | |
| I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
| ⭐ If you found this tool helpful, please consider: | |
| - Subscribing to my YouTube channel | |
| - Connecting on LinkedIn | |
| - Sharing this tool with others in healthcare tech | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True | |
| ) | |