SHIKARICHACHA commited on
Commit
6ffc385
·
verified ·
1 Parent(s): 516f715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +473 -384
app.py CHANGED
@@ -1,397 +1,486 @@
1
- """
2
- Adaptive Music Exercise Generator – Integer Durations Only
3
- ========================================================
4
- 1 unit = 1 quarter-note (integer only)
5
- """
6
-
7
- # ------------------------------------------------------------------
8
- # 0. Auto-install deps (unchanged)
9
- # ------------------------------------------------------------------
10
- import sys, subprocess
11
- from typing import Dict, Optional, Tuple, List
12
-
13
- def install(pkgs: List[str]):
14
- for p in pkgs:
15
- try: __import__(p)
16
- except ImportError:
17
- subprocess.check_call([sys.executable, "-m", "pip", "install", p])
18
-
19
- install(["mido", "midi2audio", "pydub", "gradio", "requests"])
20
-
21
- # ------------------------------------------------------------------
22
- # 1. Imports
23
- # ------------------------------------------------------------------
24
- import random, requests, json, tempfile, mido, re, os, shutil, gradio as gr
25
- from mido import Message, MidiFile, MidiTrack, MetaMessage
26
- from midi2audio import FluidSynth
27
- from pydub import AudioSegment
28
-
29
- # ------------------------------------------------------------------
30
- # 2. Config
31
- # ------------------------------------------------------------------
32
- MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
33
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
34
 
