Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
| 1 |
with open("app.py", "w") as f:
|
| 2 |
f.write("""
|
| 3 |
import os
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import gradio as gr
|
| 8 |
import pickle
|
| 9 |
import pandas as pd
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
print("
|
|
|
|
| 12 |
print("Current directory:", os.getcwd())
|
| 13 |
-
print("
|
| 14 |
|
| 15 |
class TabTransformer(nn.Module):
|
| 16 |
def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
|
|
@@ -34,15 +38,42 @@ class TabTransformer(nn.Module):
|
|
| 34 |
x = x.squeeze(0)
|
| 35 |
return self.fc(x)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def predict(*inputs):
|
| 38 |
try:
|
| 39 |
-
print("Prediction started...")
|
| 40 |
-
# Feature lists
|
| 41 |
-
categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
|
| 42 |
-
numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
|
| 43 |
-
'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
|
| 44 |
-
'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
|
| 45 |
-
|
| 46 |
# Split inputs
|
| 47 |
cat_inputs = inputs[:len(categorical_features)]
|
| 48 |
num_inputs = inputs[len(categorical_features):]
|
|
@@ -53,69 +84,63 @@ def predict(*inputs):
|
|
| 53 |
|
| 54 |
# Create DataFrame
|
| 55 |
data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# Load scaler and transform data
|
| 59 |
-
with open("trans_scaler.pkl", 'rb') as f:
|
| 60 |
-
scaler = pickle.load(f)
|
| 61 |
scaled_data = scaler.transform(data)
|
| 62 |
-
print("Data scaled")
|
| 63 |
-
|
| 64 |
-
# Load model and predict
|
| 65 |
-
input_dim = len(categorical_features) + len(numeric_features)
|
| 66 |
-
model = TabTransformer(input_dim=input_dim)
|
| 67 |
-
model.load_state_dict(torch.load("tabtransformer_model.pth", map_location='cpu'))
|
| 68 |
-
model.eval()
|
| 69 |
|
|
|
|
| 70 |
with torch.no_grad():
|
| 71 |
tensor_data = torch.FloatTensor(scaled_data)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
print("Prediction completed")
|
| 75 |
-
|
| 76 |
-
response_prob = float(probabilities[0][0])
|
| 77 |
-
non_response_prob = float(probabilities[0][1])
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
except Exception as e:
|
| 81 |
print(f"Error in prediction: {str(e)}")
|
| 82 |
return f"Error: {str(e)}"
|
| 83 |
|
| 84 |
-
# Default values
|
| 85 |
-
numeric_defaults = {
|
| 86 |
-
'pvc_percent': 11.96, 'PVCQRS': 155.1, 'EF': 59.93, 'Age': 52.19,
|
| 87 |
-
'PVC_Prematurity_index': 0.6158, 'QRS_ratio': 1.933, 'mean_HR': 71.28,
|
| 88 |
-
'symptom_duration': 14.91, 'QTc_sinus': 425.0, 'PVCCI_dispersion': 57.1,
|
| 89 |
-
'CI_variability': 22.98, 'PVC_Peak_QRS_duration': 76.13, 'PVCCI': 513.4,
|
| 90 |
-
'PVC_Compansatory_interval': 1044
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
# Create interface
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
gr.
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
]
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
if __name__ == "__main__":
|
| 120 |
-
print("
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
with open("app.py", "w") as f:
|
| 2 |
f.write("""
|
| 3 |
import os
|
| 4 |
+
import sys
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
import gradio as gr
|
| 9 |
import pickle
|
| 10 |
import pandas as pd
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
|
| 14 |
+
print("Python version:", sys.version)
|
| 15 |
+
print("Torch version:", torch.__version__)
|
| 16 |
print("Current directory:", os.getcwd())
|
| 17 |
+
print("Directory contents:", os.listdir())
|
| 18 |
|
| 19 |
class TabTransformer(nn.Module):
|
| 20 |
def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
|
|
|
|
| 38 |
x = x.squeeze(0)
|
| 39 |
return self.fc(x)
|
| 40 |
|
| 41 |
+
# Özellik listeleri
|
| 42 |
+
categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
|
| 43 |
+
numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
|
| 44 |
+
'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
|
| 45 |
+
'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
|
| 46 |
+
|
| 47 |
+
# Varsayılan değerler
|
| 48 |
+
numeric_defaults = {
|
| 49 |
+
'pvc_percent': 11.96, 'PVCQRS': 155.1, 'EF': 59.93, 'Age': 52.19,
|
| 50 |
+
'PVC_Prematurity_index': 0.6158, 'QRS_ratio': 1.933, 'mean_HR': 71.28,
|
| 51 |
+
'symptom_duration': 14.91, 'QTc_sinus': 425.0, 'PVCCI_dispersion': 57.1,
|
| 52 |
+
'CI_variability': 22.98, 'PVC_Peak_QRS_duration': 76.13, 'PVCCI': 513.4,
|
| 53 |
+
'PVC_Compansatory_interval': 1044
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
print("Loading model...")
|
| 58 |
+
# Model tanımı
|
| 59 |
+
input_dim = len(categorical_features) + len(numeric_features)
|
| 60 |
+
model = TabTransformer(input_dim=input_dim)
|
| 61 |
+
model.load_state_dict(torch.load("tabtransformer_model.pth", map_location=torch.device('cpu')))
|
| 62 |
+
model.eval()
|
| 63 |
+
print("Model loaded successfully")
|
| 64 |
+
|
| 65 |
+
print("Loading scaler...")
|
| 66 |
+
# Scaler yükleme
|
| 67 |
+
with open("trans_scaler.pkl", "rb") as f:
|
| 68 |
+
scaler = pickle.load(f)
|
| 69 |
+
print("Scaler loaded successfully")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error during initialization: {str(e)}")
|
| 73 |
+
sys.exit(1)
|
| 74 |
+
|
| 75 |
def predict(*inputs):
|
| 76 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# Split inputs
|
| 78 |
cat_inputs = inputs[:len(categorical_features)]
|
| 79 |
num_inputs = inputs[len(categorical_features):]
|
|
|
|
| 84 |
|
| 85 |
# Create DataFrame
|
| 86 |
data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
|
| 87 |
+
|
| 88 |
+
# Scale data
|
|
|
|
|
|
|
|
|
|
| 89 |
scaled_data = scaler.transform(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
# Predict
|
| 92 |
with torch.no_grad():
|
| 93 |
tensor_data = torch.FloatTensor(scaled_data)
|
| 94 |
+
outputs = model(tensor_data)
|
| 95 |
+
probs = F.softmax(outputs, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
response_prob = float(probs[0][0])
|
| 98 |
+
non_response_prob = float(probs[0][1])
|
| 99 |
+
|
| 100 |
+
return f"Response: {response_prob:.1%}\nNon-Response: {non_response_prob:.1%}"
|
| 101 |
+
|
| 102 |
except Exception as e:
|
| 103 |
print(f"Error in prediction: {str(e)}")
|
| 104 |
return f"Error: {str(e)}"
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# Create interface
|
| 107 |
+
with gr.Blocks() as demo:
|
| 108 |
+
gr.Markdown("# PVC Response Predictor")
|
| 109 |
+
gr.Markdown("Enter patient features to predict response probability")
|
| 110 |
+
|
| 111 |
+
with gr.Row():
|
| 112 |
+
with gr.Column():
|
| 113 |
+
inputs = []
|
| 114 |
+
# Categorical inputs
|
| 115 |
+
for feat in categorical_features:
|
| 116 |
+
inputs.append(gr.Dropdown(
|
| 117 |
+
choices=["Yes", "No"],
|
| 118 |
+
value="No",
|
| 119 |
+
label=feat
|
| 120 |
+
))
|
| 121 |
+
|
| 122 |
+
# Numeric inputs
|
| 123 |
+
for feat in numeric_features:
|
| 124 |
+
inputs.append(gr.Number(
|
| 125 |
+
value=numeric_defaults[feat],
|
| 126 |
+
label=feat
|
| 127 |
+
))
|
| 128 |
+
|
| 129 |
+
with gr.Column():
|
| 130 |
+
output = gr.Textbox(label="Prediction Results")
|
| 131 |
+
|
| 132 |
+
submit_btn = gr.Button("Predict")
|
| 133 |
+
submit_btn.click(
|
| 134 |
+
fn=predict,
|
| 135 |
+
inputs=inputs,
|
| 136 |
+
outputs=output
|
| 137 |
+
)
|
| 138 |
|
| 139 |
if __name__ == "__main__":
|
| 140 |
+
print("Starting server...")
|
| 141 |
+
demo.launch(
|
| 142 |
+
server_name="0.0.0.0",
|
| 143 |
+
show_error=True,
|
| 144 |
+
share=False,
|
| 145 |
+
debug=True
|
| 146 |
+
)""")
|