import httpx from loguru import logger import pandas as pd from predicting_outcomes_in_heart_failure.config import API_URL, FIGURES_DIR async def _fetch_api(endpoint: str): async with httpx.AsyncClient() as client: try: response = await client.get(f"{API_URL}/{endpoint}") response.raise_for_status() return response.json() except Exception as e: logger.error(f"Error fetching {endpoint}: {e}") return {"error": str(e)} class Wrapper: async def prediction_with_explanation( age, chest_pain_type, resting_bp, cholesterol, fasting_bs, resting_ecg, max_hr, exercise_angina, oldpeak, st_slope, ): payload = { "Age": age, "ChestPainType": chest_pain_type, "RestingBP": resting_bp, "Cholesterol": cholesterol, "FastingBS": fasting_bs, "RestingECG": resting_ecg, "MaxHR": max_hr, "ExerciseAngina": exercise_angina, "Oldpeak": round(oldpeak, 2), "ST_Slope": st_slope, } async with httpx.AsyncClient() as client: try: pred_resp = await client.post(f"{API_URL}/predictions", json=payload) pred_resp.raise_for_status() pred_json = pred_resp.json() prediction_value = pred_json["data"]["prediction"] status = "🆘 Risk Detected" if prediction_value == 1 else "✅ No Risk Detected" status_text = f"# Patient's status: {status}" except Exception as e: logger.error(f"Error making prediction: {e}") return f"Error during prediction: {str(e)}", "" try: expl_resp = await client.post(f"{API_URL}/explanations", json=payload) expl_resp.raise_for_status() expl_json = expl_resp.json() plot_rel_url = expl_json["data"].get("explanation_plot_url") if not plot_rel_url: logger.warning("No explanation_plot_url found in /explanations response.") return status_text, "" filename = plot_rel_url.split("/")[-1] plot_path = FIGURES_DIR / filename return status_text, str(plot_path) except Exception as e: logger.error(f"Error getting explanation: {e}") return status_text, "" async def batch_prediction(file): async with httpx.AsyncClient(timeout=30.0) as client: try: df = pd.read_csv(file) payload = [] for _, row in df.iterrows(): sample = { "Age": int(row["Age"]), "ChestPainType": row["ChestPainType"], "RestingBP": int(row["RestingBP"]), "Cholesterol": int(row["Cholesterol"]), "FastingBS": int(row["FastingBS"]), "RestingECG": row["RestingECG"], "MaxHR": int(row["MaxHR"]), "ExerciseAngina": row["ExerciseAngina"], "Oldpeak": round(float(row["Oldpeak"]), 2), "ST_Slope": row["ST_Slope"], } payload.append(sample) response = await client.post(f"{API_URL}/batch-predictions", json=payload) response.raise_for_status() result = response.json() results = result["data"]["results"] df_results = pd.DataFrame( [ { "Patients's index": r["index"], "Patient's status": "🆘 Risk Detected" if r["prediction"] == 1 else "✅ No Risk Detected", } for r in results ] ) return df_results except Exception as e: logger.error(f"Error making batch prediction: {e}") return pd.DataFrame({"error": [str(e)]}) async def get_model_card(): data = await _fetch_api("cards/model_card") card_lines = data.get("data").get("card_lines") return "\n".join(card_lines) async def get_dataset_card(): data = await _fetch_api("cards/dataset_card") card_lines = data.get("data").get("card_lines") return "\n".join(card_lines) async def get_hyperparameters(): data = await _fetch_api("model/hyperparameters") if "error" in data: return f"## Error\n{data['error']}" data = data.get("data", {}).get("hyperparameters", {}).get("cv", {}) md = "" for key, value in data.items(): md += f"- **{key}**: {value}\n" return md async def get_metrics(): data = await _fetch_api("model/metrics") if "error" in data: return f"## Error\n{data['error']}" metrics = data.get("data", {}).get("metrics", {}) if not metrics: return "## No metrics found" md = "" for key, value in metrics.items(): md += f"- **{key}**: {value:.4f}\n" return md async def batch_explanation(file, patient_index: int): """Return SHAP plot (filepath) for a specific patient in the uploaded CSV.""" try: df = pd.read_csv(file) except Exception as e: logger.error(f"Error reading CSV for batch explanation: {e}") return None try: idx = int(patient_index) except (TypeError, ValueError): logger.error(f"Invalid patient_index: {patient_index}") return None if idx < 0 or idx >= len(df): logger.error(f"patient_index {idx} out of range (0..{len(df) - 1})") return None row = df.iloc[idx] payload = { "Age": int(row["Age"]), "ChestPainType": row["ChestPainType"], "RestingBP": int(row["RestingBP"]), "Cholesterol": int(row["Cholesterol"]), "FastingBS": int(row["FastingBS"]), "RestingECG": row["RestingECG"], "MaxHR": int(row["MaxHR"]), "ExerciseAngina": row["ExerciseAngina"], "Oldpeak": round(float(row["Oldpeak"]), 2), "ST_Slope": row["ST_Slope"], } async with httpx.AsyncClient() as client: try: expl_resp = await client.post(f"{API_URL}/explanations", json=payload) expl_resp.raise_for_status() expl_json = expl_resp.json() plot_rel_url = expl_json["data"].get("explanation_plot_url") if not plot_rel_url: logger.warning( "No explanation_plot_url found in /explanations response (batch)." ) return None filename = plot_rel_url.split("/")[-1] plot_path = FIGURES_DIR / filename return str(plot_path) except Exception as e: logger.error(f"Error getting batch explanation: {e}") return None