Spaces:
Sleeping
Sleeping
File size: 7,464 Bytes
f7d11f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import httpx
from loguru import logger
import pandas as pd
from predicting_outcomes_in_heart_failure.config import API_URL, FIGURES_DIR
async def _fetch_api(endpoint: str):
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"{API_URL}/{endpoint}")
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error fetching {endpoint}: {e}")
return {"error": str(e)}
class Wrapper:
async def prediction_with_explanation(
age,
chest_pain_type,
resting_bp,
cholesterol,
fasting_bs,
resting_ecg,
max_hr,
exercise_angina,
oldpeak,
st_slope,
):
payload = {
"Age": age,
"ChestPainType": chest_pain_type,
"RestingBP": resting_bp,
"Cholesterol": cholesterol,
"FastingBS": fasting_bs,
"RestingECG": resting_ecg,
"MaxHR": max_hr,
"ExerciseAngina": exercise_angina,
"Oldpeak": round(oldpeak, 2),
"ST_Slope": st_slope,
}
async with httpx.AsyncClient() as client:
try:
pred_resp = await client.post(f"{API_URL}/predictions", json=payload)
pred_resp.raise_for_status()
pred_json = pred_resp.json()
prediction_value = pred_json["data"]["prediction"]
status = "🆘 Risk Detected" if prediction_value == 1 else "✅ No Risk Detected"
status_text = f"# Patient's status: {status}"
except Exception as e:
logger.error(f"Error making prediction: {e}")
return f"Error during prediction: {str(e)}", ""
try:
expl_resp = await client.post(f"{API_URL}/explanations", json=payload)
expl_resp.raise_for_status()
expl_json = expl_resp.json()
plot_rel_url = expl_json["data"].get("explanation_plot_url")
if not plot_rel_url:
logger.warning("No explanation_plot_url found in /explanations response.")
return status_text, ""
filename = plot_rel_url.split("/")[-1]
plot_path = FIGURES_DIR / filename
return status_text, str(plot_path)
except Exception as e:
logger.error(f"Error getting explanation: {e}")
return status_text, ""
async def batch_prediction(file):
async with httpx.AsyncClient(timeout=30.0) as client:
try:
df = pd.read_csv(file)
payload = []
for _, row in df.iterrows():
sample = {
"Age": int(row["Age"]),
"ChestPainType": row["ChestPainType"],
"RestingBP": int(row["RestingBP"]),
"Cholesterol": int(row["Cholesterol"]),
"FastingBS": int(row["FastingBS"]),
"RestingECG": row["RestingECG"],
"MaxHR": int(row["MaxHR"]),
"ExerciseAngina": row["ExerciseAngina"],
"Oldpeak": round(float(row["Oldpeak"]), 2),
"ST_Slope": row["ST_Slope"],
}
payload.append(sample)
response = await client.post(f"{API_URL}/batch-predictions", json=payload)
response.raise_for_status()
result = response.json()
results = result["data"]["results"]
df_results = pd.DataFrame(
[
{
"Patients's index": r["index"],
"Patient's status": "🆘 Risk Detected"
if r["prediction"] == 1
else "✅ No Risk Detected",
}
for r in results
]
)
return df_results
except Exception as e:
logger.error(f"Error making batch prediction: {e}")
return pd.DataFrame({"error": [str(e)]})
async def get_model_card():
data = await _fetch_api("cards/model_card")
card_lines = data.get("data").get("card_lines")
return "\n".join(card_lines)
async def get_dataset_card():
data = await _fetch_api("cards/dataset_card")
card_lines = data.get("data").get("card_lines")
return "\n".join(card_lines)
async def get_hyperparameters():
data = await _fetch_api("model/hyperparameters")
if "error" in data:
return f"## Error\n{data['error']}"
data = data.get("data", {}).get("hyperparameters", {}).get("cv", {})
md = ""
for key, value in data.items():
md += f"- **{key}**: {value}\n"
return md
async def get_metrics():
data = await _fetch_api("model/metrics")
if "error" in data:
return f"## Error\n{data['error']}"
metrics = data.get("data", {}).get("metrics", {})
if not metrics:
return "## No metrics found"
md = ""
for key, value in metrics.items():
md += f"- **{key}**: {value:.4f}\n"
return md
async def batch_explanation(file, patient_index: int):
"""Return SHAP plot (filepath) for a specific patient in the uploaded CSV."""
try:
df = pd.read_csv(file)
except Exception as e:
logger.error(f"Error reading CSV for batch explanation: {e}")
return None
try:
idx = int(patient_index)
except (TypeError, ValueError):
logger.error(f"Invalid patient_index: {patient_index}")
return None
if idx < 0 or idx >= len(df):
logger.error(f"patient_index {idx} out of range (0..{len(df) - 1})")
return None
row = df.iloc[idx]
payload = {
"Age": int(row["Age"]),
"ChestPainType": row["ChestPainType"],
"RestingBP": int(row["RestingBP"]),
"Cholesterol": int(row["Cholesterol"]),
"FastingBS": int(row["FastingBS"]),
"RestingECG": row["RestingECG"],
"MaxHR": int(row["MaxHR"]),
"ExerciseAngina": row["ExerciseAngina"],
"Oldpeak": round(float(row["Oldpeak"]), 2),
"ST_Slope": row["ST_Slope"],
}
async with httpx.AsyncClient() as client:
try:
expl_resp = await client.post(f"{API_URL}/explanations", json=payload)
expl_resp.raise_for_status()
expl_json = expl_resp.json()
plot_rel_url = expl_json["data"].get("explanation_plot_url")
if not plot_rel_url:
logger.warning(
"No explanation_plot_url found in /explanations response (batch)."
)
return None
filename = plot_rel_url.split("/")[-1]
plot_path = FIGURES_DIR / filename
return str(plot_path)
except Exception as e:
logger.error(f"Error getting batch explanation: {e}")
return None
|