35
- SOUNDFONT_URLS = {
36
- "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
37
- "Piano": "https://musical-artifacts.com/artifacts/2719/GeneralUser_GS_1.471.sf2",
38
- "Violin": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
39
- "Clarinet": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
40
- "Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
- TICKS_PER_BEAT = 480
44
- TICKS_PER_UNIT = TICKS_PER_BEAT
45
-
46
- # ------------------------------------------------------------------
47
- # 3. Note helpers – fixed regex & cleaner
48
- # ------------------------------------------------------------------
49
- NOTE_MAP = {"C":0,"C#":1,"DB":1,"D":2,"D#":3,"EB":3,"E":4,"F":5,"F#":6,"GB":6,
50
- "G":7,"G#":8,"AB":8,"A":9,"A#":10,"BB":10,"B":11}
51
-
52
- def sanitise_note_name(n: str) -> str:
53
- """Convert common variations to canonical form."""
54
- n = n.strip().upper()
55
- # Map full words
56
- n = re.sub(r'\bFLAT\b', 'B', n, flags=re.I)
57
- n = re.sub(r'\bSHARP\b', '#', n, flags=re.I)
58
- # Map dash/space forms
59
- n = re.sub(r'([A-G])\s*-\s*FLAT', r'\1B', n, flags=re.I)
60
- n = re.sub(r'([A-G])\s*-\s*SHARP', r'\1#', n, flags=re.I)
61
- return n
62
-
63
- def note_name_to_midi(n: str) -> int:
64
- n = sanitise_note_name(n)
65
- m = re.fullmatch(r"([A-G][#B]?)(\d)", n)
66
- if not m:
67
- raise ValueError(f"Invalid note: {n}")
68
- pitch, octave = m.groups()
69
- pitch = pitch.replace('B', 'B') # keep BB as B-flat
70
- if pitch not in NOTE_MAP:
71
- raise ValueError(f"Invalid pitch: {pitch}")
72
- return NOTE_MAP[pitch] + (int(octave) + 1) * 12
73
-
74
- # ------------------------------------------------------------------
75
- # 4. Integer-only scaler – unchanged
76
- # ------------------------------------------------------------------
77
- def scale_json_durations(json_data, target_units: int) -> list:
78
- ints = [[n, max(1, int(round(d)))] for n, d in json_data]
79
- total = sum(d for _, d in ints)
80
- deficit = target_units - total
81
- if deficit > 0:
82
- for i in range(deficit):
83
- ints[i % len(ints)][1] += 1
84
- elif deficit < 0:
85
- for i in range(-deficit):
86
- if ints[i % len(ints)][1] > 1:
87
- ints[i % len(ints)][1] -= 1
88
- return ints
89
-
90
- # ------------------------------------------------------------------
91
- # 5. JSON → MIDI – unchanged
92
- # ------------------------------------------------------------------
93
- def json_to_midi(json_data, instrument, tempo, time_signature, measures):
94
- mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
95
- track = MidiTrack(); mid.tracks.append(track)
96
- program = INSTRUMENT_PROGRAMS.get(instrument, 56)
97
- num, denom = map(int, time_signature.split('/'))
98
-
99
- track.append(MetaMessage('time_signature', numerator=num, denominator=denom, time=0))
100
- track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
101
- track.append(Message('program_change', program=program, time=0))
102
-
103
- for note_name, dur_units in json_data:
104
- note_num = note_name_to_midi(note_name)
105
- ticks = max(int(dur_units * TICKS_PER_UNIT), 1)
106
- vel = random.randint(60, 100)
107
- track.append(Message('note_on', note=note_num, velocity=vel, time=0))
108
- track.append(Message('note_off', note=note_num, velocity=vel, time=ticks))
109
- return mid
110
-
111
- # ------------------------------------------------------------------
112
- # 6. Prompt helpers – always include strict format
113
- # ------------------------------------------------------------------
114
- INSTRUMENT_PROGRAMS = {"Piano":0,"Trumpet":56,"Violin":40,"Clarinet":71,"Flute":73}
115
 
116
- def fallback(instrument, level, key, time_sig, measures):
117
- patterns = {
118
- "Trumpet":["C4","D4","E4","G4"],
119
- "Piano":["C4","E4","G4","C5"],
120
- "Violin":["G4","A4","B4","D5"],
121
- "Clarinet":["E4","F4","G4","Bb4"],
122
- "Flute":["A4","B4","C5","E5"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  }
124
- pat = patterns.get(instrument, patterns["Trumpet"])
125
- numerator = int(time_sig.split('/')[0])
126
- total = measures * numerator
127
- notes, durs = [], []
128
- i = 0
129
- while sum(durs) < total:
130
- notes.append(pat[i % len(pat)])
131
- durs.append(1)
132
- i += 1
133
- if sum(durs) > total:
134
- durs[-1] -= sum(durs) - total
135
- return json.dumps([[n,d] for n,d in zip(notes,durs)])
136
-
137
- # ------------------------------------------------------------------
138
- # 7. Mistral query – strengthened prompt
139
- # ------------------------------------------------------------------
140
- def query_mistral(prompt, instrument, level, key, time_sig, measures):
141
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
142
- numerator = int(time_sig.split('/')[0])
143
- target_units = measures * numerator
144
-
145
- strict_format = (
146
- "Use ONLY standard note names like 'C4', 'F#5', 'Bb3'. "
147
- "Use ONLY integer durations representing quarter-note beats: "
148
- "1 = quarter, 2 = half, 4 = whole. "
149
- f"Sum MUST equal exactly {target_units}. "
150
- "Output ONLY a JSON array of [note, integer_duration] pairs. "
151
- "No prose, no explanation."
152
- )
153
 
154
- system = f"You are an expert {instrument.lower()} teacher."
155
- if prompt.strip():
156
- user = f"{prompt}\n\n{strict_format}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  else:
158
- style = random.choice(["simple","jazzy","technical"])
159
- tech = random.choice(["with long tones","with slurs","with double tonguing"])
160
- user = f"Create a {style} {instrument.lower()} exercise in {key}, {time_sig}, {tech}. {strict_format}"
161
-
162
- payload = {
163
- "model":"mistral-medium",
164
- "messages":[{"role":"system","content":system},{"role":"user","content":user}],
165
- "max_tokens":800,
166
- "temperature":0.6,
167
- }
168
- try:
169
- r = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
170
- r.raise_for_status()
171
- return r.json()["choices"][0]["message"]["content"].strip()
172
- except Exception as e:
173
- print("Mistral error:", e)
174
- return fallback(instrument, level, key, time_sig, measures)
175
-
176
- # ------------------------------------------------------------------
177
- # 8. Safe JSON parse – unchanged
178
- # ------------------------------------------------------------------
179
- def safe_parse(text):
180
- try:
181
- text = text.replace("'",'"')
182
- m = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL)
183
- return json.loads(m.group(0) if m else text)
184
- except Exception:
185
- return None
186
-
187
- # ------------------------------------------------------------------
188
- # 9. Main generation – unchanged
189
- # ------------------------------------------------------------------
190
- def generate_exercise(instrument, level, key, tempo, ts, measures, custom_prompt, mode):
191
- try:
192
- prompt_in = custom_prompt if mode=="Exercise Prompt" else ""
193
- raw = query_mistral(prompt_in, instrument, level, key, ts, measures)
194
- parsed = safe_parse(raw)
195
- if not parsed:
196
- return "Invalid JSON", None, str(tempo), None, "0", ts
197
-
198
- numerator = int(ts.split('/')[0])
199
- target_units = measures * numerator
200
- scaled = scale_json_durations(parsed, target_units)
201
- midi = json_to_midi(scaled, instrument, tempo, ts, measures)
202
-
203
- # MIDI → MP3
204
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as f:
205
- midi.save(f.name)
206
- wav = f.name.replace(".mid",".wav")
207
- mp3 = f.name.replace(".mid",".mp3")
208
- sf2 = get_soundfont(instrument)
209
- import subprocess as sp
210
- sp.run(["fluidsynth","-ni",sf2,f.name,"-F",wav,"-r","44100"], check=True, capture_output=True)
211
- AudioSegment.from_wav(wav).export(mp3, format="mp3")
212
- static_mp3 = os.path.join("static", os.path.basename(mp3))
213
- shutil.move(mp3, static_mp3)
214
- dur = AudioSegment.from_mp3(static_mp3).duration_seconds
215
- return json.dumps(scaled, indent=2), static_mp3, str(tempo), midi, f"{dur:.2f}s", ts
216
- except Exception as e:
217
- return f"Error: {e}", None, str(tempo), None, "0", ts
218
-
219
- def get_soundfont(instrument):
220
- os.makedirs("soundfonts", exist_ok=True)
221
- sf2 = f"soundfonts/{instrument}.sf2"
222
- if not os.path.exists(sf2):
223
- open(sf2,"wb").write(requests.get(SOUNDFONT_URLS[instrument]).content)
224
- return sf2
225
-
226
- # ------------------------------------------------------------------
227
- # 10. Chat helper – unchanged
228
- # ------------------------------------------------------------------
229
- def chat(msg, hist, instr, lvl):
230
- if not msg.strip(): return "", hist
231
- msgs=[{"role":"system","content":f"You are a {instr} teacher for {lvl} students."}]
232
- for u,a in hist: msgs.extend([{"role":"user","content":u},{"role":"assistant","content":a}])
233
- msgs.append({"role":"user","content":msg})
234
- headers={"Authorization":f"Bearer {MISTRAL_API_KEY}","Content-Type":"application/json"}
235
- payload={"model":"mistral-medium","messages":msgs,"max_tokens":500}
236
- try:
237
- r=requests.post(MISTRAL_API_URL,headers=headers,json=payload)
238
- r.raise_for_status()
239
- reply=r.json()["choices"][0]["message"]["content"]
240
- hist.append((msg,reply))
241
- return "", hist
242
- except Exception as e:
243
- hist.append((msg,f"Error: {e}"))
244
- return "", hist
245
-
246
- # ------------------------------------------------------------------
247
- # 11. Gradio UI – unchanged
248
- # ------------------------------------------------------------------
249
- def ui():
250
- with gr.Blocks(title="Adaptive Music Exercise Generator – Integer Durations") as demo:
251
- gr.Markdown("# 🎼 Adaptive Music Exercise Generator – Integer Durations")
252
- mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Mode")
253
- with gr.Row():
254
- with gr.Column(scale=1):
255
- with gr.Group(visible=True) as params_group:
256
- instrument=gr.Dropdown(["Trumpet","Piano","Violin","Clarinet","Flute"],value="Trumpet")
257
- level = gr.Radio(["Beginner","Intermediate","Advanced"], value="Intermediate")
258
- key = gr.Dropdown(["C Major","G Major","D Major","F Major","Bb Major","A Minor","E Minor"],value="C Major")
259
- ts = gr.Dropdown(["3/4","4/4"], value="4/4", label="Time Signature")
260
- measures = gr.Radio([4,8], value=4, label="Length (measures)")
261
- with gr.Group(visible=False) as prompt_group:
262
- prompt_txt = gr.Textbox("", lines=3, label="Custom prompt")
263
- measures2 = gr.Radio([4,8], value=4, label="Length (measures)")
264
- gen_btn = gr.Button("Generate Exercise", variant="primary")
265
- with gr.Column(scale=2):
266
- with gr.Tabs():
267
- with gr.TabItem("Player"):
268
- audio=gr.Audio(autoplay=True)
269
- bpm=gr.Textbox(label="Tempo (BPM)")
270
- ts_disp=gr.Textbox(label="Time Signature")
271
- dur=gr.Textbox(label="Duration")
272
- with gr.TabItem("Data"):
273
- js_out=gr.Code(language="json")
274
- with gr.TabItem("MIDI"):
275
- midi_file=gr.File()
276
- dl_btn=gr.Button("Download MIDI")
277
- with gr.TabItem("Chat"):
278
- chatbot=gr.Chatbot(height=400)
279
- chat_in=gr.Textbox(label="Ask the AI")
280
- chat_btn=gr.Button("Send")
281
-
282
- mode.change(lambda m:{params_group:gr.update(visible=m=="Exercise Parameters"),
283
- prompt_group:gr.update(visible=m=="Exercise Prompt")},
284
- [mode],[params_group,prompt_group])
285
-
286
- def caller(m,i,l,k,t,ms,p,ms2):
287
- real_ms = ms2 if m=="Exercise Prompt" else ms
288
- return generate_exercise(i,l,k,60,t,real_ms,p or "",m)
289
- gen_btn.click(caller,[mode,instrument,level,key,ts,measures,prompt_txt,measures2],
290
- [js_out,audio,bpm,gr.State(),dur,ts_disp])
291
-
292
- def save(json_txt,instr,ts):
293
- data=safe_parse(json_txt)
294
- if not data: return None
295
- num=int(ts.split('/')[0])
296
- target=sum(int(d) for _,d in data)
297
- measures_est=max(1,round(target/num))
298
- scaled=scale_json_durations(data,measures_est*num)
299
- midi=json_to_midi(scaled,instr,60,ts,measures_est)
300
- path="static/exercise.mid"
301
- midi.save(path)
302
- return path
303
- dl_btn.click(save,[js_out,instrument,ts],[midi_file])
304
 
305
- chat_btn.click(chat,[chat_in,chatbot,instrument,level],[chat_in,chatbot])
306
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- if __name__ == "__main__":
309
- ui().launch() response += "⚠️ [When to Seek Medical Care]\n" + "\n".join(["• " + item for item in knowledge["warnings"]]) + "\n\n"
310
- response += "📅 [Follow-up]\n• Re-evaluate in 2-3 days if not improving\n• See doctor if symptoms worsen or persist beyond 5 days"
311
- return response
312
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- def query_mistral(prompt):
315
- """Try API first, fallback to local knowledge"""
316
- try:
317
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
318
- payload = {
319
- "model": "mistral-tiny",
320
- "messages": [{"role": "user", "content": prompt}],
321
- "temperature": 0.7,
322
- "max_tokens": 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  }
324
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload, timeout=10)
325
- response.raise_for_status()
326
- return response.json()['choices'][0]['message']['content']
327
- except:
328
- return None
329
-
330
- def generate_response(name, message):
331
- """Generate medical response with fallback"""
332
- # First try local knowledge base
333
- local_response = get_local_advice(message)
334
- if local_response:
335
- return f"Hello {name},\n\n{local_response}\n\nWishing you good health,\nDr. Alex"
336
-
337
- # Then try API
338
- prompt = f"""As Dr. Alex, provide structured advice for {name} who reports: "{message}"
339
-
340
- Format response with these sections:
341
-
342
- [Assessment] - Brief professional evaluation
343
- [Self-Care Recommendations] - 3-5 specific actionable steps
344
- [When to Seek Medical Care] - Clear warning signs
345
- [Follow-up] - Monitoring advice
346
-
347
- Use bullet points, professional yet compassionate tone."""
348
-
349
- api_response = query_mistral(prompt)
350
- if api_response:
351
- return api_response
352
-
353
- # Final fallback
354
- common_conditions = "\n".join([f"- {cond}" for cond in MEDICAL_KNOWLEDGE.keys()])
355
- return f"""Hello {name},
356
-
357
- I'm currently unable to access detailed medical databases. For general advice:
358
-
359
- Common conditions I can advise on:
360
- {common_conditions}
361
-
362
- For immediate concerns:
363
- • Contact your local healthcare provider
364
- • In emergencies, call your local emergency number
365
-
366
- Please try asking about one of these common conditions or consult a healthcare professional.
367
 
368
- Best regards,
369
- Dr. Alex"""
370
-
371
- # Gradio Interface
372
- with gr.Blocks(title="Dr. Alex Medical Advisor", theme=gr.themes.Soft()) as demo:
373
- gr.Markdown("# 🩺 Dr. Alex - General Health Advisor")
374
- name = gr.Textbox(label="Your Name", placeholder="Enter your name")
375
-
376
- with gr.Tab("Health Consultation"):
377
- gr.Markdown("## Describe your health concern")
378
- msg = gr.Textbox(label="Symptoms/Question", lines=3)
379
- submit_btn = gr.Button("Get Medical Advice")
380
- output = gr.Textbox(label="Dr. Alex's Advice", lines=15, interactive=False)
381
-
382
- submit_btn.click(
383
- fn=lambda name, msg: generate_response(name, msg),
384
- inputs=[name, msg],
385
- outputs=output
386
- )
387
-
388
- gr.Markdown("### Common Health Topics")
389
- with gr.Row():
390
- for condition in MEDICAL_KNOWLEDGE:
391
- gr.Button(condition.capitalize()).click(
392
- fn=lambda name, cond=condition: generate_response(name, cond),
393
- inputs=[name],
394
- outputs=output
395
- )
396
-
397
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import requests
7
+
8
+ # Set page config
9
+ st.set_page_config(
10
+ page_title="Heart Disease Prediction",
11
+ page_icon="❤️",
12
+ layout="wide"
13
+ )
14
+
15
+ # Mistral API Key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
17
 
