| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import joblib |
| import onnxruntime as ort |
|
|
| |
| try: |
| ort_session = ort.InferenceSession("hiv_model.onnx") |
| scaler = joblib.load("hiv_scaler.pkl") |
| feature_names = ['Age', 'Sex', 'CD4+ T-cell count', 'Viral load', 'WBC count', 'Hemoglobin', 'Platelet count'] |
|
|
| model_loaded = True |
| scaler_loaded = True |
| except Exception as e: |
| print(f"Error loading model or scaler: {e}") |
| model_loaded = False |
| scaler_loaded = False |
| ort_session = None |
| scaler = None |
| feature_names = [] |
|
|
| def predict_risk(age, sex, cd4_count, viral_load, wbc_count, hemoglobin, platelet_count): |
| """ |
| Predicts HIV risk probability based on input features. |
| """ |
| if not model_loaded or not scaler_loaded: |
| return "Model or scaler not loaded. Please ensure 'hiv_model.onnx' and 'hiv_scaler.pkl' are in the same directory." |
|
|
| try: |
| |
| input_data = { |
| 'Age': [age], |
| 'Sex': [0 if sex == "Female" else 1], |
| 'CD4+ T-cell count': [cd4_count], |
| 'Viral load': [viral_load], |
| 'WBC count': [wbc_count], |
| 'Hemoglobin': [hemoglobin], |
| 'Platelet count': [platelet_count] |
| } |
| input_df = pd.DataFrame(input_data) |
|
|
| |
| scaled_values = scaler.transform(input_df[feature_names]) |
| scaled_df = pd.DataFrame(scaled_values, columns=feature_names) |
|
|
| |
| input_array = scaled_df[feature_names].values.astype(np.float32) |
| ort_inputs = {ort_session.get_inputs()[0].name: input_array} |
| ort_outs = ort_session.run(None, ort_inputs) |
|
|
| |
| probabilities = ort_outs[0][0] |
| risk_probability = probabilities[1] |
|
|
| return f"High Risk Probability: {risk_probability:.4f}" |
|
|
| except Exception as e: |
| return f"An error occurred during prediction: {e}" |
|
|
|
|
| |
| age_input = gr.Number(label="Age", value=30) |
| sex_input = gr.Radio(["Female", "Male"], label="Sex", value="Female") |
| cd4_input = gr.Number(label="CD4+ T-cell count", value=500) |
| viral_input = gr.Number(label="Viral load", value=10000) |
| wbc_input = gr.Number(label="WBC count", value=7000) |
| hemoglobin_input = gr.Number(label="Hemoglobin", value=14.0) |
| platelet_input = gr.Number(label="Platelet count", value=250000) |
|
|
| |
| iface = gr.Interface( |
| fn=predict_risk, |
| inputs=[age_input, sex_input, cd4_input, viral_input, wbc_input, hemoglobin_input, platelet_input], |
| outputs="text", |
| title="Sentinel-P1: HIV Risk Prediction Demo", |
| description="Enter blood report values to estimate HIV risk. This is a demonstration model and should not be used for medical advice.", |
| ) |
|
|
| iface.launch() |