halil21 commited on
Commit
f34b5b5
·
verified ·
1 Parent(s): b9bfb18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -102
app.py CHANGED
@@ -1,20 +1,16 @@
1
  with open("app.py", "w") as f:
2
  f.write("""
3
  import os
4
- import sys
5
  import torch
6
  import torch.nn as nn
7
- import torch.nn.functional as F
8
  import gradio as gr
9
  import pickle
10
  import pandas as pd
11
 
12
- # Debug için print fonksiyonları
13
- print("Python version:", sys.version)
14
- print("Current working directory:", os.getcwd())
15
- print("Directory contents:", os.listdir())
16
 
17
- # TabTransformer Model Tanımı
18
  class TabTransformer(nn.Module):
19
  def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
20
  super().__init__()
@@ -37,117 +33,92 @@ class TabTransformer(nn.Module):
37
  x = x.squeeze(0)
38
  return self.fc(x)
39
 
40
- # Özellikler
41
- categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
42
- numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
43
- 'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
44
- 'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
45
-
46
- numeric_means = {
47
- 'pvc_percent': 11.96, 'PVCQRS': 155.1, 'EF': 59.93, 'Age': 52.19,
48
- 'PVC_Prematurity_index': 0.6158, 'QRS_ratio': 1.933, 'mean_HR': 71.28,
49
- 'symptom_duration': 14.91, 'QTc_sinus': 425.0, 'PVCCI_dispersion': 57.1,
50
- 'CI_variability': 22.98, 'PVC_Peak_QRS_duration': 76.13, 'PVCCI': 513.4,
51
- 'PVC_Compansatory_interval': 1044
52
- }
53
-
54
- # Global değişkenler
55
- model = None
56
- scaler = None
57
-
58
- def load_model_and_scaler():
59
- global model, scaler
60
- try:
61
- print("Model ve scaler yükleniyor...")
62
-
63
- # Model dosyası kontrolü
64
- model_path = "tabtransformer_model.pth"
65
- if not os.path.exists(model_path):
66
- raise FileNotFoundError(f"Model dosyası bulunamadı: {model_path}")
67
-
68
- # Scaler dosyası kontrolü
69
- scaler_path = "trans_scaler.pkl"
70
- if not os.path.exists(scaler_path):
71
- raise FileNotFoundError(f"Scaler dosyası bulunamadı: {scaler_path}")
72
-
73
- # Model yükleme
74
- input_dim = len(categorical_features) + len(numeric_features)
75
- model = TabTransformer(input_dim=input_dim)
76
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
77
- model.eval()
78
-
79
- # Scaler yükleme
80
- with open(scaler_path, 'rb') as f:
81
- scaler = pickle.load(f)
82
-
83
- print("Model ve scaler başarıyla yüklendi!")
84
- return True
85
- except Exception as e:
86
- print(f"Model yükleme hatası: {str(e)}")
87
- return False
88
-
89
  def predict(*inputs):
90
- if model is None or scaler is None:
91
- return {"Error": "Model henüz yüklenmedi"}
92
-
93
  try:
94
- # Girdileri ayır
 
 
 
 
 
 
 
95
  cat_inputs = inputs[:len(categorical_features)]
96
  num_inputs = inputs[len(categorical_features):]
97
 
98
- # Kategorik verileri dönüştür
99
  cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
100
-
101
- # Sayısal verileri dönüştür
102
  num_data = [float(val) for val in num_inputs]
103
 
104
- # DataFrame oluştur
105
  data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
106
-
107
- # Veriyi ölçeklendir
 
 
 
108
  scaled_data = scaler.transform(data)
 
 
 
 
 
 
 
109
 
110
- # Tahmin
111
  with torch.no_grad():
112
  tensor_data = torch.FloatTensor(scaled_data)
113
- logits = model(tensor_data)
114
- probabilities = F.softmax(logits, dim=1).numpy()
115
-
 
116
  return {
117
- "Probability of Response": float(probabilities[0][0]),
118
- "Probability of Non-Response": float(probabilities[0][1])
119
  }
120
  except Exception as e:
121
- print(f"Tahmin hatası: {str(e)}")
122
- return {"Error": str(e)}
 
 
 
 
 
 
 
 
 
123
 
124
- # Gradio arayüzü
125
- def create_interface():
126
- inputs = [gr.Dropdown(choices=['Yes', 'No'], label=feat) for feat in categorical_features]
127
- inputs.extend([gr.Number(label=feat, value=numeric_means[feat]) for feat in numeric_features])
128
-
129
- outputs = gr.Label(label="Prediction")
130
-
131
- return gr.Interface(
132
- fn=predict,
133
- inputs=inputs,
134
- outputs=outputs,
135
- title="TabTransformer Prediction",
136
- description="Enter the features to predict the response probability"
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
- print("Uygulama başlatılıyor...")
141
-
142
- # Model ve scaler'ı yükle
143
- if not load_model_and_scaler():
144
- print("Model yüklenemedi. Uygulama sonlandırılıyor.")
145
- sys.exit(1)
146
-
147
- # Arayüzü oluştur ve başlat
148
- try:
149
- demo = create_interface()
150
- demo.launch(server_name="0.0.0.0", server_port=7860)
151
- except Exception as e:
152
- print(f"Arayüz başlatma hatası: {str(e)}")
153
- sys.exit(1)""")
 