18
+ # Feature information with medical ranges
19
+ feature_info = {
20
+ "age": {"label": "Age", "desc": "Age in years", "min": 20, "max": 100, "value": 50,
21
+ "healthy_range": "20-60", "risk_range": "60+"},
22
+ "sex": {"label": "Sex", "desc": "Biological sex", "options": ["Female", "Male"]},
23
+ "cp": {"label": "Chest Pain Type", "desc": "Type of chest pain",
24
+ "options": ["Typical angina", "Atypical angina", "Non-anginal pain", "Asymptomatic"],
25
+ "healthy": "Typical angina", "risk": "Asymptomatic"},
26
+ "trestbps": {"label": "Resting BP", "desc": "Resting blood pressure (mm Hg)",
27
+ "min": 80, "max": 200, "value": 120,
28
+ "healthy_range": "90-120", "risk_range": "140+"},
29
+ "chol": {"label": "Cholesterol", "desc": "Serum cholesterol (mg/dl)",
30
+ "min": 100, "max": 600, "value": 200,
31
+ "healthy_range": "< 200", "risk_range": "240+"},
32
+ "fbs": {"label": "Fasting Blood Sugar", "desc": "> 120 mg/dl", "options": ["No", "Yes"],
33
+ "healthy": "No", "risk": "Yes"},
34
+ "restecg": {"label": "Resting ECG", "desc": "Resting electrocardiographic results",
35
+ "options": ["Normal", "ST-T wave abnormality", "Left ventricular hypertrophy"],
36
+ "healthy": "Normal", "risk": "Left ventricular hypertrophy"},
37
+ "thalach": {"label": "Max Heart Rate", "desc": "Maximum heart rate achieved",
38
+ "min": 60, "max": 220, "value": 150,
39
+ "healthy_range": "60-100 (resting)", "risk_range": "< 120 (exercise)"},
40
+ "exang": {"label": "Exercise Angina", "desc": "Exercise induced angina", "options": ["No", "Yes"],
41
+ "healthy": "No", "risk": "Yes"},
42
+ "oldpeak": {"label": "ST Depression", "desc": "ST depression induced by exercise",
43
+ "min": 0.0, "max": 6.0, "value": 1.0, "step": 0.1,
44
+ "healthy_range": "0-1", "risk_range": "2+"},
45
+ "slope": {"label": "ST Slope", "desc": "Slope of peak exercise ST segment",
46
+ "options": ["Upsloping", "Flat", "Downsloping"],
47
+ "healthy": "Upsloping", "risk": "Downsloping"},
48
+ "ca": {"label": "Major Vessels", "desc": "Number of major vessels colored by fluoroscopy",
49
+ "options": ["0", "1", "2", "3"],
50
+ "healthy": "0", "risk": "3"},
51
+ "thal": {"label": "Thalassemia", "desc": "Blood disorder called thalassemia",
52
+ "options": ["Normal", "Fixed defect", "Reversible defect"],
53
+ "healthy": "Normal", "risk": "Fixed defect"}
54
  }
