halil21 commited on
Commit
75cf953
·
verified ·
1 Parent(s): 888d601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -63
app.py CHANGED
@@ -1,5 +1,7 @@
1
  with open("app.py", "w") as f:
2
  f.write("""
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -7,6 +9,11 @@ import gradio as gr
7
  import pickle
8
  import pandas as pd
9
 
 
 
 
 
 
10
  # TabTransformer Model Tanımı
11
  class TabTransformer(nn.Module):
12
  def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
@@ -25,102 +32,122 @@ class TabTransformer(nn.Module):
25
 
26
  def forward(self, x):
27
  x = self.embedding(x)
28
- x = x.unsqueeze(0) # Add sequence length dimension
29
  x = self.transformer_encoder(x)
30
- x = x.squeeze(0) # Remove sequence length dimension
31
  return self.fc(x)
32
 
33
- # Kategorik ve sayısal özellikler
34
  categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
35
  numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
36
- 'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
37
- 'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
38
 
39
- # Mean değerleri ile varsayılanlar
40
  numeric_means = {
41
- 'pvc_percent': 11.96,
42
- 'PVCQRS': 155.1,
43
- 'EF': 59.93,
44
- 'Age': 52.19,
45
- 'PVC_Prematurity_index': 0.6158,
46
- 'QRS_ratio': 1.933,
47
- 'mean_HR': 71.28,
48
- 'symptom_duration': 14.91,
49
- 'QTc_sinus': 425.0,
50
- 'PVCCI_dispersion': 57.1,
51
- 'CI_variability': 22.98,
52
- 'PVC_Peak_QRS_duration': 76.13,
53
- 'PVCCI': 513.4,
54
  'PVC_Compansatory_interval': 1044
55
  }
56
 
57
- try:
58
- # Model ve scaler'ı yükleme
59
- model_path = "tabtransformer_model.pth"
60
- scaler_path = "trans_scaler.pkl"
61
-
62
- # Model tanımı
63
- input_dim = len(categorical_features) + len(numeric_features) # Toplam giriş boyutu
64
- model = TabTransformer(input_dim=input_dim)
65
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Model ağırlıklarını yükle
66
- model.eval() # Değerlendirme moduna al
67
-
68
- # Scaler yükleme
69
- with open(scaler_path, "rb") as f:
70
- scaler = pickle.load(f)
71
 
72
- except Exception as e:
73
- print(f"Model yükleme hatası: {str(e)}")
74
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def predict(*inputs):
 
 
 
77
  try:
78
- # Girdileri kategorik ve sayısal olarak ayır
79
  cat_inputs = inputs[:len(categorical_features)]
80
  num_inputs = inputs[len(categorical_features):]
81
-
82
- # Kategorik girdiler (binary olarak 0/1 kodlama: "Yes" -> 1, "No" -> 0)
83
  cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
84
-
85
- # Sayısal girdiler
86
  num_data = [float(val) for val in num_inputs]
87
-
88
- # Veriyi birleştir ve ölçeklendir
89
  data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
 
 
90
  scaled_data = scaler.transform(data)
91
-
92
- # Modelden tahmin al
93
- tensor_data = torch.FloatTensor(scaled_data)
94
  with torch.no_grad():
 
95
  logits = model(tensor_data)
96
  probabilities = F.softmax(logits, dim=1).numpy()
97
-
98
  return {
99
  "Probability of Response": float(probabilities[0][0]),
100
  "Probability of Non-Response": float(probabilities[0][1])
101
  }
102
  except Exception as e:
103
  print(f"Tahmin hatası: {str(e)}")
104
- raise
105
 
106
  # Gradio arayüzü
107
- inputs = (
108
- [gr.Dropdown(choices=['Yes', 'No'], label=feature) for feature in categorical_features] +
109
- [gr.Number(label=feature, value=numeric_means[feature]) for feature in numeric_features]
110
- )
111
- outputs = gr.Label(label="Prediction")
112
- interface = gr.Interface(
113
- fn=predict,
114
- inputs=inputs,
115
- outputs=outputs,
116
- title="TabTransformer Prediction",
117
- description="Enter the features to predict the response probability"
118
- )
 
119
 
120
- # Spaces için başlatma
121
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
122
  try:
123
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
124
  except Exception as e:
125
  print(f"Arayüz başlatma hatası: {str(e)}")
126
- raise""")
 
1
  with open("app.py", "w") as f:
2
  f.write("""
3
+ import os
4
+ import sys
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
9
  import pickle
10
  import pandas as pd
11
 
12
+ # Debug için print fonksiyonları
13
+ print("Python version:", sys.version)
14
+ print("Current working directory:", os.getcwd())
15
+ print("Directory contents:", os.listdir())
16
+
17
  # TabTransformer Model Tanımı
18
  class TabTransformer(nn.Module):
19
  def __init__(self, input_dim, num_classes=2, d_model=64, nhead=4, num_layers=3, dropout=0.1):
 
32
 
33
  def forward(self, x):
34
  x = self.embedding(x)
35
+ x = x.unsqueeze(0)
36
  x = self.transformer_encoder(x)
37
+ x = x.squeeze(0)
38
  return self.fc(x)
39
 
40
+ # Özellikler
41
  categorical_features = ['Multifocal_PVC', 'Nonsustained_VT', 'gender', 'HTN', 'DM', 'Fullcompansasion']
42
  numeric_features = ['pvc_percent', 'PVCQRS', 'EF', 'Age', 'PVC_Prematurity_index', 'QRS_ratio',
43
+ 'mean_HR', 'symptom_duration', 'QTc_sinus', 'PVCCI_dispersion',
44
+ 'CI_variability', 'PVC_Peak_QRS_duration', 'PVCCI', 'PVC_Compansatory_interval']
45
 
 
46
  numeric_means = {
47
+ 'pvc_percent': 11.96, 'PVCQRS': 155.1, 'EF': 59.93, 'Age': 52.19,
48
+ 'PVC_Prematurity_index': 0.6158, 'QRS_ratio': 1.933, 'mean_HR': 71.28,
49
+ 'symptom_duration': 14.91, 'QTc_sinus': 425.0, 'PVCCI_dispersion': 57.1,
50
+ 'CI_variability': 22.98, 'PVC_Peak_QRS_duration': 76.13, 'PVCCI': 513.4,
 
 
 
 
 
 
 
 
 
51
  'PVC_Compansatory_interval': 1044
52
  }
53
 
54
+ # Global değişkenler
55
+ model = None
56
+ scaler = None
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ def load_model_and_scaler():
59
+ global model, scaler
60
+ try:
61
+ print("Model ve scaler yükleniyor...")
62
+
63
+ # Model dosyası kontrolü
64
+ model_path = "tabtransformer_model.pth"
65
+ if not os.path.exists(model_path):
66
+ raise FileNotFoundError(f"Model dosyası bulunamadı: {model_path}")
67
+
68
+ # Scaler dosyası kontrolü
69
+ scaler_path = "trans_scaler.pkl"
70
+ if not os.path.exists(scaler_path):
71
+ raise FileNotFoundError(f"Scaler dosyası bulunamadı: {scaler_path}")
72
+
73
+ # Model yükleme
74
+ input_dim = len(categorical_features) + len(numeric_features)
75
+ model = TabTransformer(input_dim=input_dim)
76
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
77
+ model.eval()
78
+
79
+ # Scaler yükleme
80
+ with open(scaler_path, 'rb') as f:
81
+ scaler = pickle.load(f)
82
+
83
+ print("Model ve scaler başarıyla yüklendi!")
84
+ return True
85
+ except Exception as e:
86
+ print(f"Model yükleme hatası: {str(e)}")
87
+ return False
88
 
89
  def predict(*inputs):
90
+ if model is None or scaler is None:
91
+ return {"Error": "Model henüz yüklenmedi"}
92
+
93
  try:
94
+ # Girdileri ayır
95
  cat_inputs = inputs[:len(categorical_features)]
96
  num_inputs = inputs[len(categorical_features):]
97
+
98
+ # Kategorik verileri dönüştür
99
  cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
100
+
101
+ # Sayısal verileri dönüştür
102
  num_data = [float(val) for val in num_inputs]
103
+
104
+ # DataFrame oluştur
105
  data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
106
+
107
+ # Veriyi ölçeklendir
108
  scaled_data = scaler.transform(data)
109
+
110
+ # Tahmin
 
111
  with torch.no_grad():
112
+ tensor_data = torch.FloatTensor(scaled_data)
113
  logits = model(tensor_data)
114
  probabilities = F.softmax(logits, dim=1).numpy()
115
+
116
  return {
117
  "Probability of Response": float(probabilities[0][0]),
118
  "Probability of Non-Response": float(probabilities[0][1])
119
  }
120
  except Exception as e:
121
  print(f"Tahmin hatası: {str(e)}")
122
+ return {"Error": str(e)}
123
 
124
  # Gradio arayüzü
125
+ def create_interface():
126
+ inputs = [gr.Dropdown(choices=['Yes', 'No'], label=feat) for feat in categorical_features]
127
+ inputs.extend([gr.Number(label=feat, value=numeric_means[feat]) for feat in numeric_features])
128
+
129
+ outputs = gr.Label(label="Prediction")
130
+
131
+ return gr.Interface(
132
+ fn=predict,
133
+ inputs=inputs,
134
+ outputs=outputs,
135
+ title="TabTransformer Prediction",
136
+ description="Enter the features to predict the response probability"
137
+ )
138
 
 
139
  if __name__ == "__main__":
140
+ print("Uygulama başlatılıyor...")
141
+
142
+ # Model ve scaler'ı yükle
143
+ if not load_model_and_scaler():
144
+ print("Model yüklenemedi. Uygulama sonlandırılıyor.")
145
+ sys.exit(1)
146
+
147
+ # Arayüzü oluştur ve başlat
148
  try:
149
+ demo = create_interface()
150
+ demo.launch(server_name="0.0.0.0", server_port=7860)
151
  except Exception as e:
152
  print(f"Arayüz başlatma hatası: {str(e)}")
153
+ sys.exit(1)""")