halil21 commited on
Commit
888d601
·
verified ·
1 Parent(s): 6267b8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -33
app.py CHANGED
@@ -54,46 +54,54 @@ numeric_means = {
54
  'PVC_Compansatory_interval': 1044
55
  }
56
 
57
- # Model ve scaler'ı yükleme
58
- model_path = "tabtransformer_model.pth"
59
- scaler_path = "trans_scaler.pkl"
 
60
 
61
- # Model tanımı
62
- input_dim = len(categorical_features) + len(numeric_features) # Toplam giriş boyutu
63
- model = TabTransformer(input_dim=input_dim)
64
- model.load_state_dict(torch.load(model_path, weights_only=True)) # Model ağırlıklarını yükle
65
- model.eval() # Değerlendirme moduna al
66
 
67
- # Scaler yükleme
68
- with open(scaler_path, "rb") as f:
69
- scaler = pickle.load(f)
 
 
 
 
70
 
71
- # Prediction fonksiyonu
72
  def predict(*inputs):
73
- # Girdileri kategorik ve sayısal olarak ayır
74
- cat_inputs = inputs[:len(categorical_features)]
75
- num_inputs = inputs[len(categorical_features):]
 
76
 
77
- # Kategorik girdiler (binary olarak 0/1 kodlama: "Yes" -> 1, "No" -> 0)
78
- cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
79
 
80
- # Sayısal girdiler
81
- num_data = [float(val) for val in num_inputs]
82
 
83
- # Veriyi birleştir ve ölçeklendir
84
- data = pd.DataFrame([cat_data + num_data])
85
- scaled_data = scaler.transform(data)
86
 
87
- # Modelden tahmin al
88
- tensor_data = torch.FloatTensor(scaled_data)
89
- with torch.no_grad():
90
- logits = model(tensor_data)
91
- probabilities = F.softmax(logits, dim=1).numpy()
92
 
93
- return {
94
- "Probability of Response": probabilities[0][0],
95
- "Probability of Non-Response": probabilities[0][1]
96
- }
 
 
 
97
 
98
  # Gradio arayüzü
99
  inputs = (
@@ -101,8 +109,18 @@ inputs = (
101
  [gr.Number(label=feature, value=numeric_means[feature]) for feature in numeric_features]
102
  )
103
  outputs = gr.Label(label="Prediction")
104
- interface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="TabTransformer Prediction")
 
 
 
 
 
 
105
 
106
  # Spaces için başlatma
107
  if __name__ == "__main__":
108
- interface.launch(server_name="0.0.0.0", server_port=8080)""")
 
 
 
 
 
54
  'PVC_Compansatory_interval': 1044
55
  }
56
 
57
+ try:
58
+ # Model ve scaler'ı yükleme
59
+ model_path = "tabtransformer_model.pth"
60
+ scaler_path = "trans_scaler.pkl"
61
 
62
+ # Model tanımı
63
+ input_dim = len(categorical_features) + len(numeric_features) # Toplam giriş boyutu
64
+ model = TabTransformer(input_dim=input_dim)
65
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Model ağırlıklarını yükle
66
+ model.eval() # Değerlendirme moduna al
67
 
68
+ # Scaler yükleme
69
+ with open(scaler_path, "rb") as f:
70
+ scaler = pickle.load(f)
71
+
72
+ except Exception as e:
73
+ print(f"Model yükleme hatası: {str(e)}")
74
+ raise
75
 
 
76
  def predict(*inputs):
77
+ try:
78
+ # Girdileri kategorik ve sayısal olarak ayır
79
+ cat_inputs = inputs[:len(categorical_features)]
80
+ num_inputs = inputs[len(categorical_features):]
81
 
82
+ # Kategorik girdiler (binary olarak 0/1 kodlama: "Yes" -> 1, "No" -> 0)
83
+ cat_data = [1 if val == "Yes" else 0 for val in cat_inputs]
84
 
85
+ # Sayısal girdiler
86
+ num_data = [float(val) for val in num_inputs]
87
 
88
+ # Veriyi birleştir ve ölçeklendir
89
+ data = pd.DataFrame([cat_data + num_data], columns=categorical_features + numeric_features)
90
+ scaled_data = scaler.transform(data)
91
 
92
+ # Modelden tahmin al
93
+ tensor_data = torch.FloatTensor(scaled_data)
94
+ with torch.no_grad():
95
+ logits = model(tensor_data)
96
+ probabilities = F.softmax(logits, dim=1).numpy()
97
 
98
+ return {
99
+ "Probability of Response": float(probabilities[0][0]),
100
+ "Probability of Non-Response": float(probabilities[0][1])
101
+ }
102
+ except Exception as e:
103
+ print(f"Tahmin hatası: {str(e)}")
104
+ raise
105
 
106
  # Gradio arayüzü
107
  inputs = (
 
109
  [gr.Number(label=feature, value=numeric_means[feature]) for feature in numeric_features]
110
  )
111
  outputs = gr.Label(label="Prediction")
112
+ interface = gr.Interface(
113
+ fn=predict,
114
+ inputs=inputs,
115
+ outputs=outputs,
116
+ title="TabTransformer Prediction",
117
+ description="Enter the features to predict the response probability"
118
+ )
119
 
120
  # Spaces için başlatma
121
  if __name__ == "__main__":
122
+ try:
123
+ interface.launch(server_name="0.0.0.0", server_port=7860)
124
+ except Exception as e:
125
+ print(f"Arayüz başlatma hatası: {str(e)}")
126
+ raise""")