55
 
56
+ # Risk assessment thresholds
57
+ RISK_THRESHOLDS = {
58
+ "Low": 0.3,
59
+ "Moderate": 0.6,
60
+ "High": 1.0
61
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Sample healthy and unhealthy profiles
64
+ SAMPLE_PROFILES = {
65
+ "Healthy": {
66
+ "age": 45, "sex": "Female", "cp": "Typical angina",
67
+ "trestbps": 110, "chol": 180, "fbs": "No",
68
+ "restecg": "Normal", "thalach": 160, "exang": "No",
69
+ "oldpeak": 0.5, "slope": "Upsloping", "ca": "0",
70
+ "thal": "Normal"
71
+ },
72
+ "Moderate Risk": {
73
+ "age": 58, "sex": "Male", "cp": "Atypical angina",
74
+ "trestbps": 135, "chol": 230, "fbs": "No",
75
+ "restecg": "ST-T wave abnormality", "thalach": 140, "exang": "No",
76
+ "oldpeak": 1.5, "slope": "Flat", "ca": "1",
77
+ "thal": "Reversible defect"
78
+ },
79
+ "High Risk": {
80
+ "age": 65, "sex": "Male", "cp": "Asymptomatic",
81
+ "trestbps": 180, "chol": 280, "fbs": "Yes",
82
+ "restecg": "Left ventricular hypertrophy", "thalach": 120, "exang": "Yes",
83
+ "oldpeak": 3.0, "slope": "Downsloping", "ca": "3",
84
+ "thal": "Fixed defect"
85
  }
86
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def calculate_risk(inputs):
89
+ """Calculate heart disease risk based on clinical parameters"""
90
+ risk_factors = 0
91
+ total_possible = 13 # Total parameters being evaluated
92
+ risk_details = []
93
+ risk_params = []
94
+
95
+ if inputs['age'] > 60:
96
+ risk_factors += 1
97
+ risk_details.append(f"- Age: {inputs['age']} (Risk: 60+)")
98
+ risk_params.append('age')
99
+ if inputs['sex'] == 'Male':
100
+ risk_factors += 1
101
+ risk_details.append("- Sex: Male (Higher risk)")
102
+ risk_params.append('sex')
103
+ if inputs['cp'] == 'Asymptomatic':
104
+ risk_factors += 1
105
+ risk_details.append("- Chest Pain: Asymptomatic (Highest risk)")
106
+ risk_params.append('cp')
107
+ elif inputs['cp'] in ['Atypical angina', 'Non-anginal pain']:
108
+ risk_factors += 0.5
109
+ risk_details.append(f"- Chest Pain: {inputs['cp']} (Moderate risk)")
110
+ if inputs['trestbps'] >= 140:
111
+ risk_factors += 1
112
+ risk_details.append(f"- Blood Pressure: {inputs['trestbps']} mmHg (Stage 2 Hypertension)")
113
+ risk_params.append('trestbps')
114
+ elif inputs['trestbps'] >= 130:
115
+ risk_factors += 0.5
116
+ risk_details.append(f"- Blood Pressure: {inputs['trestbps']} mmHg (Elevated)")
117
+ if inputs['chol'] >= 240:
118
+ risk_factors += 1
119
+ risk_details.append(f"- Cholesterol: {inputs['chol']} mg/dL (High)")
120
+ risk_params.append('chol')
121
+ elif inputs['chol'] >= 200:
122
+ risk_factors += 0.5
123
+ risk_details.append(f"- Cholesterol: {inputs['chol']} mg/dL (Borderline High)")
124
+ if inputs['fbs'] == 'Yes':
125
+ risk_factors += 1
126
+ risk_details.append("- Fasting Blood Sugar > 120 mg/dL")
127
+ risk_params.append('fbs')
128
+ if inputs['restecg'] == 'Left ventricular hypertrophy':
129
+ risk_factors += 1
130
+ risk_details.append("- ECG: Left Ventricular Hypertrophy")
131
+ risk_params.append('restecg')
132
+ elif inputs['restecg'] == 'ST-T wave abnormality':
133
+ risk_factors += 0.5
134
+ risk_details.append("- ECG: ST-T Wave Abnormality")
135
+ if inputs['thalach'] < 120:
136
+ risk_factors += 1
137
+ risk_details.append(f"- Max Heart Rate: {inputs['thalach']} (Low)")
138
+ risk_params.append('thalach')
139
+ if inputs['exang'] == 'Yes':
140
+ risk_factors += 1
141
+ risk_details.append("- Exercise Induced Angina: Yes")
142
+ risk_params.append('exang')
143
+ if inputs['oldpeak'] >= 2.0:
144
+ risk_factors += 1
145
+ risk_details.append(f"- ST Depression: {inputs['oldpeak']} (High)")
146
+ risk_params.append('oldpeak')
147
+ elif inputs['oldpeak'] >= 1.0:
148
+ risk_factors += 0.5
149
+ risk_details.append(f"- ST Depression: {inputs['oldpeak']} (Moderate)")
150
+ if inputs['slope'] == 'Downsloping':
151
+ risk_factors += 1
152
+ risk_details.append("- ST Slope: Downsloping (Highest risk)")
153
+ risk_params.append('slope')
154
+ elif inputs['slope'] == 'Flat':
155
+ risk_factors += 0.5
156
+ risk_details.append("- ST Slope: Flat (Moderate risk)")
157
+ if inputs['ca'] == '3':
158
+ risk_factors += 1
159
+ risk_details.append("- Major Vessels: 3 (Highest risk)")
160
+ risk_params.append('ca')
161
+ elif inputs['ca'] in ['1', '2']:
162
+ risk_factors += 0.5
163
+ risk_details.append(f"- Major Vessels: {inputs['ca']} (Moderate risk)")
164
+ if inputs['thal'] == 'Fixed defect':
165
+ risk_factors += 1
166
+ risk_details.append("- Thalassemia: Fixed Defect")
167
+ risk_params.append('thal')
168
+ elif inputs['thal'] == 'Reversible defect':
169
+ risk_factors += 0.5
170
+ risk_details.append("- Thalassemia: Reversible Defect")
171
+
172
+ probability = min(risk_factors / total_possible, 0.99) # Cap at 99%
173
+
174
+ if probability <= RISK_THRESHOLDS['Low']:
175
+ risk_level = 'Low'
176
+ color = "green"
177
+ icon = "✅"
178
+ advice = "Maintain your healthy lifestyle with regular check-ups."
179
+ elif probability <= RISK_THRESHOLDS['Moderate']:
180
+ risk_level = 'Moderate'
181
+ color = "orange"
182
+ icon = "⚠️"
183
+ advice = "Consider lifestyle changes and consult your doctor."
184
  else:
185
+ risk_level = 'High'
186
+ color = "red"
187
+ icon = ""
188
+ advice = "Please consult a cardiologist immediately."
189
+
190
+ return risk_level, probability, color, icon, advice, risk_details, risk_params
191
+
192
+ # App layout
193
+ st.title("❤️ Heart Disease Risk Assessment")
194
+ st.markdown("""
195
+ This tool evaluates your heart disease risk based on clinical parameters and provides personalized insights.
196
+ """)
197
+
198
+ # Create tabs
199
+ tab1, tab2, tab3 = st.tabs(["Risk Assessment", "Parameter Analysis", "Health Guidance"])
200
+
201
+ with tab1:
202
+ with st.form("heart_form"):
203
+ st.subheader("Enter Your Health Parameters")
204
+ col1, col2, col3 = st.columns(3)
205
+ inputs = {}
206
+
207
+ # Column 1
208
+ with col1:
209
+ inputs["age"] = st.number_input(
210
+ feature_info["age"]["label"],
211
+ min_value=feature_info["age"]["min"],
212
+ max_value=feature_info["age"]["max"],
213
+ value=st.session_state.get("age", feature_info["age"]["value"]),
214
+ help=f"{feature_info['age']['desc']}. Healthy range: {feature_info['age']['healthy_range']}"
215
+ )
216
+ inputs["sex"] = st.selectbox(
217
+ feature_info["sex"]["label"],
218
+ feature_info["sex"]["options"],
219
+ index=feature_info["sex"]["options"].index(st.session_state.get("sex", feature_info["sex"]["options"][0])),
220
+ help=feature_info["sex"]["desc"]
221
+ )
222
+ inputs["cp"] = st.selectbox(
223
+ feature_info["cp"]["label"],
224
+ feature_info["cp"]["options"],
225
+ index=feature_info["cp"]["options"].index(st.session_state.get("cp", feature_info["cp"]["options"][0])),
226
+ help=f"{feature_info['cp']['desc']}. Healthy: {feature_info['cp']['healthy']}"
227
+ )
228
+ inputs["trestbps"] = st.number_input(
229
+ feature_info["trestbps"]["label"],
230
+ min_value=feature_info["trestbps"]["min"],
231
+ max_value=feature_info["trestbps"]["max"],
232
+ value=st.session_state.get("trestbps", feature_info["trestbps"]["value"]),
233
+ help=f"{feature_info['trestbps']['desc']}. Healthy: {feature_info['trestbps']['healthy_range']}"
234
+ )
235
+ inputs["chol"] = st.number_input(
236
+ feature_info["chol"]["label"],
237
+ min_value=feature_info["chol"]["min"],
238
+ max_value=feature_info["chol"]["max"],
239
+ value=st.session_state.get("chol", feature_info["chol"]["value"]),
240
+ help=f"{feature_info['chol']['desc']}. Healthy: {feature_info['chol']['healthy_range']}"
241
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Column 2
244
+ with col2:
245
+ inputs["fbs"] = st.selectbox(
246
+ feature_info["fbs"]["label"],
247
+ feature_info["fbs"]["options"],
248
+ index=feature_info["fbs"]["options"].index(st.session_state.get("fbs", feature_info["fbs"]["options"][0])),
249
+ help=f"{feature_info['fbs']['desc']}. Healthy: {feature_info['fbs']['healthy']}"
250
+ )
251
+ inputs["restecg"] = st.selectbox(
252
+ feature_info["restecg"]["label"],
253
+ feature_info["restecg"]["options"],
254
+ index=feature_info["restecg"]["options"].index(st.session_state.get("restecg", feature_info["restecg"]["options"][0])),
255
+ help=f"{feature_info['restecg']['desc']}. Healthy: {feature_info['restecg']['healthy']}"
256
+ )
257
+ inputs["thalach"] = st.number_input(
258
+ feature_info["thalach"]["label"],
259
+ min_value=feature_info["thalach"]["min"],
260
+ max_value=feature_info["thalach"]["max"],
261
+ value=st.session_state.get("thalach", feature_info["thalach"]["value"]),
262
+ help=f"{feature_info['thalach']['desc']}. Healthy: {feature_info['thalach']['healthy_range']}"
263
+ )
264
+ inputs["exang"] = st.selectbox(
265
+ feature_info["exang"]["label"],
266
+ feature_info["exang"]["options"],
267
+ index=feature_info["exang"]["options"].index(st.session_state.get("exang", feature_info["exang"]["options"][0])),
268
+ help=f"{feature_info['exang']['desc']}. Healthy: {feature_info['exang']['healthy']}"
269
+ )
270
+ inputs["oldpeak"] = st.number_input(
271
+ feature_info["oldpeak"]["label"],
272
+ min_value=feature_info["oldpeak"]["min"],
273
+ max_value=feature_info["oldpeak"]["max"],
274
+ value=st.session_state.get("oldpeak", feature_info["oldpeak"]["value"]),
275
+ step=feature_info["oldpeak"]["step"],
276
+ help=f"{feature_info['oldpeak']['desc']}. Healthy: {feature_info['oldpeak']['healthy_range']}"
277
+ )
278
 
279
+ # Column 3
280
+ with col3:
281
+ inputs["slope"] = st.selectbox(
282
+ feature_info["slope"]["label"],
283
+ feature_info["slope"]["options"],
284
+ index=feature_info["slope"]["options"].index(st.session_state.get("slope", feature_info["slope"]["options"][0])),
285
+ help=f"{feature_info['slope']['desc']}. Healthy: {feature_info['slope']['healthy']}"
286
+ )
287
+ inputs["ca"] = st.selectbox(
288
+ feature_info["ca"]["label"],
289
+ feature_info["ca"]["options"],
290
+ index=feature_info["ca"]["options"].index(st.session_state.get("ca", feature_info["ca"]["options"][0])),
291
+ help=f"{feature_info['ca']['desc']}. Healthy: {feature_info['ca']['healthy']}"
292
+ )
293
+ inputs["thal"] = st.selectbox(
294
+ feature_info["thal"]["label"],
295
+ feature_info["thal"]["options"],
296
+ index=feature_info["thal"]["options"].index(st.session_state.get("thal", feature_info["thal"]["options"][0])),
297
+ help=f"{feature_info['thal']['desc']}. Healthy: {feature_info['thal']['healthy']}"
298
+ )
299
 
300
+ submitted = st.form_submit_button("Assess My Heart Disease Risk")
301
+
302
+ if submitted:
303
+ risk_level, prediction_proba, color, icon, advice, risk_details, risk_params = calculate_risk(inputs)
304
+
305
+ st.subheader("Risk Assessment Results")
306
+ cols = st.columns(3)
307
+ with cols[1]:
308
+ st.metric("Heart Disease Risk", f"{prediction_proba * 100:.1f}%", f"{risk_level} Risk", delta_color="off")
309
+ risk_meter = st.progress(prediction_proba)
310
+ st.markdown(f"""
311
+ <div style="background-color:#f0f2f6;padding:15px;border-radius:10px">
312
+ <h4 style="color:{color};text-align:center">{icon} {risk_level} Risk Category</h4>
313
+ <p style="text-align:center">{advice}</p>
314
+ </div>
315
+ """, unsafe_allow_html=True)
316
+
317
+ st.subheader("Risk Factor Analysis")
318
+ total_factors = 13
319
+ risk_factors = len(risk_details)
320
+ st.write(f"**{risk_factors} out of {total_factors}** parameters show elevated risk:")
321
+ if risk_details:
322
+ st.markdown("\n".join(risk_details))
323
+ else:
324
+ st.success("All parameters are within healthy ranges!")
325
+
326
+ st.session_state.results = {
327
+ "risk_level": risk_level,
328
+ "probability": prediction_proba,
329
+ "risk_factors": risk_details,
330
+ "risk_params": risk_params,
331
+ "inputs": inputs
332
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ with tab2:
335
+ st.header("Parameter Analysis")
336
+ if "results" not in st.session_state:
337
+ st.warning("Please complete the risk assessment first.")
338
+ else:
339
+ st.subheader("Health Parameter Radar Chart")
340
+ radar_params = ["age", "trestbps", "chol", "thalach", "oldpeak"]
341
+ radar_labels = [feature_info[p]["label"] for p in radar_params]
342
+ radar_values = [st.session_state.results["inputs"][p] for p in radar_params]
343
+
344
+ max_values = {
345
+ "age": 100,
346
+ "trestbps": 200,
347
+ "chol": 600,
348
+ "thalach": 220,
349
+ "oldpeak": 6.0
350
+ }
351
+ normalized_values = [v / max_values[p] for v, p in zip(radar_values, radar_params)]
352
+ angles = np.linspace(0, 2 * np.pi, len(radar_labels), endpoint=False).tolist()
353
+ angles += angles[:1]
354
+ normalized_values += normalized_values[:1]
355
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
356
+ ax.fill(angles, normalized_values, color='red', alpha=0.25)
357
+ ax.plot(angles, normalized_values, color='red', linewidth=2)
358
+ ax.set_xticks(angles[:-1])
359
+ ax.set_xticklabels(radar_labels)
360
+ ax.set_yticklabels([])
361
+ ax.set_title("Your Health Parameters (Normalized)", pad=20)
362
+ st.pyplot(fig)
363
+
364
+ st.subheader("Parameter Comparison")
365
+ comparison_data = []
366
+ for param in ["age", "trestbps", "chol", "thalach", "oldpeak"]:
367
+ info = feature_info[param]
368
+ status = "⚠️ Out of Range" if param in st.session_state.results["risk_params"] else "✅ Within Range"
369
+ comparison_data.append({
370
+ "Parameter": info["label"],
371
+ "Your Value": st.session_state.results["inputs"][param],
372
+ "Healthy Range": info.get("healthy_range", info.get("healthy", "")),
373
+ "Status": status
374
+ })
375
+ st.table(pd.DataFrame(comparison_data))
376
+
377
+ with tab3:
378
+ st.header("Personalized Health Guidance")
379
+ if "results" not in st.session_state:
380
+ st.warning("Please complete the risk assessment first.")
381
+ else:
382
+ risk_level = st.session_state.results["risk_level"]
383
+
384
+ # Basic guidance
385
+ if risk_level == "Low":
386
+ st.success("""
387
+ **Your heart health looks good!**
388
+ Maintain these healthy habits:
389
+ - Continue regular exercise (150 mins/week moderate activity)
390
+ - Eat a balanced diet rich in fruits, vegetables, and whole grains
391
+ - Get 7-9 hours of quality sleep nightly
392
+ - Manage stress through relaxation techniques
393
+ - Annual check-ups with your doctor
394
+ """)
395
+ elif risk_level == "Moderate":
396
+ st.warning("""
397
+ **Your heart health needs attention**
398
+ Recommended actions:
399
+ - Increase physical activity (aim for 30 mins/day)
400
+ - Reduce sodium (<1500mg/day) and saturated fat intake
401
+ - Quit smoking if applicable
402
+ - Limit alcohol to 1 drink/day (women) or 2/day (men)
403
+ - Monitor blood pressure regularly
404
+ - Schedule a doctor's visit within 3 months
405
+ - Consider cholesterol screening
406
+ """)
407
+ else:
408
+ st.error("""
409
+ **Your heart health requires immediate attention**
410
+ Critical next steps:
411
+ - Consult a cardiologist within 1 month
412
+ - Begin a supervised exercise program
413
+ - Strict dietary modifications (Mediterranean diet recommended)
414
+ - Medication may be needed (doctor will advise)
415
+ - Regular monitoring of blood pressure and cholesterol
416
+ - Consider cardiac rehabilitation program
417
+ - Emergency care if chest pain or shortness of breath occurs
418
+ """)
419
+
420
+ # AI-Powered Recommendations
421
+ st.subheader("AI-Powered Recommendations")
422
+ if st.button("Get Personalized AI Advice"):
423
+ with st.spinner("Generating AI recommendations..."):
424
+ try:
425
+ # Prepare the prompt for Mistral API
426
+ prompt = f"""
427
+ Provide detailed, personalized heart health recommendations for a patient with the following profile:
428
+ - Age: {st.session_state.results["inputs"]["age"]}
429
+ - Sex: {st.session_state.results["inputs"]["sex"]}
430
+ - Risk Level: {risk_level}
431
+ - Key Risk Factors: {', '.join(st.session_state.results["risk_params"])}
432
+
433
+ Provide specific, actionable advice on:
434
+ 1. Lifestyle modifications
435
+ 2. Dietary recommendations
436
+ 3. Exercise suggestions
437
+ 4. Monitoring suggestions
438
+ 5. When to seek medical attention
439
+
440
+ Format the response with clear headings and bullet points.
441
+ """
442
+
443
+ # Call Mistral API
444
+ headers = {
445
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
446
+ "Content-Type": "application/json"
447
+ }
448
+
449
+ data = {
450
+ "model": "mistral-tiny",
451
+ "messages": [{"role": "user", "content": prompt}]
452
+ }
453
+
454
+ response = requests.post(
455
+ "https://api.mistral.ai/v1/chat/completions",
456
+ headers=headers,
457
+ json=data
458
+ )
459
+
460
+ if response.status_code == 200:
461
+ ai_response = response.json()["choices"][0]["message"]["content"]
462
+ st.markdown("### AI Health Advisor Recommendations")
463
+ st.write(ai_response)
464
+ else:
465
+ st.error(f"Failed to get AI recommendations. Status code: {response.status_code}")
466
+
467
+ except Exception as e:
468
+ st.error(f"An error occurred: {str(e)}")
469
+
470
+ # Sidebar with sample profiles
471
+ st.sidebar.header("Quick Start")
472
+ profile = st.sidebar.selectbox("Load sample profile:", list(SAMPLE_PROFILES.keys()))
473
+ if st.sidebar.button("Load Profile"):
474
+ for key in list(st.session_state.keys()):
475
+ if key in feature_info:
476
+ del st.session_state[key]
477
+ for key, value in SAMPLE_PROFILES[profile].items():
478
+ st.session_state[key] = value
479
+ st.rerun()
480
+
481
+ # Disclaimer
482
+ st.sidebar.markdown("""
483
+ ### Important Disclaimer
484
+ This tool provides risk assessment only and is not a substitute for professional medical advice.
485
+ Always consult with a qualified healthcare provider for medical concerns.
486
+ """)