Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import plotly.graph_objects as go | |
| from faster_whisper import WhisperModel | |
| import requests | |
| import json | |
| import time | |
| import io | |
| from audio_recorder_streamlit import audio_recorder | |
| from PIL import Image | |
| # --- 1. CONFIGURATION --- | |
| st.set_page_config(page_title="SomAI", layout="wide", page_icon="🩺") | |
| BACKEND_API_URL = "https://arshenoy/somAI-backend.hf.space" | |
| def load_whisper(): | |
| print(">>> LOADING AUDIO SENSORS...") | |
| whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") | |
| return whisper_model | |
| try: | |
| whisper = load_whisper() | |
| except Exception as e: | |
| st.error(f"WHISPER FAILURE: {e}") | |
| st.stop() | |
| # --- 3. NEW NEON STYLE CSS --- | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&family=JetBrains+Mono:wght@400;700&display=swap'); | |
| /* BASE THEME - Deep Space Black */ | |
| .stApp { | |
| background-color: #050505; | |
| background-image: radial-gradient(circle at 50% 50%, #1a1a1a 0%, #000000 100%); | |
| font-family: 'Inter', sans-serif; | |
| } | |
| /* GLASS SIDEBAR */ | |
| section[data-testid="stSidebar"] { | |
| background: rgba(10, 10, 10, 0.7); | |
| backdrop-filter: blur(12px); | |
| border-right: 1px solid rgba(255, 255, 255, 0.08); | |
| } | |
| /* NEON METRICS */ | |
| div[data-testid="metric-container"] { | |
| background: rgba(255, 255, 255, 0.03); | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| border-radius: 12px; | |
| padding: 15px; | |
| transition: 0.3s; | |
| } | |
| div[data-testid="metric-container"]:hover { | |
| border-color: #00ff80; | |
| box-shadow: 0 0 15px rgba(0, 255, 128, 0.1); | |
| } | |
| /* TEXT GLOW */ | |
| h1, h2, h3 { | |
| font-family: 'JetBrains Mono', monospace; | |
| letter-spacing: -0.5px; | |
| color: #fff; | |
| text-shadow: 0 0 10px rgba(255, 255, 255, 0.2); | |
| } | |
| /* CHAT BUBBLES - Updated for new dark background */ | |
| .chat-bubble { | |
| padding: 12px 16px; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| line-height: 1.5; | |
| font-family: 'Inter', sans-serif; | |
| font-size: 16px; | |
| color: #e0e0e0; | |
| } | |
| .user-bubble { | |
| background-color: #004d26; /* Darker green for user */ | |
| margin-left: 20%; | |
| border-radius: 12px 12px 0 12px; | |
| text-align: right; | |
| } | |
| .ai-bubble { | |
| background-color: #1a1a1a; /* Dark gray for AI */ | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| margin-right: 20%; | |
| border-radius: 12px 12px 12px 0; | |
| text-align: left; | |
| } | |
| /* REMOVE JUNK */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- 4. SESSION STATE --- | |
| if 'history' not in st.session_state: st.session_state.history = [] | |
| if 'risk_score' not in st.session_state: st.session_state.risk_score = 0 | |
| if 'risk_summary' not in st.session_state: st.session_state.risk_summary = "Pending Analysis" | |
| if 'mode' not in st.session_state: st.session_state.mode = "GENERAL" | |
| # --- 5. SIDEBAR --- | |
| with st.sidebar: | |
| st.markdown("### 🧬 Patient Intake & Vitals") | |
| with st.expander("Patient Profile", expanded=True): | |
| name = st.text_input("Name", "Patient X") | |
| age = st.slider("Age", 18, 90, 20) | |
| condition = st.text_input("Primary Condition", "Diabetes/Hypertension") | |
| with st.expander("Clinical Vitals", expanded=True): | |
| bp = st.number_input("Systolic BP", 90, 220, 110) | |
| glucose = st.number_input("Glucose", 70, 400, 110) | |
| sleep = st.slider("Sleep Quality (0-10)", 0, 10, 4) | |
| meds = st.slider("Missed Doses (Last 7 Days)", 0, 7, 3) | |
| if st.button("RUN CLINICAL ANALYSIS", type="primary", use_container_width=True): | |
| with st.spinner("Analyzing Clinical Markers..."): | |
| try: | |
| # API CALL TO /analyze ENDPOINT | |
| payload = { | |
| "age": age, | |
| "condition": condition, | |
| "sleep_quality": sleep, | |
| "missed_doses": meds, | |
| "systolic_bp": bp, | |
| "glucose": glucose | |
| } | |
| response = requests.post(f"{BACKEND_API_URL}/analyze", json=payload, timeout=30) | |
| response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx) | |
| data = response.json() | |
| st.session_state.risk_score = data['numeric_score'] | |
| st.session_state.risk_summary = data['risk_summary'] | |
| except requests.exceptions.RequestException as req_err: | |
| st.error(f"API Error: Cannot connect to backend (Code: {req_err.response.status_code if hasattr(req_err, 'response') and req_err.response else 'N/A'}). Ensure Space 2 is running.") | |
| st.session_state.risk_score = 0 | |
| st.session_state.risk_summary = "Backend service unavailable." | |
| except Exception as e: | |
| st.error(f"Analysis Failed: {e}") | |
| st.session_state.risk_score = 0 | |
| st.session_state.risk_summary = "Processing error." | |
| val = st.session_state.risk_score | |
| color = "#00ff80" if val < 40 else "#ffc300" if val < 80 else "#ff3300" # Neon color scheme | |
| # Gauge Chart | |
| fig = go.Figure(go.Indicator( | |
| mode="gauge+number", | |
| value=val, | |
| number={'font': {'size': 40, 'color': 'white'}}, | |
| gauge={ | |
| 'axis': {'range': [0, 100], 'tickcolor': '#333333'}, | |
| 'bar': {'color': color}, | |
| 'bgcolor': "rgba(26, 26, 26, 0.7)", | |
| 'bordercolor': "#333333", | |
| 'steps': [ | |
| {'range': [0, 40], 'color': 'rgba(0, 255, 128, 0.1)'}, | |
| {'range': [40, 80], 'color': 'rgba(255, 195, 0, 0.1)'}, | |
| {'range': [80, 100], 'color': 'rgba(255, 51, 0, 0.1)'}, | |
| ] | |
| } | |
| )) | |
| fig.update_layout( | |
| height=250, | |
| margin=dict(l=10,r=10,t=30,b=10), | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| font={'color': 'white', 'family': 'JetBrains Mono'} | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.markdown(f""" | |
| <div style="background: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 12px; padding: 15px; margin-top: 15px; border-left: 4px solid {color};"> | |
| <h5 style="margin:0; color: {color}; font-family: 'JetBrains Mono', monospace;">CLINICAL ASSESSMENT</h5> | |
| <p style="margin-top:5px; font-size: 0.9rem; color: #ddd;">{st.session_state.risk_summary}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # --- 6. MAIN CHAT --- | |
| col_main = st.columns(1)[0] | |
| with col_main: | |
| st.markdown("## 🧠 SomAI Assistant") | |
| c1, c2 = st.columns(2) | |
| if c1.button("🩺 Medical Guide", use_container_width=True, type="primary" if st.session_state.mode == "GENERAL" else "secondary"): st.session_state.mode = "GENERAL" | |
| if c2.button("🫂 Therapist", use_container_width=True, type="primary" if st.session_state.mode == "THERAPY" else "secondary"): st.session_state.mode = "THERAPY" | |
| # Display History | |
| chat_container = st.container(height=400) | |
| for msg in st.session_state.history: | |
| div_class = "user-bubble" if msg['role'] == "user" else "ai-bubble" | |
| chat_container.markdown(f"<div class='chat-bubble {div_class}'>{msg['content']}</div>", unsafe_allow_html=True) | |
| # --- ROBUST AUDIO INPUT (Hold and Speak) --- | |
| st.markdown("---") | |
| st.markdown("🎙️ **Hold & Speak:**") | |
| audio_bytes = audio_recorder( | |
| text="", | |
| recording_color="#ff3300", | |
| neutral_color="#00ff80", | |
| icon_name="microphone", | |
| icon_size="3x", | |
| initial_time=0 | |
| ) | |
| user_query = None | |
| # 1. VOICE PROCESSING | |
| if audio_bytes: | |
| with st.spinner("🔊 Transcribing Voice..."): | |
| audio_file = io.BytesIO(audio_bytes) | |
| segments, info = whisper.transcribe(audio_file, beam_size=5) | |
| text_list = [segment.text for segment in segments] | |
| user_query = " ".join(text_list).strip() | |
| if not user_query: | |
| st.warning("Could not detect speech. Please speak clearly.") | |
| st.stop() | |
| st.session_state.history.append({"role": "user", "content": user_query}) | |
| chat_container.markdown(f"<div class='chat-bubble user-bubble'>{user_query}</div>", unsafe_allow_html=True) | |
| # 2. TEXT PROCESSING | |
| text_input = st.chat_input("...or type a message") | |
| if text_input: | |
| user_query = text_input | |
| st.session_state.history.append({"role": "user", "content": user_query}) | |
| chat_container.markdown(f"<div class='chat-bubble user-bubble'>{user_query}</div>", unsafe_allow_html=True) | |
| # 3. QUERY LOGIC (API Call) | |
| if user_query: | |
| # --- LLM Response Generation (Streaming Emulation) --- | |
| placeholder = chat_container.empty() | |
| full_resp = "" | |
| with placeholder.container(): | |
| with st.spinner("Thinking..."): | |
| try: | |
| # API Call to /generate ENDPOINT | |
| payload = { | |
| "query": user_query, | |
| "age": age, | |
| "condition": condition, | |
| "mode": st.session_state.mode | |
| } | |
| response = requests.post(f"{BACKEND_API_URL}/generate", json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| raw_text = data['generated_text'] | |
| chunk_size = 5 | |
| words = raw_text.split() | |
| for i in range(0, len(words), chunk_size): | |
| chunk = " ".join(words[i:i + chunk_size]) | |
| full_resp += chunk + " " | |
| placeholder.markdown(f"<div class='chat-bubble ai-bubble'>{full_resp}▌</div>", unsafe_allow_html=True) | |
| time.sleep(0.05) # Adjust for speed | |
| placeholder.markdown(f"<div class='chat-bubble ai-bubble'>{raw_text}</div>", unsafe_allow_html=True) | |
| st.session_state.history.append({"role": "assistant", "content": raw_text}) | |
| suggestions = data.get('suggestions', []) | |
| if suggestions: | |
| st.markdown("---") | |
| st.markdown("💡 **Next Steps:**") | |
| suggestion_cols = st.columns(len(suggestions)) | |
| for i, sug in enumerate(suggestions): | |
| suggestion_cols[i].button(sug, key=f"sug_{i}_{len(st.session_state.history)}", use_container_width=True) | |
| except requests.exceptions.RequestException as req_err: | |
| error_msg = f"API Error: {req_err}. Check backend service health." | |
| st.error(error_msg) | |
| st.session_state.history.append({"role": "assistant", "content": error_msg}) | |
| placeholder.markdown(f"<div class='chat-bubble ai-bubble'>{error_msg}</div>", unsafe_allow_html=True) | |
| except Exception as e: | |
| error_msg = f"LLM Generation Failed: {e}" | |
| st.error(error_msg) | |
| st.session_state.history.append({"role": "assistant", "content": error_msg}) | |
| placeholder.markdown(f"<div class='chat-bubble ai-bubble'>{error_msg}</div>", unsafe_allow_html=True) | |
| st.rerun() |