| | import gradio as gr |
| | import pandas as pd |
| | import numpy as np |
| | from sklearn.ensemble import RandomForestClassifier |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, roc_auc_score |
| | from sklearn.calibration import calibration_curve |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from io import StringIO |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | import numpy as np |
| | import pandas as pd |
| | import pyarrow.parquet as pq |
| | from sklearn.preprocessing import OneHotEncoder,MinMaxScaler |
| | from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier |
| | from sklearn.model_selection import train_test_split,cross_val_score,StratifiedKFold,RepeatedStratifiedKFold |
| | from sklearn.metrics import confusion_matrix,classification_report,precision_score, recall_score, f1_score, accuracy_score, balanced_accuracy_score, matthews_corrcoef |
| | from sklearn.metrics import roc_auc_score,auc |
| | import pickle |
| |
|
| | from sklearn.utils.class_weight import compute_sample_weight |
| |
|
| | import xgboost as xgb |
| | from xgboost.sklearn import XGBClassifier |
| | from sklearn.naive_bayes import GaussianNB |
| | from sklearn.ensemble import AdaBoostClassifier |
| | from sklearn.svm import SVC |
| | from sklearn.linear_model import LogisticRegression |
| | from sklearn.preprocessing import StandardScaler |
| | from sklearn.metrics import brier_score_loss |
| | from sklearn.calibration import calibration_curve |
| | import matplotlib.pyplot as plt |
| | from sklearn.calibration import CalibratedClassifierCV |
| | from sklearn.linear_model import LinearRegression |
| |
|
| | |
| | training_data = None |
| | column_names = None |
| | test_list=[] |
| | def rand_for(neww_list,x_te,rf,lab,x_tr,actual,paramss,X_Tempp,enco,my_table_str,my_table_num,tabl,tracount): |
| | cl_list=[] |
| | pro_list=[] |
| | for i in neww_list: |
| | dff_copy=i.copy() |
| | y_cl=dff_copy.loc[:,lab] |
| | teemp_list=[] |
| | ftli=[] |
| | X_cl=dff_copy.drop([lab],axis=1) |
| | x_te=pd.DataFrame(x_te,columns=X_Tempp.columns) |
| |
|
| | if tracount==0: |
| |
|
| | |
| | mm=RandomForestClassifier(n_estimators=100, criterion='entropy',max_features=None,random_state=42,bootstrap=True, oob_score=True,class_weight='balanced',ccp_alpha=0.01) |
| | |
| | calibrated_rf = CalibratedClassifierCV(estimator=mm, method='isotonic', cv=5) |
| | calibrated_rf.fit(X_cl, y_cl) |
| | |
| | out=calibrated_rf.predict(x_te) |
| | probs=calibrated_rf.predict_proba(x_te)[:,1] |
| | elif tracount==1: |
| | dtrain = xgb.DMatrix(X_cl.to_numpy(), label=y_cl) |
| | dtest = xgb.DMatrix(x_te.to_numpy(), label=y_te) |
| | params = { |
| | 'objective': 'binary:logistic', |
| | 'eval_metric': 'logloss', |
| | 'max_depth': 60, |
| | 'eta': 0.1, |
| | 'subsample': 0.8, |
| | 'colsample_bytree': 0.8, |
| | 'seed': 42} |
| | num_rounds = 100 |
| | mm=xgb.train(params, dtrain, num_rounds) |
| | probs = mm.predict(dtest) |
| | out = (probs > 0.5).astype(int) |
| |
|
| | elif tracount==5: |
| | mm=LogisticRegression(penalty='l2',solver='newton-cholesky',max_iter=200) |
| | mm.fit(X_cl,y_cl) |
| | out=mm.predict(x_te) |
| | probs=mm.predict_proba(x_te)[:,1] |
| |
|
| | |
| | elif tracount==4: |
| | var_smoothing_value = 1e-9 |
| | mm = GaussianNB(var_smoothing=var_smoothing_value) |
| | mm.fit(X_cl, y_cl) |
| | out = mm.predict(x_te) |
| | probs = mm.predict_proba(x_te)[:, 1] |
| |
|
| | elif tracount==1: |
| | mm = AdaBoostClassifier(n_estimators=100,random_state=42,estimator=RandomForestClassifier(n_estimators=100, criterion='entropy',random_state=42,bootstrap=True, oob_score=True,class_weight='balanced',ccp_alpha=0.01)) |
| | out = mm.predict(x_te) |
| | probs = mm.predict_proba(x_te)[:, 1] |
| |
|
| | elif tracount==6: |
| | mm = SVC(probability=True, C=3) |
| | mm.fit(X_cl, y_cl) |
| | out = mm.predict(x_te) |
| | probs = mm.predict_proba(x_te)[:, 1] |
| |
|
| |
|
| |
|
| | cl_list.append(out) |
| | pro_list.append(probs) |
| | |
| | |
| | |
| | return cl_list,pro_list |
| | def ne_calib(some_prob,down_factor,origin_factor): |
| | aa=some_prob*origin_factor/down_factor |
| | denone=(1-some_prob)*(1-origin_factor)/(1-down_factor) |
| | new_dum_prob=aa/(denone+aa) |
| | return new_dum_prob |
| | def actualll(sl_list,pro_list,delt,down_factor,origin_factor): |
| | ac_list=[] |
| | probab_list=[] |
| | second_probab_list=[] |
| |
|
| | for i in range(len(sl_list[0])): |
| | sum=0 |
| | sum_pro=0 |
| | sum_pro_pro=0 |
| | for j in range(len(sl_list)): |
| |
|
| | sum_pro+=ne_calib(pro_list[j][i],down_factor,origin_factor) |
| | sum_pro_pro+=pro_list[j][i] |
| |
|
| | if sl_list[j][i]==-1: |
| | sum+=(sl_list[j][i]) |
| | else: |
| | sum+=(sl_list[j][i]) |
| | |
| | sum/=len(sl_list) |
| | sum_pro/=len(sl_list) |
| | sum_pro_pro/=len(sl_list) |
| |
|
| |
|
| | if sum>=delt: |
| | ac_list.append(1) |
| | probab_list.append(sum_pro) |
| | second_probab_list.append(sum_pro_pro) |
| | elif sum<=delt and sum >=0 : |
| | ac_list.append(0) |
| | probab_list.append(1-sum_pro) |
| | second_probab_list.append(1-sum_pro_pro) |
| | elif sum<=delt and sum <0: |
| | ac_list.append(0) |
| | probab_list.append(sum_pro) |
| | second_probab_list.append(sum_pro_pro) |
| | return ac_list,probab_list,second_probab_list |
| | |
| |
|
| |
|
| | def sli_mod(c_lisy): |
| | sli_list=[] |
| | |
| | for i in c_lisy: |
| | k=np.array(i) |
| | k[k<0.5]=-1 |
| | k[k>=0.5]=1 |
| | |
| | sli_list.append(list(k)) |
| | return sli_list |
| |
|
| | def run_model(x_tr,x_te,y_tr,deltaa,lab,rf,X_Tempp,track,actual,paramss,enco,my_table_str,my_table_num,tabl,tracount,origin_factor): |
| |
|
| | x_tr=pd.DataFrame(x_tr,columns=X_Tempp.columns) |
| | y_tr=pd.DataFrame(y_tr,columns=[test_list[track]]) |
| | master_table=pd.concat([x_tr,y_tr],axis=1).copy() |
| |
|
| | only_minority=master_table.loc[master_table[lab]==1] |
| |
|
| | only_majority=master_table.drop(only_minority.index) |
| | min_index=only_minority.index |
| | max_index=only_majority.index |
| |
|
| | df_list=[] |
| | down_factor=0 |
| | if (len(min_index)<=60): |
| | for i in range(20): |
| | np.random.seed(i+30) |
| | if test_list[track]=='VOD' or test_list[track]=='STROKEHI': |
| | sampled_array = np.random.choice(max_index,size=int(3*len(min_index)), replace=True) |
| | down_factor=0.25 |
| | elif test_list[track]=='ACSPSHI': |
| | sampled_array = np.random.choice(max_index,size=int(2.5*len(min_index)), replace=True) |
| | down_factor=1/(1+2.5) |
| | else: |
| | sampled_array = np.random.choice(max_index,size=int(2*len(min_index)), replace=True) |
| | down_factor=1/(1+2) |
| | temp_df=only_majority.loc[sampled_array] |
| |
|
| | new_df=pd.concat([temp_df,only_minority]) |
| | |
| | df_list.append(new_df) |
| | else: |
| | for i in range(10): |
| | np.random.seed(i+30) |
| | if test_list[track]=='DEAD': |
| | sampled_array = np.random.choice(max_index,size=int(3*len(min_index)), replace=True) |
| | down_factor=1/(1+3) |
| | else: |
| | sampled_array = np.random.choice(max_index,size=int(3*len(min_index)), replace=True) |
| | down_factor=1/(1+3) |
| | temp_df=only_majority.loc[sampled_array] |
| | |
| | new_df=pd.concat([temp_df,only_minority]) |
| | |
| | df_list.append(new_df) |
| |
|
| |
|
| | |
| | |
| | neww_list=df_list |
| | c_lisy,pro_lisy=rand_for(neww_list,x_te,rf,lab,x_tr,actual,paramss,X_Tempp,enco,my_table_str,my_table_num,tabl,tracount) |
| | sli_lisy=sli_mod(c_lisy) |
| |
|
| | a_lisy,probab_lisy,secondlisy=actualll(sli_lisy,pro_lisy,deltaa,down_factor,origin_factor) |
| | return a_lisy,probab_lisy,secondlisy |
| | def load_training_data(): |
| |
|
| | global training_data, column_names, test_list |
| | |
| |
|
| | try: |
| | my_table=pq.read_table('year6.parquet').to_pandas() |
| | print(my_table['YEARGPF'].value_counts()) |
| | my_table=my_table[(my_table['YEARGPF']!='< 2008')] |
| | my_table=my_table.reset_index(drop=True) |
| | |
| | pa=pd.read_csv('final_variable.csv') |
| | pali=list(pa.iloc[:,0]) |
| | print(pali) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | training_data = my_table |
| | column_names=pali |
| | except FileNotFoundError: |
| |
|
| | return "No training Data" |
| |
|
| | def train_and_evaluate(input_file): |
| | |
| | global training_data, column_names,test_list |
| | |
| | if training_data is None or column_names is None: |
| | load_training_data() |
| | |
| | if input_file is None: |
| | return None, None, None |
| | |
| | try: |
| | |
| | input_data = pd.read_csv(input_file.name) |
| | |
| |
|
| | available_features = [col for col in column_names if col in training_data.columns] |
| | available_features_input = [col for col in available_features if col in input_data.columns] |
| | |
| | if not available_features_input: |
| | return "Error: No matching columns found between datasets", None, None |
| | |
| | |
| | |
| | |
| | outcome_cols = ['DEAD', 'GF', 'AGVHD', 'CGVHD', 'VOCPSHI', 'STROKEHI'] |
| | test_list=outcome_cols.copy() |
| | total_cols=available_features+outcome_cols |
| | inter_df=training_data[total_cols] |
| | inter_df=inter_df.dropna() |
| | inter_df=inter_df.reset_index(drop=True) |
| |
|
| |
|
| | input_data=input_data[(input_data['YEARGPF']!='< 2008')] |
| | input_data=input_data.reset_index(drop=True) |
| |
|
| | inter_input=input_data[total_cols] |
| | inter_input=inter_input.dropna() |
| | inter_input=inter_input.reset_index(drop=True) |
| | my_table=inter_df[available_features] |
| | |
| | X_input = inter_input[available_features] |
| | X_input = X_input[my_table.columns] |
| | my_test=X_input |
| | '''li1=['Yes','No'] |
| | li2=['Event happened', 'No event'] |
| | cols_with_unique_values1 = [] |
| | cols_with_unique_values2 = [] |
| | #print(my_table['EXCHTFPR'].isin(li1)) |
| | for col in my_table.columns: |
| | if my_table[col].isin(li1).all(): |
| | cols_with_unique_values1.append(col) |
| | for col in my_table.columns: |
| | if my_table[col].isin(li2).all(): |
| | cols_with_unique_values2.append(col) |
| | #print(len(cols_with_unique_values1)) |
| | #print(len(cols_with_unique_values2)) |
| | my_ye=my_table[cols_with_unique_values1].replace(['Yes','No'],[1,0]).astype('int64') |
| | my_eve=my_table[cols_with_unique_values2].replace(['Event happened','No event'],[1,0]).astype('int64') |
| | my_table2=my_table.copy() |
| | ccc=[elem for elem in cols_with_unique_values1+cols_with_unique_values2] |
| | #print(ccc) |
| | my_table_modify=my_table2.drop(ccc,axis=1) |
| | my_table_modify=pd.concat([my_table_modify,my_ye,my_eve],axis=1) |
| | #my_table_modify=my_table_modify.drop([test_list[track],'DUMMYID'],axis=1) |
| | my_table_str=my_table_modify.select_dtypes(exclude=['number']) |
| | print(my_table_str.shape) |
| | my_table_num=my_table_modify.select_dtypes(include=['number']) |
| | #print(my_table_num.shape) |
| | enco=OneHotEncoder(sparse_output=True) |
| | fito=enco.fit(my_table_str) |
| | #mmm=aa.inverse_transform(g) |
| | tabl=enco.transform(my_table_str) |
| | tabl=pd.DataFrame(tabl.toarray(),columns=enco.get_feature_names_out()) |
| | #print(tabl.shape) |
| | #print(dfcopy) |
| | ftable=pd.concat([tabl,my_table_num],axis=1) |
| | X_train_full=ftable |
| | li1=['Yes','No'] |
| | li2=['Event happened', 'No event'] |
| | cols_with_unique_values1 = [] |
| | cols_with_unique_values2 = [] |
| | for col in my_test.columns: |
| | if my_test[col].isin(li1).all(): |
| | cols_with_unique_values1.append(col) |
| | for col in my_test.columns: |
| | if my_test[col].isin(li2).all(): |
| | cols_with_unique_values2.append(col) |
| | #print(len(cols_with_unique_values1)) |
| | #print(len(cols_with_unique_values2)) |
| | my_ye=my_test[cols_with_unique_values1].replace(['Yes','No'],[1,0]).astype('int64') |
| | my_eve=my_test[cols_with_unique_values2].replace(['Event happened','No event'],[1,0]).astype('int64') |
| | my_test2=my_test.copy() |
| | ccc=[elem for elem in cols_with_unique_values1+cols_with_unique_values2] |
| | #print(ccc) |
| | my_test_modify=my_test2.drop(ccc,axis=1) |
| | my_test=pd.concat([my_test_modify,my_ye,my_eve],axis=1) |
| | #print(my_table_str.shape) |
| | my_test_num=my_test.select_dtypes(include=['number']) |
| | my_test_str=my_test.select_dtypes(exclude=['number']) |
| | mm=my_test_str.columns |
| | my_test_str=enco.transform(my_test_str) |
| | my_test_str=pd.DataFrame(my_test_str.toarray(),columns=enco.get_feature_names_out()) |
| | my_test_real=pd.concat([my_test_str,my_test_num],axis=1)''' |
| |
|
| | |
| | li1=['Yes','No'] |
| | |
| | cols_with_unique_values1 = [] |
| | cols_with_unique_values2 = [] |
| | |
| | for col in my_table.columns: |
| | if my_table[col].isin(li1).all(): |
| | cols_with_unique_values1.append(col) |
| | |
| | |
| | |
| | |
| | |
| | my_ye=my_table[cols_with_unique_values1].replace(['Yes','No'],[1,0]).astype('int64') |
| | |
| | my_table2=my_table.copy() |
| | ccc=[elem for elem in cols_with_unique_values1+cols_with_unique_values2] |
| | |
| | my_table_modify=my_table2.drop(ccc,axis=1) |
| | my_table_modify=pd.concat([my_table_modify,my_ye],axis=1) |
| | |
| | my_table_str=my_table_modify.select_dtypes(exclude=['number']) |
| | print(my_table_str.shape) |
| | my_table_num=my_table_modify.select_dtypes(include=['number']) |
| |
|
| | |
| | li1=['Yes','No'] |
| | li2=['Event happened', 'No event'] |
| | cols_with_unique_values1 = [] |
| | cols_with_unique_values2 = [] |
| | for col in my_test.columns: |
| | if my_test[col].isin(li1).all(): |
| | cols_with_unique_values1.append(col) |
| | for col in my_test.columns: |
| | if my_test[col].isin(li2).all(): |
| | cols_with_unique_values2.append(col) |
| | |
| | |
| | my_ye=my_test[cols_with_unique_values1].replace(['Yes','No'],[1,0]).astype('int64') |
| | |
| | my_test2=my_test.copy() |
| | ccc=[elem for elem in cols_with_unique_values1+cols_with_unique_values2] |
| | |
| | my_test_modify=my_test2.drop(ccc,axis=1) |
| | my_test=pd.concat([my_test_modify,my_ye],axis=1) |
| | |
| | my_test_num=my_test.select_dtypes(include=['number']) |
| | my_test_str=my_test.select_dtypes(exclude=['number']) |
| | mm=my_test_str.columns |
| |
|
| |
|
| | |
| | df_combined = pd.concat([my_table_str, my_test_str], axis=0, ignore_index=True) |
| | enco = OneHotEncoder(sparse_output=False, handle_unknown='ignore') |
| | encoded = enco.fit_transform(df_combined) |
| | encoded_df = pd.DataFrame(encoded, columns=enco.get_feature_names_out()) |
| |
|
| | tabl = encoded_df.iloc[:len(my_table_str)].reset_index(drop=True) |
| | tabl=tabl.reset_index(drop=True) |
| | ftable=pd.concat([tabl,my_table_num],axis=1) |
| | X_train_full=ftable |
| | my_test_str = encoded_df.iloc[len(my_table_str):].reset_index(drop=True) |
| | my_test_str=my_test_str.reset_index(drop=True) |
| | my_test_real=pd.concat([my_test_str,my_test_num],axis=1) |
| |
|
| | |
| | |
| |
|
| | metrics_results = [] |
| | calibration_results = [] |
| | calibration_plots = [] |
| | |
| | outcome_names = ['Overall Survival', 'Graft Failure', 'Acute GVHD', 'Chronic GVHD', 'Vaso-Occlusive Crisis Post-HCT', 'Stroke Post-HCT'] |
| | |
| | for i, (outcome_col, outcome_name) in enumerate(zip(outcome_cols, outcome_names)): |
| | if outcome_col not in training_data.columns: |
| | continue |
| | |
| | y_train_full = inter_df[outcome_col] |
| | amaj1=y_train_full.value_counts().idxmax() |
| | amin1=y_train_full.value_counts().idxmin() |
| | |
| | y_train_full=y_train_full.replace([amin1,amaj1],[1,0]).astype(int) |
| |
|
| | y_test_full = inter_input[outcome_col] |
| | amaj1=y_test_full.value_counts().idxmax() |
| | amin1=y_test_full.value_counts().idxmin() |
| | |
| | y_test_full=y_test_full.replace([amin1,amaj1],[1,0]).astype(int) |
| | |
| | X_train,y_train=X_train_full.values,y_train_full.values |
| | x_te,y_test=my_test_real.values,y_test_full.values |
| | vddc=len(np.where(y_train_full.to_numpy()==1)[0])/X_train_full.shape[0] |
| | deltaa=0.2 |
| | rf=RandomForestClassifier() |
| | paramss={} |
| | tracount=0 |
| | y_pred,y_pred_proba,secondnaive=run_model(X_train,x_te,y_train,deltaa,outcome_col,rf,X_train_full,i,ftable,paramss,enco,my_table_str,my_table_num,tabl,tracount,vddc) |
| | |
| | |
| | |
| | |
| |
|
| | accuracy = accuracy_score(y_test, y_pred) |
| | balanced_acc = balanced_accuracy_score(y_test, y_pred) |
| | precision = precision_score(y_test, y_pred, average='weighted', zero_division=0) |
| | recall = recall_score(y_test, y_pred, average='weighted', zero_division=0) |
| | auc = roc_auc_score(y_test, y_pred_proba) |
| | |
| | metrics_results.append([outcome_name, f"{accuracy:.3f}", f"{balanced_acc:.3f}", |
| | f"{precision:.3f}", f"{recall:.3f}", f"{auc:.3f}"]) |
| | |
| |
|
| | fraction_pos, mean_pred = calibration_curve(y_test, y_pred_proba, n_bins=10) |
| | |
| |
|
| | if len(mean_pred) > 1 and len(fraction_pos) > 1: |
| | slope = np.polyfit(mean_pred, fraction_pos, 1)[0] |
| | intercept = np.polyfit(mean_pred, fraction_pos, 1)[1] |
| | else: |
| | slope, intercept = 1.0, 0.0 |
| | |
| | calibration_results.append([outcome_name, f"{slope:.3f}", f"{intercept:.3f}"]) |
| | |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration') |
| | ax.plot(mean_pred, fraction_pos, 'o-', label=f'{outcome_name}') |
| | ax.set_xlabel('Mean Predicted Probability') |
| | ax.set_ylabel('Fraction of Positives') |
| | ax.set_title(f'Calibration Plot - {outcome_name}') |
| | ax.legend() |
| | ax.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | calibration_plots.append(fig) |
| | |
| |
|
| | metrics_df = pd.DataFrame(metrics_results, |
| | columns=['Outcome', 'Accuracy', 'Balanced Accuracy', 'Precision', 'Recall', 'AUC']) |
| | |
| |
|
| | calibration_df = pd.DataFrame(calibration_results, |
| | columns=['Outcome', 'Slope', 'Intercept']) |
| | |
| | return metrics_df, calibration_df, calibration_plots |
| | |
| | except Exception as e: |
| | return f"Error processing data: {str(e)}", None, None |
| |
|
| | def create_interface(): |
| |
|
| | |
| |
|
| | load_training_data() |
| | |
| | with gr.Blocks( |
| | css=""" |
| | .gradio-container { |
| | max-width: none !important; |
| | height: 100vh; |
| | overflow-y: auto; |
| | } |
| | .main-container { |
| | padding: 20px; |
| | } |
| | .big-title { |
| | font-size: 2.5em; |
| | font-weight: bold; |
| | margin-bottom: 30px; |
| | text-align: center; |
| | } |
| | .section-title { |
| | font-size: 2em; |
| | font-weight: bold; |
| | margin: 40px 0 20px 0; |
| | color: #2d5aa0; |
| | } |
| | .subsection-title { |
| | font-size: 1.5em; |
| | font-weight: bold; |
| | margin: 30px 0 15px 0; |
| | color: #4a4a4a; |
| | } |
| | """, |
| | title="ML Model Evaluation Pipeline" |
| | ) as demo: |
| | |
| | with gr.Column(elem_classes=["main-container"]): |
| |
|
| | gr.HTML('<div class="big-title">Input</div>') |
| | |
| | gr.Markdown("### Please upload the dataset:") |
| | file_input = gr.File( |
| | label="Upload Dataset (CSV)", |
| | file_types=[".csv"], |
| | type="filepath" |
| | ) |
| | |
| |
|
| | process_btn = gr.Button("Process Dataset", variant="primary", size="lg") |
| | |
| |
|
| | gr.HTML('<div class="section-title">Outputs</div>') |
| | |
| |
|
| | gr.HTML('<div class="subsection-title">Metrics</div>') |
| | metrics_table = gr.Dataframe( |
| | headers=["Outcome", "Accuracy", "Balanced Accuracy", "Precision", "Recall", "AUC"], |
| | interactive=False, |
| | wrap=True |
| | ) |
| | |
| |
|
| | gr.HTML('<div class="subsection-title">Calibration</div>') |
| | calibration_table = gr.Dataframe( |
| | headers=["Outcome", "Slope", "Intercept"], |
| | interactive=False, |
| | wrap=True |
| | ) |
| | |
| |
|
| | gr.Markdown("#### Calibration Curves") |
| | |
| |
|
| | |
| | plot2 = gr.Plot(label="Overall Survival") |
| | plot3 = gr.Plot(label="Graft Failure") |
| | plot4 = gr.Plot(label="Acute GVHD") |
| | plot5 = gr.Plot(label="Chronic GVHD") |
| | plot6 = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") |
| | plot7 = gr.Plot(label="Stroke Post-HCT") |
| | |
| | plots = [plot2, plot3, plot4, plot5, plot6, plot7] |
| | |
| |
|
| | def process_and_display(file): |
| | metrics_df, calibration_df, calibration_plots = train_and_evaluate(file) |
| | |
| | if isinstance(metrics_df, str): |
| | return metrics_df, None, None, None, None, None, None, None |
| | |
| |
|
| | plot_outputs = [None] * 6 |
| | if calibration_plots: |
| | for i, plot in enumerate(calibration_plots[:6]): |
| | plot_outputs[i] = plot |
| | |
| | return (metrics_df, calibration_df, |
| | plot_outputs[0], plot_outputs[1], plot_outputs[2], |
| | plot_outputs[3], plot_outputs[4], plot_outputs[5]) |
| | |
| |
|
| | process_btn.click( |
| | fn=process_and_display, |
| | inputs=[file_input], |
| | outputs=[metrics_table, calibration_table] + plots |
| | ) |
| | |
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = create_interface() |
| | demo.launch( |
| | share=True, |
| | inbrowser=True, |
| | height=800, |
| | show_error=True |
| | ) |