Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Mon Apr 1 22:06:46 2024 | |
| @author: admin | |
| """ | |
| import tensorflow as tf | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| from storage import result_output,preprocess_data,process_train_data,turn_back,result_output,draw_acc,cal_accuracy | |
| filepath='data/VPA10.8.xlsx' | |
| df=pd.read_excel(filepath) | |
| df.columns = df.columns.str.replace('[{}:]', '') | |
| # 示例:确保有效标识符 | |
| df.columns = df.columns.str.replace(' ', '_') # 将空格替换为下划线 | |
| df.columns = df.columns.str.replace('^[0-9]', 'X') # 如果以数字开头,则在前面添加字符 'X' | |
| # 示例:删除特殊字符 | |
| df.columns = df.columns.str.replace('[^a-zA-Z0-9_]', '') | |
| result = df.groupby('ID')['DV'].count().reset_index(name='Count') | |
| # 过滤出Count大于1的记录的ID | |
| filtered_ids = result[result['Count'] >= 1]['ID'] | |
| # 保留ID在filtered_ids中的记录,并将AMT值设为上一行的AMT值 | |
| filtered_df = df[df['ID'].isin(filtered_ids)] | |
| filtered_df['AMT'] = filtered_df['AMT'].fillna(filtered_df.groupby('ID')['AMT'].shift()) | |
| filtered_df = filtered_df.dropna(subset=['DV']) | |
| samples_train = [] | |
| samples_val = [] | |
| samples_tr = [] | |
| # 获取 'AMT' 特征的最小值和最大值 | |
| min_amt = filtered_df['AMT'].min() | |
| max_amt = filtered_df['AMT'].max() | |
| min_dv = filtered_df['DV'].min() | |
| max_dv = filtered_df['DV'].max() | |
| filtered_df['BSA_square'] = filtered_df['BSA'] ** 2 | |
| filtered_df['BSA_cubic'] = filtered_df['BSA'] ** 3 | |
| filtered_df['AMT'] = np.log(filtered_df['AMT']) | |
| for id_value, group in filtered_df.groupby('ID'): | |
| count = group['DV'].count() | |
| if count >= 1: | |
| # for i in range(count - 1): | |
| i=count - 2 | |
| input_features = group.iloc[i + 1][['AMT', 't','BSA','BW','age','height']].tolist() | |
| output_feature = group.iloc[i+1]['DV'] | |
| if group.iloc[i + 1]['ID']>290: | |
| samples_val.append((input_features, output_feature)) | |
| else: | |
| samples_train.append((input_features, output_feature)) | |
| # 提取输入特征和输出特征 | |
| X = [input_features for input_features, _ in samples_train] | |
| y = [output_feature for _, output_feature in samples_train] | |
| val_x = [input_features for input_features, _ in samples_val] | |
| val_y = [output_feature for _, output_feature in samples_val] | |
| train_x, test_x, train_y, test_y = train_test_split(X, y, test_size=0.1) | |
| save_path = 'model/model_CNN' | |
| # save_path = 'C:/Users/admin/Desktop/药物建模/VPA手稿/model/model_CNN' | |
| loaded_model = tf.saved_model.load(save_path) | |
| model_pre = loaded_model.signatures['serving_default'] | |
| train_x = np.array(train_x).reshape(-1, 6, 1) | |
| test_x = np.array(test_x).reshape(-1, 6, 1) | |
| val_x = np.array(val_x).reshape(-1, 6, 1) | |
| train_predictions = model_pre(tf.constant(train_x, dtype=tf.float32)) | |
| train_predictions = train_predictions['dense_5'].numpy() | |
| test_predictions = model_pre(tf.constant(test_x, dtype=tf.float32)) | |
| val_predictions = model_pre(tf.constant(val_x, dtype=tf.float32)) | |
| val_predictions = val_predictions['dense_5'].numpy() | |
| train_y = np.reshape(train_y,(-1,1)) | |
| test_y = np.reshape(test_y,(-1,1)) | |
| val_y = np.reshape(val_y,(-1,1)) | |
| cal_accuracy(train_predictions,train_y) | |
| cal_accuracy(val_predictions,val_y) | |
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| # 加载 TensorFlow 模型 | |
| model_path = 'model/model_CNN' | |
| loaded_model = tf.saved_model.load(model_path) | |
| model_predict = loaded_model.signatures['serving_default'] | |
| def predict(AMT, t, BSA, BW, age, height): | |
| # 格式化输入数据以匹配模型的输入格式 | |
| input_features = np.array([[np.log(AMT), t, BSA, BW, age, height]], dtype=float).reshape(1, 6, 1) | |
| predictions = model_predict(tf.constant(input_features, dtype=tf.float32))['dense_5'].numpy() | |
| return predictions.flatten()[0] | |
| # 创建 Gradio 界面 | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[gr.Number(label='AMT', default=1.0), | |
| gr.Number(label='t', default=1.0), | |
| gr.Number(label='BSA', default=1.0), | |
| gr.Number(label='BW', default=1.0), | |
| gr.Number(label='age', default=30), | |
| gr.Number(label='height', default=160)], | |
| outputs='text', | |
| title="Drug Response Prediction", | |
| description="Enter the values for AMT, t, BSA, BW, age, and height to predict the drug response." | |
| ) | |
| iface.launch() | |