1
  with open("app.py", "w") as f:
2
  f.write("""
3
  import os
 
4
  import torch
5
  import torch.nn as nn
 
6
  import gradio as gr
7
  import pickle
8
  import pandas as pd
9
 
10
+ print("Starting application...")
11
+ print("Current directory:", os.getcwd())
12
+ print("Files in directory:", os.listdir())
 
13
 
 
14
  class TabTransformer(nn.Module):
15
  def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
16
  super().__init__()
 
33
  x = x.squeeze(0)
34
  return self.fc(x)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def predict(*inputs):
 
 
 
37
  try:
38
+ print("Prediction started...")
39
+ # Feature lists
40
+ categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
41
+ numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
42
+ 'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
43
+ 'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
44
+
45
+ # Split inputs
46
  cat_inputs = inputs[:len(categorical_features)]
47
  num_inputs = inputs[len(categorical_features):]
48
 
49
+ # Convert inputs
50
  cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
 
 
51
  num_data = [float(val) for val in num_inputs]
52
 
53
+ # Create DataFrame
54
  data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
55
+ print("Data prepared:", data.shape)
56
+
57
+ # Load scaler and transform data
58
+ with open("trans_scaler.pkl", 'rb') as f:
59
+ scaler = pickle.load(f)
60
  scaled_data = scaler.transform(data)
61
+ print("Data scaled")
62
+
63
+ # Load model and predict
64
+ input_dim = len(categorical_features) + len(numeric_features)
65
+ model = TabTransformer(input_dim=input_dim)
66
+ model.load_state_dict(torch.load("tabtransformer_model.pth", map_location='cpu'))
67
+ model.eval()
68
 
 
69
  with torch.no_grad():
70
  tensor_data = torch.FloatTensor(scaled_data)
71
+ output = model(tensor_data)
72
+ probabilities = torch.softmax(output, dim=1)
73
+ print("Prediction completed")
74
+
75
  return {
76
+ "Response Probability": float(probabilities[0][0]),
77
+ "Non-Response Probability": float(probabilities[0][1])
78
  }
79
  except Exception as e:
80
+ print(f"Error in prediction: {str(e)}")
81
+ return {"error": str(e)}
82
+
83
+ # Default values
84
+ numeric_defaults = {
85
+ 'pvc_percent': 11.96, 'PVCQRS': 155.1, 'EF': 59.93, 'Age': 52.19,
86
+ 'PVC_Prematurity_index': 0.6158, 'QRS_ratio': 1.933, 'mean_HR': 71.28,
87
+ 'symptom_duration': 14.91, 'QTc_sinus': 425.0, 'PVCCI_dispersion': 57.1,
88
+ 'CI_variability': 22.98, 'PVC_Peak_QRS_duration': 76.13, 'PVCCI': 513.4,
89
+ 'PVC_Compansatory_interval': 1044
90
+ }
91
 
92
+ # Create interface
93
+ demo = gr.Interface(
94
+ fn=predict,
95
+ inputs=[
96
+ gr.Dropdown(choices=["Yes", "No"], label="Multifocal_PVC"),
97
+ gr.Dropdown(choices=["Yes", "No"], label="Nonsustained_VT"),
98
+ gr.Dropdown(choices=["Yes", "No"], label="gender"),
99
+ gr.Dropdown(choices=["Yes", "No"], label="HTN"),
100
+ gr.Dropdown(choices=["Yes", "No"], label="DM"),
101
+ gr.Dropdown(choices=["Yes", "No"], label="Fullcompansasion"),
102
+ gr.Number(value=numeric_defaults['pvc_percent'], label="pvc_percent"),
103
+ gr.Number(value=numeric_defaults['PVCQRS'], label="PVCQRS"),
104
+ gr.Number(value=numeric_defaults['EF'], label="EF"),
105
+ gr.Number(value=numeric_defaults['Age'], label="Age"),
106
+ gr.Number(value=numeric_defaults['PVC_Prematurity_index'], label="PVC_Prematurity_index"),
107
+ gr.Number(value=numeric_defaults['QRS_ratio'], label="QRS_ratio"),
108
+ gr.Number(value=numeric_defaults['mean_HR'], label="mean_HR"),
109
+ gr.Number(value=numeric_defaults['symptom_duration'], label="symptom_duration"),
110
+ gr.Number(value=numeric_defaults['QTc_sinus'], label="QTc_sinus"),
111
+ gr.Number(value=numeric_defaults['PVCCI_dispersion'], label="PVCCI_dispersion"),
112
+ gr.Number(value=numeric_defaults['CI_variability'], label="CI_variability"),
113
+ gr.Number(value=numeric_defaults['PVC_Peak_QRS_duration'], label="PVC_Peak_QRS_duration"),
114
+ gr.Number(value=numeric_defaults['PVCCI'], label="PVCCI"),
115
+ gr.Number(value=numeric_defaults['PVC_Compansatory_interval'], label="PVC_Compansatory_interval")
116
+ ],
117
+ outputs=gr.Label(label="Prediction"),
118
+ title="PVC Response Predictor",
119
+ description="Enter patient features to predict response probability"
120
+ )
121
 
122
  if __name__ == "__main__":
123
+ print("Launching application...")
124
+ demo.launch()""")