resolverkatla commited on
Commit
430820d
·
verified ·
1 Parent(s): 9da6376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -7,7 +7,10 @@ from sklearn.preprocessing import LabelEncoder
7
 
8
  # Load dataset
9
  df = pd.read_csv("titanic.csv")
10
- df = df[["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Survived"]].dropna()
 
 
 
11
 
12
  # Encode 'Sex'
13
  df["Sex"] = LabelEncoder().fit_transform(df["Sex"]) # male=1, female=0
@@ -21,9 +24,9 @@ model = RandomForestClassifier(n_estimators=100, random_state=42)
21
  model.fit(X, y)
22
 
23
  # Prediction function
24
- def predict_survival(pclass, sex, age, sibsp, parch, fare):
25
  sex_encoded = 1 if sex == "male" else 0
26
- input_data = np.array([[pclass, sex_encoded, age, sibsp, parch, fare]])
27
  prediction = model.predict(input_data)[0]
28
  return "✅ Survived" if prediction == 1 else "❌ Did not survive"
29
 
@@ -34,8 +37,6 @@ iface = gr.Interface(
34
  gr.Dropdown([1, 2, 3], label="Passenger Class"),
35
  gr.Radio(["male", "female"], label="Sex"),
36
  gr.Slider(0, 80, step=1, label="Age"),
37
- gr.Slider(0, 10, step=1, label="Siblings/Spouses Aboard"),
38
- gr.Slider(0, 10, step=1, label="Parents/Children Aboard"),
39
  gr.Slider(0, 500, step=1, label="Fare"),
40
  ],
41
  outputs="text",
 
7
 
8
  # Load dataset
9
  df = pd.read_csv("titanic.csv")
10
+ expected_cols = ["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Survived"]
11
+ available_cols = [col for col in expected_cols if col in df.columns]
12
+ df = df[available_cols].dropna()
13
+
14
 
15
  # Encode 'Sex'
16
  df["Sex"] = LabelEncoder().fit_transform(df["Sex"]) # male=1, female=0
 
24
  model.fit(X, y)
25
 
26
  # Prediction function
27
+ def predict_survival(pclass, sex, age, fare):
28
  sex_encoded = 1 if sex == "male" else 0
29
+ input_data = np.array([[pclass, sex_encoded, age, fare]])
30
  prediction = model.predict(input_data)[0]
31
  return "✅ Survived" if prediction == 1 else "❌ Did not survive"
32
 
 
37
  gr.Dropdown([1, 2, 3], label="Passenger Class"),
38
  gr.Radio(["male", "female"], label="Sex"),
39
  gr.Slider(0, 80, step=1, label="Age"),
 
 
40
  gr.Slider(0, 500, step=1, label="Fare"),
41
  ],
42
  outputs="text",