Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,397 +1,486 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
except ImportError:
|
| 17 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", p])
|
| 18 |
-
|
| 19 |
-
install(["mido", "midi2audio", "pydub", "gradio", "requests"])
|
| 20 |
-
|
| 21 |
-
# ------------------------------------------------------------------
|
| 22 |
-
# 1. Imports
|
| 23 |
-
# ------------------------------------------------------------------
|
| 24 |
-
import random, requests, json, tempfile, mido, re, os, shutil, gradio as gr
|
| 25 |
-
from mido import Message, MidiFile, MidiTrack, MetaMessage
|
| 26 |
-
from midi2audio import FluidSynth
|
| 27 |
-
from pydub import AudioSegment
|
| 28 |
-
|
| 29 |
-
# ------------------------------------------------------------------
|
| 30 |
-
# 2. Config
|
| 31 |
-
# ------------------------------------------------------------------
|
| 32 |
-
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
|
| 33 |
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
"
|
| 38 |
-
|
| 39 |
-
"
|
| 40 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
NOTE_MAP = {"C":0,"C#":1,"DB":1,"D":2,"D#":3,"EB":3,"E":4,"F":5,"F#":6,"GB":6,
|
| 50 |
-
"G":7,"G#":8,"AB":8,"A":9,"A#":10,"BB":10,"B":11}
|
| 51 |
-
|
| 52 |
-
def sanitise_note_name(n: str) -> str:
|
| 53 |
-
"""Convert common variations to canonical form."""
|
| 54 |
-
n = n.strip().upper()
|
| 55 |
-
# Map full words
|
| 56 |
-
n = re.sub(r'\bFLAT\b', 'B', n, flags=re.I)
|
| 57 |
-
n = re.sub(r'\bSHARP\b', '#', n, flags=re.I)
|
| 58 |
-
# Map dash/space forms
|
| 59 |
-
n = re.sub(r'([A-G])\s*-\s*FLAT', r'\1B', n, flags=re.I)
|
| 60 |
-
n = re.sub(r'([A-G])\s*-\s*SHARP', r'\1#', n, flags=re.I)
|
| 61 |
-
return n
|
| 62 |
-
|
| 63 |
-
def note_name_to_midi(n: str) -> int:
|
| 64 |
-
n = sanitise_note_name(n)
|
| 65 |
-
m = re.fullmatch(r"([A-G][#B]?)(\d)", n)
|
| 66 |
-
if not m:
|
| 67 |
-
raise ValueError(f"Invalid note: {n}")
|
| 68 |
-
pitch, octave = m.groups()
|
| 69 |
-
pitch = pitch.replace('B', 'B') # keep BB as B-flat
|
| 70 |
-
if pitch not in NOTE_MAP:
|
| 71 |
-
raise ValueError(f"Invalid pitch: {pitch}")
|
| 72 |
-
return NOTE_MAP[pitch] + (int(octave) + 1) * 12
|
| 73 |
-
|
| 74 |
-
# ------------------------------------------------------------------
|
| 75 |
-
# 4. Integer-only scaler – unchanged
|
| 76 |
-
# ------------------------------------------------------------------
|
| 77 |
-
def scale_json_durations(json_data, target_units: int) -> list:
|
| 78 |
-
ints = [[n, max(1, int(round(d)))] for n, d in json_data]
|
| 79 |
-
total = sum(d for _, d in ints)
|
| 80 |
-
deficit = target_units - total
|
| 81 |
-
if deficit > 0:
|
| 82 |
-
for i in range(deficit):
|
| 83 |
-
ints[i % len(ints)][1] += 1
|
| 84 |
-
elif deficit < 0:
|
| 85 |
-
for i in range(-deficit):
|
| 86 |
-
if ints[i % len(ints)][1] > 1:
|
| 87 |
-
ints[i % len(ints)][1] -= 1
|
| 88 |
-
return ints
|
| 89 |
-
|
| 90 |
-
# ------------------------------------------------------------------
|
| 91 |
-
# 5. JSON → MIDI – unchanged
|
| 92 |
-
# ------------------------------------------------------------------
|
| 93 |
-
def json_to_midi(json_data, instrument, tempo, time_signature, measures):
|
| 94 |
-
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
|
| 95 |
-
track = MidiTrack(); mid.tracks.append(track)
|
| 96 |
-
program = INSTRUMENT_PROGRAMS.get(instrument, 56)
|
| 97 |
-
num, denom = map(int, time_signature.split('/'))
|
| 98 |
-
|
| 99 |
-
track.append(MetaMessage('time_signature', numerator=num, denominator=denom, time=0))
|
| 100 |
-
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
|
| 101 |
-
track.append(Message('program_change', program=program, time=0))
|
| 102 |
-
|
| 103 |
-
for note_name, dur_units in json_data:
|
| 104 |
-
note_num = note_name_to_midi(note_name)
|
| 105 |
-
ticks = max(int(dur_units * TICKS_PER_UNIT), 1)
|
| 106 |
-
vel = random.randint(60, 100)
|
| 107 |
-
track.append(Message('note_on', note=note_num, velocity=vel, time=0))
|
| 108 |
-
track.append(Message('note_off', note=note_num, velocity=vel, time=ticks))
|
| 109 |
-
return mid
|
| 110 |
-
|
| 111 |
-
# ------------------------------------------------------------------
|
| 112 |
-
# 6. Prompt helpers – always include strict format
|
| 113 |
-
# ------------------------------------------------------------------
|
| 114 |
-
INSTRUMENT_PROGRAMS = {"Piano":0,"Trumpet":56,"Violin":40,"Clarinet":71,"Flute":73}
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
"
|
| 120 |
-
"
|
| 121 |
-
"
|
| 122 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
-
|
| 125 |
-
numerator = int(time_sig.split('/')[0])
|
| 126 |
-
total = measures * numerator
|
| 127 |
-
notes, durs = [], []
|
| 128 |
-
i = 0
|
| 129 |
-
while sum(durs) < total:
|
| 130 |
-
notes.append(pat[i % len(pat)])
|
| 131 |
-
durs.append(1)
|
| 132 |
-
i += 1
|
| 133 |
-
if sum(durs) > total:
|
| 134 |
-
durs[-1] -= sum(durs) - total
|
| 135 |
-
return json.dumps([[n,d] for n,d in zip(notes,durs)])
|
| 136 |
-
|
| 137 |
-
# ------------------------------------------------------------------
|
| 138 |
-
# 7. Mistral query – strengthened prompt
|
| 139 |
-
# ------------------------------------------------------------------
|
| 140 |
-
def query_mistral(prompt, instrument, level, key, time_sig, measures):
|
| 141 |
-
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
|
| 142 |
-
numerator = int(time_sig.split('/')[0])
|
| 143 |
-
target_units = measures * numerator
|
| 144 |
-
|
| 145 |
-
strict_format = (
|
| 146 |
-
"Use ONLY standard note names like 'C4', 'F#5', 'Bb3'. "
|
| 147 |
-
"Use ONLY integer durations representing quarter-note beats: "
|
| 148 |
-
"1 = quarter, 2 = half, 4 = whole. "
|
| 149 |
-
f"Sum MUST equal exactly {target_units}. "
|
| 150 |
-
"Output ONLY a JSON array of [note, integer_duration] pairs. "
|
| 151 |
-
"No prose, no explanation."
|
| 152 |
-
)
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
else:
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
return json.dumps(scaled, indent=2), static_mp3, str(tempo), midi, f"{dur:.2f}s", ts
|
| 216 |
-
except Exception as e:
|
| 217 |
-
return f"Error: {e}", None, str(tempo), None, "0", ts
|
| 218 |
-
|
| 219 |
-
def get_soundfont(instrument):
|
| 220 |
-
os.makedirs("soundfonts", exist_ok=True)
|
| 221 |
-
sf2 = f"soundfonts/{instrument}.sf2"
|
| 222 |
-
if not os.path.exists(sf2):
|
| 223 |
-
open(sf2,"wb").write(requests.get(SOUNDFONT_URLS[instrument]).content)
|
| 224 |
-
return sf2
|
| 225 |
-
|
| 226 |
-
# ------------------------------------------------------------------
|
| 227 |
-
# 10. Chat helper – unchanged
|
| 228 |
-
# ------------------------------------------------------------------
|
| 229 |
-
def chat(msg, hist, instr, lvl):
|
| 230 |
-
if not msg.strip(): return "", hist
|
| 231 |
-
msgs=[{"role":"system","content":f"You are a {instr} teacher for {lvl} students."}]
|
| 232 |
-
for u,a in hist: msgs.extend([{"role":"user","content":u},{"role":"assistant","content":a}])
|
| 233 |
-
msgs.append({"role":"user","content":msg})
|
| 234 |
-
headers={"Authorization":f"Bearer {MISTRAL_API_KEY}","Content-Type":"application/json"}
|
| 235 |
-
payload={"model":"mistral-medium","messages":msgs,"max_tokens":500}
|
| 236 |
-
try:
|
| 237 |
-
r=requests.post(MISTRAL_API_URL,headers=headers,json=payload)
|
| 238 |
-
r.raise_for_status()
|
| 239 |
-
reply=r.json()["choices"][0]["message"]["content"]
|
| 240 |
-
hist.append((msg,reply))
|
| 241 |
-
return "", hist
|
| 242 |
-
except Exception as e:
|
| 243 |
-
hist.append((msg,f"Error: {e}"))
|
| 244 |
-
return "", hist
|
| 245 |
-
|
| 246 |
-
# ------------------------------------------------------------------
|
| 247 |
-
# 11. Gradio UI – unchanged
|
| 248 |
-
# ------------------------------------------------------------------
|
| 249 |
-
def ui():
|
| 250 |
-
with gr.Blocks(title="Adaptive Music Exercise Generator – Integer Durations") as demo:
|
| 251 |
-
gr.Markdown("# 🎼 Adaptive Music Exercise Generator – Integer Durations")
|
| 252 |
-
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Mode")
|
| 253 |
-
with gr.Row():
|
| 254 |
-
with gr.Column(scale=1):
|
| 255 |
-
with gr.Group(visible=True) as params_group:
|
| 256 |
-
instrument=gr.Dropdown(["Trumpet","Piano","Violin","Clarinet","Flute"],value="Trumpet")
|
| 257 |
-
level = gr.Radio(["Beginner","Intermediate","Advanced"], value="Intermediate")
|
| 258 |
-
key = gr.Dropdown(["C Major","G Major","D Major","F Major","Bb Major","A Minor","E Minor"],value="C Major")
|
| 259 |
-
ts = gr.Dropdown(["3/4","4/4"], value="4/4", label="Time Signature")
|
| 260 |
-
measures = gr.Radio([4,8], value=4, label="Length (measures)")
|
| 261 |
-
with gr.Group(visible=False) as prompt_group:
|
| 262 |
-
prompt_txt = gr.Textbox("", lines=3, label="Custom prompt")
|
| 263 |
-
measures2 = gr.Radio([4,8], value=4, label="Length (measures)")
|
| 264 |
-
gen_btn = gr.Button("Generate Exercise", variant="primary")
|
| 265 |
-
with gr.Column(scale=2):
|
| 266 |
-
with gr.Tabs():
|
| 267 |
-
with gr.TabItem("Player"):
|
| 268 |
-
audio=gr.Audio(autoplay=True)
|
| 269 |
-
bpm=gr.Textbox(label="Tempo (BPM)")
|
| 270 |
-
ts_disp=gr.Textbox(label="Time Signature")
|
| 271 |
-
dur=gr.Textbox(label="Duration")
|
| 272 |
-
with gr.TabItem("Data"):
|
| 273 |
-
js_out=gr.Code(language="json")
|
| 274 |
-
with gr.TabItem("MIDI"):
|
| 275 |
-
midi_file=gr.File()
|
| 276 |
-
dl_btn=gr.Button("Download MIDI")
|
| 277 |
-
with gr.TabItem("Chat"):
|
| 278 |
-
chatbot=gr.Chatbot(height=400)
|
| 279 |
-
chat_in=gr.Textbox(label="Ask the AI")
|
| 280 |
-
chat_btn=gr.Button("Send")
|
| 281 |
-
|
| 282 |
-
mode.change(lambda m:{params_group:gr.update(visible=m=="Exercise Parameters"),
|
| 283 |
-
prompt_group:gr.update(visible=m=="Exercise Prompt")},
|
| 284 |
-
[mode],[params_group,prompt_group])
|
| 285 |
-
|
| 286 |
-
def caller(m,i,l,k,t,ms,p,ms2):
|
| 287 |
-
real_ms = ms2 if m=="Exercise Prompt" else ms
|
| 288 |
-
return generate_exercise(i,l,k,60,t,real_ms,p or "",m)
|
| 289 |
-
gen_btn.click(caller,[mode,instrument,level,key,ts,measures,prompt_txt,measures2],
|
| 290 |
-
[js_out,audio,bpm,gr.State(),dur,ts_disp])
|
| 291 |
-
|
| 292 |
-
def save(json_txt,instr,ts):
|
| 293 |
-
data=safe_parse(json_txt)
|
| 294 |
-
if not data: return None
|
| 295 |
-
num=int(ts.split('/')[0])
|
| 296 |
-
target=sum(int(d) for _,d in data)
|
| 297 |
-
measures_est=max(1,round(target/num))
|
| 298 |
-
scaled=scale_json_durations(data,measures_est*num)
|
| 299 |
-
midi=json_to_midi(scaled,instr,60,ts,measures_est)
|
| 300 |
-
path="static/exercise.mid"
|
| 301 |
-
midi.save(path)
|
| 302 |
-
return path
|
| 303 |
-
dl_btn.click(save,[js_out,instrument,ts],[midi_file])
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
}
|
| 324 |
-
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload, timeout=10)
|
| 325 |
-
response.raise_for_status()
|
| 326 |
-
return response.json()['choices'][0]['message']['content']
|
| 327 |
-
except:
|
| 328 |
-
return None
|
| 329 |
-
|
| 330 |
-
def generate_response(name, message):
|
| 331 |
-
"""Generate medical response with fallback"""
|
| 332 |
-
# First try local knowledge base
|
| 333 |
-
local_response = get_local_advice(message)
|
| 334 |
-
if local_response:
|
| 335 |
-
return f"Hello {name},\n\n{local_response}\n\nWishing you good health,\nDr. Alex"
|
| 336 |
-
|
| 337 |
-
# Then try API
|
| 338 |
-
prompt = f"""As Dr. Alex, provide structured advice for {name} who reports: "{message}"
|
| 339 |
-
|
| 340 |
-
Format response with these sections:
|
| 341 |
-
|
| 342 |
-
[Assessment] - Brief professional evaluation
|
| 343 |
-
[Self-Care Recommendations] - 3-5 specific actionable steps
|
| 344 |
-
[When to Seek Medical Care] - Clear warning signs
|
| 345 |
-
[Follow-up] - Monitoring advice
|
| 346 |
-
|
| 347 |
-
Use bullet points, professional yet compassionate tone."""
|
| 348 |
-
|
| 349 |
-
api_response = query_mistral(prompt)
|
| 350 |
-
if api_response:
|
| 351 |
-
return api_response
|
| 352 |
-
|
| 353 |
-
# Final fallback
|
| 354 |
-
common_conditions = "\n".join([f"- {cond}" for cond in MEDICAL_KNOWLEDGE.keys()])
|
| 355 |
-
return f"""Hello {name},
|
| 356 |
-
|
| 357 |
-
I'm currently unable to access detailed medical databases. For general advice:
|
| 358 |
-
|
| 359 |
-
Common conditions I can advise on:
|
| 360 |
-
{common_conditions}
|
| 361 |
-
|
| 362 |
-
For immediate concerns:
|
| 363 |
-
• Contact your local healthcare provider
|
| 364 |
-
• In emergencies, call your local emergency number
|
| 365 |
-
|
| 366 |
-
Please try asking about one of these common conditions or consult a healthcare professional.
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import requests
|
| 7 |
+
|
| 8 |
+
# Set page config
|
| 9 |
+
st.set_page_config(
|
| 10 |
+
page_title="Heart Disease Prediction",
|
| 11 |
+
page_icon="❤️",
|
| 12 |
+
layout="wide"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Mistral API Key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
|
| 17 |
|
| 18 |
+
# Feature information with medical ranges
|
| 19 |
+
feature_info = {
|
| 20 |
+
"age": {"label": "Age", "desc": "Age in years", "min": 20, "max": 100, "value": 50,
|
| 21 |
+
"healthy_range": "20-60", "risk_range": "60+"},
|
| 22 |
+
"sex": {"label": "Sex", "desc": "Biological sex", "options": ["Female", "Male"]},
|
| 23 |
+
"cp": {"label": "Chest Pain Type", "desc": "Type of chest pain",
|
| 24 |
+
"options": ["Typical angina", "Atypical angina", "Non-anginal pain", "Asymptomatic"],
|
| 25 |
+
"healthy": "Typical angina", "risk": "Asymptomatic"},
|
| 26 |
+
"trestbps": {"label": "Resting BP", "desc": "Resting blood pressure (mm Hg)",
|
| 27 |
+
"min": 80, "max": 200, "value": 120,
|
| 28 |
+
"healthy_range": "90-120", "risk_range": "140+"},
|
| 29 |
+
"chol": {"label": "Cholesterol", "desc": "Serum cholesterol (mg/dl)",
|
| 30 |
+
"min": 100, "max": 600, "value": 200,
|
| 31 |
+
"healthy_range": "< 200", "risk_range": "240+"},
|
| 32 |
+
"fbs": {"label": "Fasting Blood Sugar", "desc": "> 120 mg/dl", "options": ["No", "Yes"],
|
| 33 |
+
"healthy": "No", "risk": "Yes"},
|
| 34 |
+
"restecg": {"label": "Resting ECG", "desc": "Resting electrocardiographic results",
|
| 35 |
+
"options": ["Normal", "ST-T wave abnormality", "Left ventricular hypertrophy"],
|
| 36 |
+
"healthy": "Normal", "risk": "Left ventricular hypertrophy"},
|
| 37 |
+
"thalach": {"label": "Max Heart Rate", "desc": "Maximum heart rate achieved",
|
| 38 |
+
"min": 60, "max": 220, "value": 150,
|
| 39 |
+
"healthy_range": "60-100 (resting)", "risk_range": "< 120 (exercise)"},
|
| 40 |
+
"exang": {"label": "Exercise Angina", "desc": "Exercise induced angina", "options": ["No", "Yes"],
|
| 41 |
+
"healthy": "No", "risk": "Yes"},
|
| 42 |
+
"oldpeak": {"label": "ST Depression", "desc": "ST depression induced by exercise",
|
| 43 |
+
"min": 0.0, "max": 6.0, "value": 1.0, "step": 0.1,
|
| 44 |
+
"healthy_range": "0-1", "risk_range": "2+"},
|
| 45 |
+
"slope": {"label": "ST Slope", "desc": "Slope of peak exercise ST segment",
|
| 46 |
+
"options": ["Upsloping", "Flat", "Downsloping"],
|
| 47 |
+
"healthy": "Upsloping", "risk": "Downsloping"},
|
| 48 |
+
"ca": {"label": "Major Vessels", "desc": "Number of major vessels colored by fluoroscopy",
|
| 49 |
+
"options": ["0", "1", "2", "3"],
|
| 50 |
+
"healthy": "0", "risk": "3"},
|
| 51 |
+
"thal": {"label": "Thalassemia", "desc": "Blood disorder called thalassemia",
|
| 52 |
+
"options": ["Normal", "Fixed defect", "Reversible defect"],
|
| 53 |
+
"healthy": "Normal", "risk": "Fixed defect"}
|
| 54 |
}
|
| 55 |
|
| 56 |
+
# Risk assessment thresholds
|
| 57 |
+
RISK_THRESHOLDS = {
|
| 58 |
+
"Low": 0.3,
|
| 59 |
+
"Moderate": 0.6,
|
| 60 |
+
"High": 1.0
|
| 61 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# Sample healthy and unhealthy profiles
|
| 64 |
+
SAMPLE_PROFILES = {
|
| 65 |
+
"Healthy": {
|
| 66 |
+
"age": 45, "sex": "Female", "cp": "Typical angina",
|
| 67 |
+
"trestbps": 110, "chol": 180, "fbs": "No",
|
| 68 |
+
"restecg": "Normal", "thalach": 160, "exang": "No",
|
| 69 |
+
"oldpeak": 0.5, "slope": "Upsloping", "ca": "0",
|
| 70 |
+
"thal": "Normal"
|
| 71 |
+
},
|
| 72 |
+
"Moderate Risk": {
|
| 73 |
+
"age": 58, "sex": "Male", "cp": "Atypical angina",
|
| 74 |
+
"trestbps": 135, "chol": 230, "fbs": "No",
|
| 75 |
+
"restecg": "ST-T wave abnormality", "thalach": 140, "exang": "No",
|
| 76 |
+
"oldpeak": 1.5, "slope": "Flat", "ca": "1",
|
| 77 |
+
"thal": "Reversible defect"
|
| 78 |
+
},
|
| 79 |
+
"High Risk": {
|
| 80 |
+
"age": 65, "sex": "Male", "cp": "Asymptomatic",
|
| 81 |
+
"trestbps": 180, "chol": 280, "fbs": "Yes",
|
| 82 |
+
"restecg": "Left ventricular hypertrophy", "thalach": 120, "exang": "Yes",
|
| 83 |
+
"oldpeak": 3.0, "slope": "Downsloping", "ca": "3",
|
| 84 |
+
"thal": "Fixed defect"
|
| 85 |
}
|
| 86 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
def calculate_risk(inputs):
|
| 89 |
+
"""Calculate heart disease risk based on clinical parameters"""
|
| 90 |
+
risk_factors = 0
|
| 91 |
+
total_possible = 13 # Total parameters being evaluated
|
| 92 |
+
risk_details = []
|
| 93 |
+
risk_params = []
|
| 94 |
+
|
| 95 |
+
if inputs['age'] > 60:
|
| 96 |
+
risk_factors += 1
|
| 97 |
+
risk_details.append(f"- Age: {inputs['age']} (Risk: 60+)")
|
| 98 |
+
risk_params.append('age')
|
| 99 |
+
if inputs['sex'] == 'Male':
|
| 100 |
+
risk_factors += 1
|
| 101 |
+
risk_details.append("- Sex: Male (Higher risk)")
|
| 102 |
+
risk_params.append('sex')
|
| 103 |
+
if inputs['cp'] == 'Asymptomatic':
|
| 104 |
+
risk_factors += 1
|
| 105 |
+
risk_details.append("- Chest Pain: Asymptomatic (Highest risk)")
|
| 106 |
+
risk_params.append('cp')
|
| 107 |
+
elif inputs['cp'] in ['Atypical angina', 'Non-anginal pain']:
|
| 108 |
+
risk_factors += 0.5
|
| 109 |
+
risk_details.append(f"- Chest Pain: {inputs['cp']} (Moderate risk)")
|
| 110 |
+
if inputs['trestbps'] >= 140:
|
| 111 |
+
risk_factors += 1
|
| 112 |
+
risk_details.append(f"- Blood Pressure: {inputs['trestbps']} mmHg (Stage 2 Hypertension)")
|
| 113 |
+
risk_params.append('trestbps')
|
| 114 |
+
elif inputs['trestbps'] >= 130:
|
| 115 |
+
risk_factors += 0.5
|
| 116 |
+
risk_details.append(f"- Blood Pressure: {inputs['trestbps']} mmHg (Elevated)")
|
| 117 |
+
if inputs['chol'] >= 240:
|
| 118 |
+
risk_factors += 1
|
| 119 |
+
risk_details.append(f"- Cholesterol: {inputs['chol']} mg/dL (High)")
|
| 120 |
+
risk_params.append('chol')
|
| 121 |
+
elif inputs['chol'] >= 200:
|
| 122 |
+
risk_factors += 0.5
|
| 123 |
+
risk_details.append(f"- Cholesterol: {inputs['chol']} mg/dL (Borderline High)")
|
| 124 |
+
if inputs['fbs'] == 'Yes':
|
| 125 |
+
risk_factors += 1
|
| 126 |
+
risk_details.append("- Fasting Blood Sugar > 120 mg/dL")
|
| 127 |
+
risk_params.append('fbs')
|
| 128 |
+
if inputs['restecg'] == 'Left ventricular hypertrophy':
|
| 129 |
+
risk_factors += 1
|
| 130 |
+
risk_details.append("- ECG: Left Ventricular Hypertrophy")
|
| 131 |
+
risk_params.append('restecg')
|
| 132 |
+
elif inputs['restecg'] == 'ST-T wave abnormality':
|
| 133 |
+
risk_factors += 0.5
|
| 134 |
+
risk_details.append("- ECG: ST-T Wave Abnormality")
|
| 135 |
+
if inputs['thalach'] < 120:
|
| 136 |
+
risk_factors += 1
|
| 137 |
+
risk_details.append(f"- Max Heart Rate: {inputs['thalach']} (Low)")
|
| 138 |
+
risk_params.append('thalach')
|
| 139 |
+
if inputs['exang'] == 'Yes':
|
| 140 |
+
risk_factors += 1
|
| 141 |
+
risk_details.append("- Exercise Induced Angina: Yes")
|
| 142 |
+
risk_params.append('exang')
|
| 143 |
+
if inputs['oldpeak'] >= 2.0:
|
| 144 |
+
risk_factors += 1
|
| 145 |
+
risk_details.append(f"- ST Depression: {inputs['oldpeak']} (High)")
|
| 146 |
+
risk_params.append('oldpeak')
|
| 147 |
+
elif inputs['oldpeak'] >= 1.0:
|
| 148 |
+
risk_factors += 0.5
|
| 149 |
+
risk_details.append(f"- ST Depression: {inputs['oldpeak']} (Moderate)")
|
| 150 |
+
if inputs['slope'] == 'Downsloping':
|
| 151 |
+
risk_factors += 1
|
| 152 |
+
risk_details.append("- ST Slope: Downsloping (Highest risk)")
|
| 153 |
+
risk_params.append('slope')
|
| 154 |
+
elif inputs['slope'] == 'Flat':
|
| 155 |
+
risk_factors += 0.5
|
| 156 |
+
risk_details.append("- ST Slope: Flat (Moderate risk)")
|
| 157 |
+
if inputs['ca'] == '3':
|
| 158 |
+
risk_factors += 1
|
| 159 |
+
risk_details.append("- Major Vessels: 3 (Highest risk)")
|
| 160 |
+
risk_params.append('ca')
|
| 161 |
+
elif inputs['ca'] in ['1', '2']:
|
| 162 |
+
risk_factors += 0.5
|
| 163 |
+
risk_details.append(f"- Major Vessels: {inputs['ca']} (Moderate risk)")
|
| 164 |
+
if inputs['thal'] == 'Fixed defect':
|
| 165 |
+
risk_factors += 1
|
| 166 |
+
risk_details.append("- Thalassemia: Fixed Defect")
|
| 167 |
+
risk_params.append('thal')
|
| 168 |
+
elif inputs['thal'] == 'Reversible defect':
|
| 169 |
+
risk_factors += 0.5
|
| 170 |
+
risk_details.append("- Thalassemia: Reversible Defect")
|
| 171 |
+
|
| 172 |
+
probability = min(risk_factors / total_possible, 0.99) # Cap at 99%
|
| 173 |
+
|
| 174 |
+
if probability <= RISK_THRESHOLDS['Low']:
|
| 175 |
+
risk_level = 'Low'
|
| 176 |
+
color = "green"
|
| 177 |
+
icon = "✅"
|
| 178 |
+
advice = "Maintain your healthy lifestyle with regular check-ups."
|
| 179 |
+
elif probability <= RISK_THRESHOLDS['Moderate']:
|
| 180 |
+
risk_level = 'Moderate'
|
| 181 |
+
color = "orange"
|
| 182 |
+
icon = "⚠️"
|
| 183 |
+
advice = "Consider lifestyle changes and consult your doctor."
|
| 184 |
else:
|
| 185 |
+
risk_level = 'High'
|
| 186 |
+
color = "red"
|
| 187 |
+
icon = "❗"
|
| 188 |
+
advice = "Please consult a cardiologist immediately."
|
| 189 |
+
|
| 190 |
+
return risk_level, probability, color, icon, advice, risk_details, risk_params
|
| 191 |
+
|
| 192 |
+
# App layout
|
| 193 |
+
st.title("❤️ Heart Disease Risk Assessment")
|
| 194 |
+
st.markdown("""
|
| 195 |
+
This tool evaluates your heart disease risk based on clinical parameters and provides personalized insights.
|
| 196 |
+
""")
|
| 197 |
+
|
| 198 |
+
# Create tabs
|
| 199 |
+
tab1, tab2, tab3 = st.tabs(["Risk Assessment", "Parameter Analysis", "Health Guidance"])
|
| 200 |
+
|
| 201 |
+
with tab1:
|
| 202 |
+
with st.form("heart_form"):
|
| 203 |
+
st.subheader("Enter Your Health Parameters")
|
| 204 |
+
col1, col2, col3 = st.columns(3)
|
| 205 |
+
inputs = {}
|
| 206 |
+
|
| 207 |
+
# Column 1
|
| 208 |
+
with col1:
|
| 209 |
+
inputs["age"] = st.number_input(
|
| 210 |
+
feature_info["age"]["label"],
|
| 211 |
+
min_value=feature_info["age"]["min"],
|
| 212 |
+
max_value=feature_info["age"]["max"],
|
| 213 |
+
value=st.session_state.get("age", feature_info["age"]["value"]),
|
| 214 |
+
help=f"{feature_info['age']['desc']}. Healthy range: {feature_info['age']['healthy_range']}"
|
| 215 |
+
)
|
| 216 |
+
inputs["sex"] = st.selectbox(
|
| 217 |
+
feature_info["sex"]["label"],
|
| 218 |
+
feature_info["sex"]["options"],
|
| 219 |
+
index=feature_info["sex"]["options"].index(st.session_state.get("sex", feature_info["sex"]["options"][0])),
|
| 220 |
+
help=feature_info["sex"]["desc"]
|
| 221 |
+
)
|
| 222 |
+
inputs["cp"] = st.selectbox(
|
| 223 |
+
feature_info["cp"]["label"],
|
| 224 |
+
feature_info["cp"]["options"],
|
| 225 |
+
index=feature_info["cp"]["options"].index(st.session_state.get("cp", feature_info["cp"]["options"][0])),
|
| 226 |
+
help=f"{feature_info['cp']['desc']}. Healthy: {feature_info['cp']['healthy']}"
|
| 227 |
+
)
|
| 228 |
+
inputs["trestbps"] = st.number_input(
|
| 229 |
+
feature_info["trestbps"]["label"],
|
| 230 |
+
min_value=feature_info["trestbps"]["min"],
|
| 231 |
+
max_value=feature_info["trestbps"]["max"],
|
| 232 |
+
value=st.session_state.get("trestbps", feature_info["trestbps"]["value"]),
|
| 233 |
+
help=f"{feature_info['trestbps']['desc']}. Healthy: {feature_info['trestbps']['healthy_range']}"
|
| 234 |
+
)
|
| 235 |
+
inputs["chol"] = st.number_input(
|
| 236 |
+
feature_info["chol"]["label"],
|
| 237 |
+
min_value=feature_info["chol"]["min"],
|
| 238 |
+
max_value=feature_info["chol"]["max"],
|
| 239 |
+
value=st.session_state.get("chol", feature_info["chol"]["value"]),
|
| 240 |
+
help=f"{feature_info['chol']['desc']}. Healthy: {feature_info['chol']['healthy_range']}"
|
| 241 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# Column 2
|
| 244 |
+
with col2:
|
| 245 |
+
inputs["fbs"] = st.selectbox(
|
| 246 |
+
feature_info["fbs"]["label"],
|
| 247 |
+
feature_info["fbs"]["options"],
|
| 248 |
+
index=feature_info["fbs"]["options"].index(st.session_state.get("fbs", feature_info["fbs"]["options"][0])),
|
| 249 |
+
help=f"{feature_info['fbs']['desc']}. Healthy: {feature_info['fbs']['healthy']}"
|
| 250 |
+
)
|
| 251 |
+
inputs["restecg"] = st.selectbox(
|
| 252 |
+
feature_info["restecg"]["label"],
|
| 253 |
+
feature_info["restecg"]["options"],
|
| 254 |
+
index=feature_info["restecg"]["options"].index(st.session_state.get("restecg", feature_info["restecg"]["options"][0])),
|
| 255 |
+
help=f"{feature_info['restecg']['desc']}. Healthy: {feature_info['restecg']['healthy']}"
|
| 256 |
+
)
|
| 257 |
+
inputs["thalach"] = st.number_input(
|
| 258 |
+
feature_info["thalach"]["label"],
|
| 259 |
+
min_value=feature_info["thalach"]["min"],
|
| 260 |
+
max_value=feature_info["thalach"]["max"],
|
| 261 |
+
value=st.session_state.get("thalach", feature_info["thalach"]["value"]),
|
| 262 |
+
help=f"{feature_info['thalach']['desc']}. Healthy: {feature_info['thalach']['healthy_range']}"
|
| 263 |
+
)
|
| 264 |
+
inputs["exang"] = st.selectbox(
|
| 265 |
+
feature_info["exang"]["label"],
|
| 266 |
+
feature_info["exang"]["options"],
|
| 267 |
+
index=feature_info["exang"]["options"].index(st.session_state.get("exang", feature_info["exang"]["options"][0])),
|
| 268 |
+
help=f"{feature_info['exang']['desc']}. Healthy: {feature_info['exang']['healthy']}"
|
| 269 |
+
)
|
| 270 |
+
inputs["oldpeak"] = st.number_input(
|
| 271 |
+
feature_info["oldpeak"]["label"],
|
| 272 |
+
min_value=feature_info["oldpeak"]["min"],
|
| 273 |
+
max_value=feature_info["oldpeak"]["max"],
|
| 274 |
+
value=st.session_state.get("oldpeak", feature_info["oldpeak"]["value"]),
|
| 275 |
+
step=feature_info["oldpeak"]["step"],
|
| 276 |
+
help=f"{feature_info['oldpeak']['desc']}. Healthy: {feature_info['oldpeak']['healthy_range']}"
|
| 277 |
+
)
|
| 278 |
|
| 279 |
+
# Column 3
|
| 280 |
+
with col3:
|
| 281 |
+
inputs["slope"] = st.selectbox(
|
| 282 |
+
feature_info["slope"]["label"],
|
| 283 |
+
feature_info["slope"]["options"],
|
| 284 |
+
index=feature_info["slope"]["options"].index(st.session_state.get("slope", feature_info["slope"]["options"][0])),
|
| 285 |
+
help=f"{feature_info['slope']['desc']}. Healthy: {feature_info['slope']['healthy']}"
|
| 286 |
+
)
|
| 287 |
+
inputs["ca"] = st.selectbox(
|
| 288 |
+
feature_info["ca"]["label"],
|
| 289 |
+
feature_info["ca"]["options"],
|
| 290 |
+
index=feature_info["ca"]["options"].index(st.session_state.get("ca", feature_info["ca"]["options"][0])),
|
| 291 |
+
help=f"{feature_info['ca']['desc']}. Healthy: {feature_info['ca']['healthy']}"
|
| 292 |
+
)
|
| 293 |
+
inputs["thal"] = st.selectbox(
|
| 294 |
+
feature_info["thal"]["label"],
|
| 295 |
+
feature_info["thal"]["options"],
|
| 296 |
+
index=feature_info["thal"]["options"].index(st.session_state.get("thal", feature_info["thal"]["options"][0])),
|
| 297 |
+
help=f"{feature_info['thal']['desc']}. Healthy: {feature_info['thal']['healthy']}"
|
| 298 |
+
)
|
| 299 |
|
| 300 |
+
submitted = st.form_submit_button("Assess My Heart Disease Risk")
|
| 301 |
+
|
| 302 |
+
if submitted:
|
| 303 |
+
risk_level, prediction_proba, color, icon, advice, risk_details, risk_params = calculate_risk(inputs)
|
| 304 |
+
|
| 305 |
+
st.subheader("Risk Assessment Results")
|
| 306 |
+
cols = st.columns(3)
|
| 307 |
+
with cols[1]:
|
| 308 |
+
st.metric("Heart Disease Risk", f"{prediction_proba * 100:.1f}%", f"{risk_level} Risk", delta_color="off")
|
| 309 |
+
risk_meter = st.progress(prediction_proba)
|
| 310 |
+
st.markdown(f"""
|
| 311 |
+
<div style="background-color:#f0f2f6;padding:15px;border-radius:10px">
|
| 312 |
+
<h4 style="color:{color};text-align:center">{icon} {risk_level} Risk Category</h4>
|
| 313 |
+
<p style="text-align:center">{advice}</p>
|
| 314 |
+
</div>
|
| 315 |
+
""", unsafe_allow_html=True)
|
| 316 |
+
|
| 317 |
+
st.subheader("Risk Factor Analysis")
|
| 318 |
+
total_factors = 13
|
| 319 |
+
risk_factors = len(risk_details)
|
| 320 |
+
st.write(f"**{risk_factors} out of {total_factors}** parameters show elevated risk:")
|
| 321 |
+
if risk_details:
|
| 322 |
+
st.markdown("\n".join(risk_details))
|
| 323 |
+
else:
|
| 324 |
+
st.success("All parameters are within healthy ranges!")
|
| 325 |
+
|
| 326 |
+
st.session_state.results = {
|
| 327 |
+
"risk_level": risk_level,
|
| 328 |
+
"probability": prediction_proba,
|
| 329 |
+
"risk_factors": risk_details,
|
| 330 |
+
"risk_params": risk_params,
|
| 331 |
+
"inputs": inputs
|
| 332 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
with tab2:
|
| 335 |
+
st.header("Parameter Analysis")
|
| 336 |
+
if "results" not in st.session_state:
|
| 337 |
+
st.warning("Please complete the risk assessment first.")
|
| 338 |
+
else:
|
| 339 |
+
st.subheader("Health Parameter Radar Chart")
|
| 340 |
+
radar_params = ["age", "trestbps", "chol", "thalach", "oldpeak"]
|
| 341 |
+
radar_labels = [feature_info[p]["label"] for p in radar_params]
|
| 342 |
+
radar_values = [st.session_state.results["inputs"][p] for p in radar_params]
|
| 343 |
+
|
| 344 |
+
max_values = {
|
| 345 |
+
"age": 100,
|
| 346 |
+
"trestbps": 200,
|
| 347 |
+
"chol": 600,
|
| 348 |
+
"thalach": 220,
|
| 349 |
+
"oldpeak": 6.0
|
| 350 |
+
}
|
| 351 |
+
normalized_values = [v / max_values[p] for v, p in zip(radar_values, radar_params)]
|
| 352 |
+
angles = np.linspace(0, 2 * np.pi, len(radar_labels), endpoint=False).tolist()
|
| 353 |
+
angles += angles[:1]
|
| 354 |
+
normalized_values += normalized_values[:1]
|
| 355 |
+
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
| 356 |
+
ax.fill(angles, normalized_values, color='red', alpha=0.25)
|
| 357 |
+
ax.plot(angles, normalized_values, color='red', linewidth=2)
|
| 358 |
+
ax.set_xticks(angles[:-1])
|
| 359 |
+
ax.set_xticklabels(radar_labels)
|
| 360 |
+
ax.set_yticklabels([])
|
| 361 |
+
ax.set_title("Your Health Parameters (Normalized)", pad=20)
|
| 362 |
+
st.pyplot(fig)
|
| 363 |
+
|
| 364 |
+
st.subheader("Parameter Comparison")
|
| 365 |
+
comparison_data = []
|
| 366 |
+
for param in ["age", "trestbps", "chol", "thalach", "oldpeak"]:
|
| 367 |
+
info = feature_info[param]
|
| 368 |
+
status = "⚠️ Out of Range" if param in st.session_state.results["risk_params"] else "✅ Within Range"
|
| 369 |
+
comparison_data.append({
|
| 370 |
+
"Parameter": info["label"],
|
| 371 |
+
"Your Value": st.session_state.results["inputs"][param],
|
| 372 |
+
"Healthy Range": info.get("healthy_range", info.get("healthy", "")),
|
| 373 |
+
"Status": status
|
| 374 |
+
})
|
| 375 |
+
st.table(pd.DataFrame(comparison_data))
|
| 376 |
+
|
| 377 |
+
with tab3:
|
| 378 |
+
st.header("Personalized Health Guidance")
|
| 379 |
+
if "results" not in st.session_state:
|
| 380 |
+
st.warning("Please complete the risk assessment first.")
|
| 381 |
+
else:
|
| 382 |
+
risk_level = st.session_state.results["risk_level"]
|
| 383 |
+
|
| 384 |
+
# Basic guidance
|
| 385 |
+
if risk_level == "Low":
|
| 386 |
+
st.success("""
|
| 387 |
+
**Your heart health looks good!**
|
| 388 |
+
Maintain these healthy habits:
|
| 389 |
+
- Continue regular exercise (150 mins/week moderate activity)
|
| 390 |
+
- Eat a balanced diet rich in fruits, vegetables, and whole grains
|
| 391 |
+
- Get 7-9 hours of quality sleep nightly
|
| 392 |
+
- Manage stress through relaxation techniques
|
| 393 |
+
- Annual check-ups with your doctor
|
| 394 |
+
""")
|
| 395 |
+
elif risk_level == "Moderate":
|
| 396 |
+
st.warning("""
|
| 397 |
+
**Your heart health needs attention**
|
| 398 |
+
Recommended actions:
|
| 399 |
+
- Increase physical activity (aim for 30 mins/day)
|
| 400 |
+
- Reduce sodium (<1500mg/day) and saturated fat intake
|
| 401 |
+
- Quit smoking if applicable
|
| 402 |
+
- Limit alcohol to 1 drink/day (women) or 2/day (men)
|
| 403 |
+
- Monitor blood pressure regularly
|
| 404 |
+
- Schedule a doctor's visit within 3 months
|
| 405 |
+
- Consider cholesterol screening
|
| 406 |
+
""")
|
| 407 |
+
else:
|
| 408 |
+
st.error("""
|
| 409 |
+
**Your heart health requires immediate attention**
|
| 410 |
+
Critical next steps:
|
| 411 |
+
- Consult a cardiologist within 1 month
|
| 412 |
+
- Begin a supervised exercise program
|
| 413 |
+
- Strict dietary modifications (Mediterranean diet recommended)
|
| 414 |
+
- Medication may be needed (doctor will advise)
|
| 415 |
+
- Regular monitoring of blood pressure and cholesterol
|
| 416 |
+
- Consider cardiac rehabilitation program
|
| 417 |
+
- Emergency care if chest pain or shortness of breath occurs
|
| 418 |
+
""")
|
| 419 |
+
|
| 420 |
+
# AI-Powered Recommendations
|
| 421 |
+
st.subheader("AI-Powered Recommendations")
|
| 422 |
+
if st.button("Get Personalized AI Advice"):
|
| 423 |
+
with st.spinner("Generating AI recommendations..."):
|
| 424 |
+
try:
|
| 425 |
+
# Prepare the prompt for Mistral API
|
| 426 |
+
prompt = f"""
|
| 427 |
+
Provide detailed, personalized heart health recommendations for a patient with the following profile:
|
| 428 |
+
- Age: {st.session_state.results["inputs"]["age"]}
|
| 429 |
+
- Sex: {st.session_state.results["inputs"]["sex"]}
|
| 430 |
+
- Risk Level: {risk_level}
|
| 431 |
+
- Key Risk Factors: {', '.join(st.session_state.results["risk_params"])}
|
| 432 |
+
|
| 433 |
+
Provide specific, actionable advice on:
|
| 434 |
+
1. Lifestyle modifications
|
| 435 |
+
2. Dietary recommendations
|
| 436 |
+
3. Exercise suggestions
|
| 437 |
+
4. Monitoring suggestions
|
| 438 |
+
5. When to seek medical attention
|
| 439 |
+
|
| 440 |
+
Format the response with clear headings and bullet points.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
# Call Mistral API
|
| 444 |
+
headers = {
|
| 445 |
+
"Authorization": f"Bearer {MISTRAL_API_KEY}",
|
| 446 |
+
"Content-Type": "application/json"
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
data = {
|
| 450 |
+
"model": "mistral-tiny",
|
| 451 |
+
"messages": [{"role": "user", "content": prompt}]
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
response = requests.post(
|
| 455 |
+
"https://api.mistral.ai/v1/chat/completions",
|
| 456 |
+
headers=headers,
|
| 457 |
+
json=data
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if response.status_code == 200:
|
| 461 |
+
ai_response = response.json()["choices"][0]["message"]["content"]
|
| 462 |
+
st.markdown("### AI Health Advisor Recommendations")
|
| 463 |
+
st.write(ai_response)
|
| 464 |
+
else:
|
| 465 |
+
st.error(f"Failed to get AI recommendations. Status code: {response.status_code}")
|
| 466 |
+
|
| 467 |
+
except Exception as e:
|
| 468 |
+
st.error(f"An error occurred: {str(e)}")
|
| 469 |
+
|
| 470 |
+
# Sidebar with sample profiles
|
| 471 |
+
st.sidebar.header("Quick Start")
|
| 472 |
+
profile = st.sidebar.selectbox("Load sample profile:", list(SAMPLE_PROFILES.keys()))
|
| 473 |
+
if st.sidebar.button("Load Profile"):
|
| 474 |
+
for key in list(st.session_state.keys()):
|
| 475 |
+
if key in feature_info:
|
| 476 |
+
del st.session_state[key]
|
| 477 |
+
for key, value in SAMPLE_PROFILES[profile].items():
|
| 478 |
+
st.session_state[key] = value
|
| 479 |
+
st.rerun()
|
| 480 |
+
|
| 481 |
+
# Disclaimer
|
| 482 |
+
st.sidebar.markdown("""
|
| 483 |
+
### Important Disclaimer
|
| 484 |
+
This tool provides risk assessment only and is not a substitute for professional medical advice.
|
| 485 |
+
Always consult with a qualified healthcare provider for medical concerns.
|
| 486 |
+
""")
|