Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import pandas as pd | |
| from autogluon.multimodal import MultiModalPredictor | |
| app = FastAPI() | |
| # Load the model | |
| predictor = MultiModalPredictor.load("model_ml_dart") | |
| # Input schema | |
| class PredictionInput(BaseModel): | |
| anchor_age: int | |
| dbp: int | |
| heart_rate: int | |
| sbp: int | |
| pH: float | |
| PaCO2: float | |
| PaO2: float | |
| HCO3: float | |
| SaO2: float | |
| Compliance: float | |
| Flow_Rate_L_min: float | |
| Inspired_O2_Fraction: float | |
| Minute_Volume: float | |
| Peak_Insp_Pressure: float | |
| Plateau_Pressure: float | |
| Resistance_Exp: float | |
| Resistance_Insp: float | |
| Respiratory_Rate_Total: float | |
| Tidal_Volume_observed: float | |
| Tidal_Volume_set: float | |
| Total_PEEP_Level: float | |
| respiratory_diagnoses: str | |
| # Column renaming to match training data | |
| rename_map = { | |
| "anchor_age": "anchor_age", | |
| "dbp": "dbp", | |
| "heart_rate": "heart_rate", | |
| "sbp": "sbp", | |
| "pH": "pH", | |
| "PaCO2": "PaCO2", | |
| "PaO2": "PaO2", | |
| "HCO3": "HCO3", | |
| "SaO2": "SaO2", | |
| "Compliance": "Compliance", | |
| "Flow_Rate_L_min": "Flow Rate (L/min)", | |
| "Inspired_O2_Fraction": "Inspired O2 Fraction", | |
| "Minute_Volume": "Minute Volume", | |
| "Peak_Insp_Pressure": "Peak Insp. Pressure", | |
| "Plateau_Pressure": "Plateau Pressure", | |
| "Resistance_Exp": "Resistance Exp", | |
| "Resistance_Insp": "Resistance Insp", | |
| "Respiratory_Rate_Total": "Respiratory Rate (Total)", | |
| "Tidal_Volume_observed": "Tidal Volume (observed)", | |
| "Tidal_Volume_set": "Tidal Volume (set)", | |
| "Total_PEEP_Level": "Total PEEP Level", | |
| "respiratory_diagnoses": "respiratory_diagnoses" | |
| } | |
| # Mapping from predicted class index to readable label | |
| label_map = { | |
| 0: "APRV", | |
| 1: "CMV", | |
| 2: "NIV", | |
| 3: "SPECIAL", | |
| 4: "PSV", | |
| 5: "SIMV", | |
| 6: "SPONT" | |
| } | |
| def predict(input_data: PredictionInput): | |
| try: | |
| # Rename input fields to match training data | |
| input_dict = input_data.dict() | |
| renamed_input = {rename_map[k]: v for k, v in input_dict.items()} | |
| df = pd.DataFrame([renamed_input]) | |
| # Run prediction | |
| raw_prediction = predictor.predict(df)[0] | |
| # Convert numeric label to string label | |
| ventilation_mode = label_map.get(int(raw_prediction), "Unknown") | |
| return {"ventilation_mode": ventilation_mode} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |