import streamlit as st import plotly.graph_objects as go from faster_whisper import WhisperModel import requests import json import time import io from audio_recorder_streamlit import audio_recorder from PIL import Image # --- 1. CONFIGURATION --- st.set_page_config(page_title="SomAI", layout="wide", page_icon="🩺") BACKEND_API_URL = "https://arshenoy/somAI-backend.hf.space" @st.cache_resource def load_whisper(): print(">>> LOADING AUDIO SENSORS...") whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") return whisper_model try: whisper = load_whisper() except Exception as e: st.error(f"WHISPER FAILURE: {e}") st.stop() # --- 3. NEW NEON STYLE CSS --- st.markdown(""" """, unsafe_allow_html=True) # --- 4. SESSION STATE --- if 'history' not in st.session_state: st.session_state.history = [] if 'risk_score' not in st.session_state: st.session_state.risk_score = 0 if 'risk_summary' not in st.session_state: st.session_state.risk_summary = "Pending Analysis" if 'mode' not in st.session_state: st.session_state.mode = "GENERAL" # --- 5. SIDEBAR --- with st.sidebar: st.markdown("### 🧬 Patient Intake & Vitals") with st.expander("Patient Profile", expanded=True): name = st.text_input("Name", "Patient X") age = st.slider("Age", 18, 90, 20) condition = st.text_input("Primary Condition", "Diabetes/Hypertension") with st.expander("Clinical Vitals", expanded=True): bp = st.number_input("Systolic BP", 90, 220, 110) glucose = st.number_input("Glucose", 70, 400, 110) sleep = st.slider("Sleep Quality (0-10)", 0, 10, 4) meds = st.slider("Missed Doses (Last 7 Days)", 0, 7, 3) if st.button("RUN CLINICAL ANALYSIS", type="primary", use_container_width=True): with st.spinner("Analyzing Clinical Markers..."): try: # API CALL TO /analyze ENDPOINT payload = { "age": age, "condition": condition, "sleep_quality": sleep, "missed_doses": meds, "systolic_bp": bp, "glucose": glucose } response = requests.post(f"{BACKEND_API_URL}/analyze", json=payload, timeout=30) response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx) data = response.json() st.session_state.risk_score = data['numeric_score'] st.session_state.risk_summary = data['risk_summary'] except requests.exceptions.RequestException as req_err: st.error(f"API Error: Cannot connect to backend (Code: {req_err.response.status_code if hasattr(req_err, 'response') and req_err.response else 'N/A'}). Ensure Space 2 is running.") st.session_state.risk_score = 0 st.session_state.risk_summary = "Backend service unavailable." except Exception as e: st.error(f"Analysis Failed: {e}") st.session_state.risk_score = 0 st.session_state.risk_summary = "Processing error." val = st.session_state.risk_score color = "#00ff80" if val < 40 else "#ffc300" if val < 80 else "#ff3300" # Neon color scheme # Gauge Chart fig = go.Figure(go.Indicator( mode="gauge+number", value=val, number={'font': {'size': 40, 'color': 'white'}}, gauge={ 'axis': {'range': [0, 100], 'tickcolor': '#333333'}, 'bar': {'color': color}, 'bgcolor': "rgba(26, 26, 26, 0.7)", 'bordercolor': "#333333", 'steps': [ {'range': [0, 40], 'color': 'rgba(0, 255, 128, 0.1)'}, {'range': [40, 80], 'color': 'rgba(255, 195, 0, 0.1)'}, {'range': [80, 100], 'color': 'rgba(255, 51, 0, 0.1)'}, ] } )) fig.update_layout( height=250, margin=dict(l=10,r=10,t=30,b=10), paper_bgcolor="rgba(0,0,0,0)", font={'color': 'white', 'family': 'JetBrains Mono'} ) st.plotly_chart(fig, use_container_width=True) st.markdown(f"""
CLINICAL ASSESSMENT

{st.session_state.risk_summary}

""", unsafe_allow_html=True) # --- 6. MAIN CHAT --- col_main = st.columns(1)[0] with col_main: st.markdown("## 🧠 SomAI Assistant") c1, c2 = st.columns(2) if c1.button("🩺 Medical Guide", use_container_width=True, type="primary" if st.session_state.mode == "GENERAL" else "secondary"): st.session_state.mode = "GENERAL" if c2.button("🫂 Therapist", use_container_width=True, type="primary" if st.session_state.mode == "THERAPY" else "secondary"): st.session_state.mode = "THERAPY" # Display History chat_container = st.container(height=400) for msg in st.session_state.history: div_class = "user-bubble" if msg['role'] == "user" else "ai-bubble" chat_container.markdown(f"
{msg['content']}
", unsafe_allow_html=True) # --- ROBUST AUDIO INPUT (Hold and Speak) --- st.markdown("---") st.markdown("🎙️ **Hold & Speak:**") audio_bytes = audio_recorder( text="", recording_color="#ff3300", neutral_color="#00ff80", icon_name="microphone", icon_size="3x", initial_time=0 ) user_query = None # 1. VOICE PROCESSING if audio_bytes: with st.spinner("🔊 Transcribing Voice..."): audio_file = io.BytesIO(audio_bytes) segments, info = whisper.transcribe(audio_file, beam_size=5) text_list = [segment.text for segment in segments] user_query = " ".join(text_list).strip() if not user_query: st.warning("Could not detect speech. Please speak clearly.") st.stop() st.session_state.history.append({"role": "user", "content": user_query}) chat_container.markdown(f"
{user_query}
", unsafe_allow_html=True) # 2. TEXT PROCESSING text_input = st.chat_input("...or type a message") if text_input: user_query = text_input st.session_state.history.append({"role": "user", "content": user_query}) chat_container.markdown(f"
{user_query}
", unsafe_allow_html=True) # 3. QUERY LOGIC (API Call) if user_query: # --- LLM Response Generation (Streaming Emulation) --- placeholder = chat_container.empty() full_resp = "" with placeholder.container(): with st.spinner("Thinking..."): try: # API Call to /generate ENDPOINT payload = { "query": user_query, "age": age, "condition": condition, "mode": st.session_state.mode } response = requests.post(f"{BACKEND_API_URL}/generate", json=payload, timeout=60) response.raise_for_status() data = response.json() raw_text = data['generated_text'] chunk_size = 5 words = raw_text.split() for i in range(0, len(words), chunk_size): chunk = " ".join(words[i:i + chunk_size]) full_resp += chunk + " " placeholder.markdown(f"
{full_resp}▌
", unsafe_allow_html=True) time.sleep(0.05) # Adjust for speed placeholder.markdown(f"
{raw_text}
", unsafe_allow_html=True) st.session_state.history.append({"role": "assistant", "content": raw_text}) suggestions = data.get('suggestions', []) if suggestions: st.markdown("---") st.markdown("💡 **Next Steps:**") suggestion_cols = st.columns(len(suggestions)) for i, sug in enumerate(suggestions): suggestion_cols[i].button(sug, key=f"sug_{i}_{len(st.session_state.history)}", use_container_width=True) except requests.exceptions.RequestException as req_err: error_msg = f"API Error: {req_err}. Check backend service health." st.error(error_msg) st.session_state.history.append({"role": "assistant", "content": error_msg}) placeholder.markdown(f"
{error_msg}
", unsafe_allow_html=True) except Exception as e: error_msg = f"LLM Generation Failed: {e}" st.error(error_msg) st.session_state.history.append({"role": "assistant", "content": error_msg}) placeholder.markdown(f"
{error_msg}
", unsafe_allow_html=True) st.rerun()