Martinacap02's picture
Init deploy branch for HF Space
f7d11f7
import httpx
from loguru import logger
import pandas as pd
from predicting_outcomes_in_heart_failure.config import API_URL, FIGURES_DIR
async def _fetch_api(endpoint: str):
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"{API_URL}/{endpoint}")
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error fetching {endpoint}: {e}")
return {"error": str(e)}
class Wrapper:
async def prediction_with_explanation(
age,
chest_pain_type,
resting_bp,
cholesterol,
fasting_bs,
resting_ecg,
max_hr,
exercise_angina,
oldpeak,
st_slope,
):
payload = {
"Age": age,
"ChestPainType": chest_pain_type,
"RestingBP": resting_bp,
"Cholesterol": cholesterol,
"FastingBS": fasting_bs,
"RestingECG": resting_ecg,
"MaxHR": max_hr,
"ExerciseAngina": exercise_angina,
"Oldpeak": round(oldpeak, 2),
"ST_Slope": st_slope,
}
async with httpx.AsyncClient() as client:
try:
pred_resp = await client.post(f"{API_URL}/predictions", json=payload)
pred_resp.raise_for_status()
pred_json = pred_resp.json()
prediction_value = pred_json["data"]["prediction"]
status = "🆘 Risk Detected" if prediction_value == 1 else "✅ No Risk Detected"
status_text = f"# Patient's status: {status}"
except Exception as e:
logger.error(f"Error making prediction: {e}")
return f"Error during prediction: {str(e)}", ""
try:
expl_resp = await client.post(f"{API_URL}/explanations", json=payload)
expl_resp.raise_for_status()
expl_json = expl_resp.json()
plot_rel_url = expl_json["data"].get("explanation_plot_url")
if not plot_rel_url:
logger.warning("No explanation_plot_url found in /explanations response.")
return status_text, ""
filename = plot_rel_url.split("/")[-1]
plot_path = FIGURES_DIR / filename
return status_text, str(plot_path)
except Exception as e:
logger.error(f"Error getting explanation: {e}")
return status_text, ""
async def batch_prediction(file):
async with httpx.AsyncClient(timeout=30.0) as client:
try:
df = pd.read_csv(file)
payload = []
for _, row in df.iterrows():
sample = {
"Age": int(row["Age"]),
"ChestPainType": row["ChestPainType"],
"RestingBP": int(row["RestingBP"]),
"Cholesterol": int(row["Cholesterol"]),
"FastingBS": int(row["FastingBS"]),
"RestingECG": row["RestingECG"],
"MaxHR": int(row["MaxHR"]),
"ExerciseAngina": row["ExerciseAngina"],
"Oldpeak": round(float(row["Oldpeak"]), 2),
"ST_Slope": row["ST_Slope"],
}
payload.append(sample)
response = await client.post(f"{API_URL}/batch-predictions", json=payload)
response.raise_for_status()
result = response.json()
results = result["data"]["results"]
df_results = pd.DataFrame(
[
{
"Patients's index": r["index"],
"Patient's status": "🆘 Risk Detected"
if r["prediction"] == 1
else "✅ No Risk Detected",
}
for r in results
]
)
return df_results
except Exception as e:
logger.error(f"Error making batch prediction: {e}")
return pd.DataFrame({"error": [str(e)]})
async def get_model_card():
data = await _fetch_api("cards/model_card")
card_lines = data.get("data").get("card_lines")
return "\n".join(card_lines)
async def get_dataset_card():
data = await _fetch_api("cards/dataset_card")
card_lines = data.get("data").get("card_lines")
return "\n".join(card_lines)
async def get_hyperparameters():
data = await _fetch_api("model/hyperparameters")
if "error" in data:
return f"## Error\n{data['error']}"
data = data.get("data", {}).get("hyperparameters", {}).get("cv", {})
md = ""
for key, value in data.items():
md += f"- **{key}**: {value}\n"
return md
async def get_metrics():
data = await _fetch_api("model/metrics")
if "error" in data:
return f"## Error\n{data['error']}"
metrics = data.get("data", {}).get("metrics", {})
if not metrics:
return "## No metrics found"
md = ""
for key, value in metrics.items():
md += f"- **{key}**: {value:.4f}\n"
return md
async def batch_explanation(file, patient_index: int):
"""Return SHAP plot (filepath) for a specific patient in the uploaded CSV."""
try:
df = pd.read_csv(file)
except Exception as e:
logger.error(f"Error reading CSV for batch explanation: {e}")
return None
try:
idx = int(patient_index)
except (TypeError, ValueError):
logger.error(f"Invalid patient_index: {patient_index}")
return None
if idx < 0 or idx >= len(df):
logger.error(f"patient_index {idx} out of range (0..{len(df) - 1})")
return None
row = df.iloc[idx]
payload = {
"Age": int(row["Age"]),
"ChestPainType": row["ChestPainType"],
"RestingBP": int(row["RestingBP"]),
"Cholesterol": int(row["Cholesterol"]),
"FastingBS": int(row["FastingBS"]),
"RestingECG": row["RestingECG"],
"MaxHR": int(row["MaxHR"]),
"ExerciseAngina": row["ExerciseAngina"],
"Oldpeak": round(float(row["Oldpeak"]), 2),
"ST_Slope": row["ST_Slope"],
}
async with httpx.AsyncClient() as client:
try:
expl_resp = await client.post(f"{API_URL}/explanations", json=payload)
expl_resp.raise_for_status()
expl_json = expl_resp.json()
plot_rel_url = expl_json["data"].get("explanation_plot_url")
if not plot_rel_url:
logger.warning(
"No explanation_plot_url found in /explanations response (batch)."
)
return None
filename = plot_rel_url.split("/")[-1]
plot_path = FIGURES_DIR / filename
return str(plot_path)
except Exception as e:
logger.error(f"Error getting batch explanation: {e}")
return None