Marcel0123 commited on
Commit
f4bae97
Β·
verified Β·
1 Parent(s): 3c90c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -154
app.py CHANGED
@@ -1,195 +1,212 @@
1
- # app.py – Titanic Gradio App (volledige versie, NL, met achtergrond + zonsondergangeffect)
2
  import gradio as gr
3
  import pandas as pd
 
 
 
 
4
  from sklearn.model_selection import train_test_split
5
  from sklearn.preprocessing import LabelEncoder
 
6
  from sklearn.ensemble import RandomForestClassifier
7
  from sklearn.linear_model import LogisticRegression
8
- from sklearn.metrics import accuracy_score
9
- import plotly.express as px
10
- import numpy as np
11
- import os
12
 
13
  # =======================
14
- # DATA LADEN
15
  # =======================
16
- REQUIRED_COLS = {"survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked"}
17
-
18
- def load_data(csv_path="Titanic-Dataset.csv"):
19
- if not os.path.exists(csv_path):
20
- raise FileNotFoundError(f"Bestand niet gevonden: {csv_path}. Plaats het in de root van de Space.")
21
- df = pd.read_csv(csv_path)
22
- df.columns = [c.strip().lower() for c in df.columns]
23
- missing = REQUIRED_COLS - set(df.columns)
24
- if missing:
25
- raise ValueError(f"Ontbrekende kolommen in dataset: {', '.join(sorted(missing))}")
26
- # Missende waarden invullen
 
 
27
  for col in df.columns:
28
  if df[col].isna().any():
29
  if df[col].dtype == "object":
30
- df[col] = df[col].fillna(df[col].mode().iloc[0])
31
  else:
32
  df[col] = df[col].fillna(df[col].median())
 
 
 
 
 
33
  return df
34
 
35
  df = load_data()
36
 
37
  # =======================
38
- # MODEL FUNCTIES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # =======================
40
- def train_model(modeltype="Random Forest"):
41
- X = df.drop("survived", axis=1).copy()
42
- y = df["survived"].astype(int).copy()
43
- # Encode categorisch
 
44
  for c in X.select_dtypes("object").columns:
45
  le = LabelEncoder()
46
  X[c] = le.fit_transform(X[c])
47
- X_train, X_test, y_train, y_test = train_test_split(
48
- X, y, test_size=0.2, random_state=42, stratify=y
49
- )
50
- if modeltype == "Random Forest":
51
- model = RandomForestClassifier(n_estimators=300, random_state=42)
52
- else:
53
- model = LogisticRegression(max_iter=1000)
54
  model.fit(X_train, y_train)
55
- pred = model.predict(X_test)
56
- acc = accuracy_score(y_test, pred)
57
- return model, acc
 
 
 
 
 
 
58
 
59
  # =======================
60
- # TABS
61
  # =======================
62
-
63
- # Tab 1 – Introductie
64
- def tab_intro():
65
- html = """
66
- <h1 style='text-align:center; color:white;'>πŸ›³οΈ Titanic Overlevingsanalyse</h1>
67
- <p style='text-align:center; color:white; max-width:820px; margin:auto;'>
68
- Ontdek de data achter de tragedie van de Titanic.
69
- Verken patronen, train machine-learningmodellen en bereken jouw kans om te overleven.
70
- </p>
71
- <div style='text-align:center; color:#d6e3ff; margin-top:10px;'>
72
- Datasetkolommen gedetecteerd: <code>survived, pclass, sex, age, sibsp, parch, fare, embarked</code>
73
- </div>
74
- """
75
- return html
76
-
77
- # Tab 2 – Verkenning
78
- def tab_verkenning():
79
- fig1 = px.histogram(
80
- df,
81
- x="age",
82
- color=df["survived"].map({0: "Niet overleefd", 1: "Overleefd"}),
83
- nbins=30,
84
- title="Leeftijdsverdeling per overlevingsstatus",
85
- )
86
- fig1.update_layout(legend_title_text="Status", bargap=0.05)
87
- fig2 = px.box(
88
- df,
89
- x="pclass",
90
- y="fare",
91
- color=df["survived"].map({0: "Niet overleefd", 1: "Overleefd"}),
92
- title="Ticketprijs per klasse",
93
  )
94
- fig2.update_layout(legend_title_text="Status")
95
- return fig1, fig2
96
-
97
- # Tab 3 – Machine Learning
98
- def tab_model(model_type):
99
- try:
100
- _, acc = train_model(model_type)
101
- return f"Het {model_type}-model behaalt een nauwkeurigheid van **{acc:.2%}**."
102
- except Exception as e:
103
- return f"⚠️ Fout bij trainen: {e}"
104
-
105
- # Tab 4 – Voorspelling
106
- def predict_overleven(pclass, sex, age, sibsp, parch, fare, embarked):
107
- X = df.drop("survived", axis=1).copy()
108
- y = df["survived"].astype(int).copy()
109
- for c in X.select_dtypes("object").columns:
110
- le = LabelEncoder()
111
- X[c] = le.fit_transform(X[c])
112
- rf = RandomForestClassifier(n_estimators=300, random_state=42)
113
- rf.fit(X, y)
114
- # Encode invoer
115
- sex_enc = 1 if str(sex).lower().startswith("v") else 0 # Vrouw=1, Man=0
116
- embarked_enc = {"C": 0, "Q": 1, "S": 2}.get(str(embarked).strip()[0].upper(), 2)
117
- row = [[int(pclass), sex_enc, float(age), int(sibsp), int(parch), float(fare), embarked_enc]]
118
- p = rf.predict_proba(row)[0, 1]
119
- return f"🎯 Je geschatte overlevingskans is **{p:.1%}**."
120
 
121
  # =======================
122
- # UI – Gradio
123
  # =======================
124
- custom_css = """
125
  body {
126
- background: url('titanic_bg.png') no-repeat center center fixed;
127
- background-size: cover;
128
- color: white;
 
129
  }
130
  .gradio-container {
131
- background: rgba(10, 16, 26, 0.70);
132
  }
133
  .gradio-container::before {
134
- content: '';
135
- position: fixed;
136
- top: 0; right: 0;
137
- width: 42vw; height: 42vh;
138
- background: radial-gradient(circle at top right, rgba(255,190,120,0.45) 0%, rgba(255,190,120,0.10) 45%, transparent 70%);
139
- pointer-events: none;
140
- z-index: 0;
141
- }
142
- h1, h2, h3, p, label, .gr-markdown { color: #eef5ff !important; }
143
- label { font-weight: 600; }
144
- div.svelte-1ipelgc, .block.padded {
145
- background: rgba(20, 28, 42, 0.70) !important;
146
- border-radius: 16px;
147
- border: 1px solid rgba(60, 80, 110, 0.5);
148
- }
149
- button.svelte-1ipelgc, .tabitem {
150
- backdrop-filter: blur(2px);
151
  }
 
 
 
 
152
  """
153
 
154
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue")) as demo:
155
- gr.Markdown("<h1 style='text-align:center;'>πŸ›³οΈ Titanic Data Explorer</h1>")
156
- with gr.Tabs():
157
- with gr.Tab("Introductie"):
158
- gr.HTML(tab_intro())
159
-
160
- with gr.Tab("Verkenning"):
161
- btn1 = gr.Button("Toon grafieken")
162
- out1 = gr.Plot(label="Leeftijdsverdeling")
163
- out2 = gr.Plot(label="Ticketprijs per klasse")
164
- btn1.click(fn=tab_verkenning, outputs=[out1, out2])
165
-
166
- with gr.Tab("Machine Learning"):
167
- with gr.Row():
168
- model_dropdown = gr.Dropdown(
169
- ["Random Forest", "Logistic Regression"],
170
- label="Kies modeltype",
171
- value="Random Forest",
172
- )
173
- btn2 = gr.Button("Train model")
174
- out3 = gr.Markdown()
175
- btn2.click(fn=tab_model, inputs=model_dropdown, outputs=out3)
176
-
177
- with gr.Tab("Voorspel je kans"):
178
- with gr.Row():
179
- pclass = gr.Slider(1, 3, 2, step=1, label="Klasse (1=1e, 3=3e)")
180
- sex = gr.Radio(["Man", "Vrouw"], label="Geslacht", value="Man")
181
- age = gr.Slider(0, 80, 30, label="Leeftijd")
182
- with gr.Row():
183
- sibsp = gr.Slider(0, 8, 1, step=1, label="Aantal broers/zussen aan boord")
184
- parch = gr.Slider(0, 6, 0, step=1, label="Aantal ouders/kinderen aan boord")
185
- fare = gr.Slider(0, 500, 50, label="Ticketprijs (Β£)")
186
- embarked = gr.Radio(["C", "Q", "S"], label="Vertrekhaven", value="S")
187
- btn3 = gr.Button("Voorspel")
188
- out4 = gr.Markdown()
189
- btn3.click(
190
- fn=predict_overleven,
191
- inputs=[pclass, sex, age, sibsp, parch, fare, embarked],
192
- outputs=out4,
193
- )
194
 
195
  demo.launch()
 
1
+ # app.py – Titanic Data Explorer – Gradio One-Page Edition (Glossy Night Sky)
2
  import gradio as gr
3
  import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.preprocessing import LabelEncoder
10
+ from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve
11
  from sklearn.ensemble import RandomForestClassifier
12
  from sklearn.linear_model import LogisticRegression
 
 
 
 
13
 
14
  # =======================
15
+ # DATA
16
  # =======================
17
+ def load_data(path="Titanic-Dataset.csv"):
18
+ if not os.path.exists(path):
19
+ raise FileNotFoundError("❌ Titanic-Dataset.csv niet gevonden in de rootmap.")
20
+ df = pd.read_csv(path)
21
+ df.columns = [c.lower().strip() for c in df.columns]
22
+
23
+ # kolommen check
24
+ req = {"survived", "pclass", "sex", "age", "sibsp", "parch", "fare", "embarked"}
25
+ miss = req - set(df.columns)
26
+ if miss:
27
+ raise ValueError(f"Ontbrekende kolommen: {miss}")
28
+
29
+ # missende waarden vullen
30
  for col in df.columns:
31
  if df[col].isna().any():
32
  if df[col].dtype == "object":
33
+ df[col] = df[col].fillna(df[col].mode()[0])
34
  else:
35
  df[col] = df[col].fillna(df[col].median())
36
+
37
+ df["family_size"] = df["sibsp"] + df["parch"] + 1
38
+ df["sex"] = df["sex"].astype(str).str.title()
39
+ df["embarked"] = df["embarked"].astype(str).str.upper()
40
+ df["status"] = df["survived"].map({0: "Niet overleefd", 1: "Overleefd"})
41
  return df
42
 
43
  df = load_data()
44
 
45
  # =======================
46
+ # PLOTS
47
+ # =======================
48
+ def make_plot(fig, title):
49
+ fig.update_layout(
50
+ title=title,
51
+ paper_bgcolor="rgba(0,0,0,0)",
52
+ plot_bgcolor="rgba(0,0,0,0)",
53
+ font=dict(color="#EAF2FF"),
54
+ title_font=dict(size=18, color="#FFD26A"),
55
+ margin=dict(l=40, r=40, t=60, b=40)
56
+ )
57
+ return fig
58
+
59
+ def plot_class_distribution(x):
60
+ f = px.pie(x, names="pclass", color="pclass", color_discrete_sequence=px.colors.sequential.Blues)
61
+ return make_plot(f, "Verdeling per Klasse")
62
+
63
+ def plot_survival_heatmap(x):
64
+ pivot = x.pivot_table(index="sex", columns="pclass", values="survived", aggfunc="mean")
65
+ f = go.Figure(data=go.Heatmap(
66
+ z=pivot.values,
67
+ x=[str(c) for c in pivot.columns],
68
+ y=pivot.index,
69
+ colorscale="YlGnBu",
70
+ zmin=0,
71
+ zmax=1
72
+ ))
73
+ return make_plot(f, "Overlevingspercentage per Geslacht en Klasse")
74
+
75
+ def plot_density_age_fare(x):
76
+ f = px.density_contour(x, x="age", y="fare", color="status", marginal_x="histogram", marginal_y="histogram")
77
+ return make_plot(f, "Leeftijd vs Ticketprijs (dichtheidsverdeling)")
78
+
79
+ def plot_bubble_family_fare(x):
80
+ f = px.scatter(
81
+ x, x="fare", y="family_size", size="age", color="status",
82
+ hover_data=["sex", "pclass"], size_max=40, color_discrete_sequence=px.colors.qualitative.Set3
83
+ )
84
+ return make_plot(f, "Bubble Chart β€” Fare vs Family Size vs Age")
85
+
86
+ def plot_sunburst(x):
87
+ f = px.sunburst(x, path=["sex", "pclass", "status"], color="status",
88
+ color_discrete_map={"Overleefd": "#FFD26A", "Niet overleefd": "#1E3E78"})
89
+ return make_plot(f, "Sunburst β€” Geslacht β†’ Klasse β†’ Overleving")
90
+
91
+ def plot_treemap(x):
92
+ f = px.treemap(x, path=["embarked", "pclass", "status"], values="fare",
93
+ color="status", color_discrete_map={"Overleefd": "#FFD26A", "Niet overleefd": "#1E3E78"})
94
+ return make_plot(f, "Treemap β€” Vertrekhaven β†’ Klasse β†’ Overleving")
95
+
96
+ def plot_corr_heatmap(x):
97
+ corr = x[["age", "fare", "family_size", "pclass", "sibsp", "parch", "survived"]].corr()
98
+ f = go.Figure(data=go.Heatmap(z=corr.values, x=corr.columns, y=corr.columns,
99
+ colorscale="Blues", zmin=-1, zmax=1))
100
+ return make_plot(f, "Correlatiematrix (numerieke variabelen)")
101
+
102
  # =======================
103
+ # MACHINE LEARNING
104
+ # =======================
105
+ def train_and_evaluate(x):
106
+ X = x[["pclass", "sex", "age", "fare", "embarked", "family_size", "sibsp", "parch"]].copy()
107
+ y = x["survived"].astype(int)
108
  for c in X.select_dtypes("object").columns:
109
  le = LabelEncoder()
110
  X[c] = le.fit_transform(X[c])
111
+
112
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
113
+ model = RandomForestClassifier(n_estimators=300, random_state=42)
 
 
 
 
114
  model.fit(X_train, y_train)
115
+ y_pred = model.predict(X_test)
116
+ acc = accuracy_score(y_test, y_pred)
117
+ auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])
118
+ cm = confusion_matrix(y_test, y_pred)
119
+
120
+ fig_cm = go.Figure(data=go.Heatmap(z=cm, text=cm, texttemplate="%{text}", colorscale="Blues"))
121
+ fig_cm = make_plot(fig_cm, "Confusion Matrix")
122
+
123
+ return f"🎯 **Nauwkeurigheid:** {acc:.2%} | **ROC AUC:** {auc:.3f}", fig_cm
124
 
125
  # =======================
126
+ # GRADIO INTERFACE
127
  # =======================
128
+ def dashboard():
129
+ acc_text, cm_fig = train_and_evaluate(df)
130
+ return (
131
+ f"{len(df)}", f"{df['survived'].sum()}",
132
+ f"{df['survived'].mean()*100:.1f}%", ", ".join(map(str, sorted(df['pclass'].unique()))),
133
+ plot_class_distribution(df),
134
+ plot_survival_heatmap(df),
135
+ plot_density_age_fare(df),
136
+ plot_bubble_family_fare(df),
137
+ plot_sunburst(df),
138
+ plot_treemap(df),
139
+ plot_corr_heatmap(df),
140
+ acc_text, cm_fig,
141
+ df.head(200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # =======================
145
+ # CSS THEMA
146
  # =======================
147
+ CUSTOM_CSS = """
148
  body {
149
+ background-image: url('titanic_bg.png');
150
+ background-size: cover;
151
+ background-position: center;
152
+ color: #EAF2FF;
153
  }
154
  .gradio-container {
155
+ background: rgba(10, 16, 26, 0.7);
156
  }
157
  .gradio-container::before {
158
+ content: "";
159
+ position: fixed;
160
+ top: 0; right: 0;
161
+ width: 40vw; height: 40vh;
162
+ background: radial-gradient(circle at top right, rgba(255,190,120,0.4) 0%, transparent 70%);
163
+ pointer-events: none;
 
 
 
 
 
 
 
 
 
 
 
164
  }
165
+ .kpi {background: rgba(20,28,42,0.8); border-radius: 12px; padding: 12px; text-align:center;}
166
+ .kpi .value {font-size:1.6rem; font-weight:800; color:#FFD26A;}
167
+ .kpi .label {font-size:0.9rem; color:#C4D7F0;}
168
+ .section-title {font-size:1.3rem; font-weight:800; color:#FFD26A; margin-top:12px;}
169
  """
170
 
171
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue")) as demo:
172
+ gr.HTML("<h1 style='text-align:center;margin-top:10px;'>πŸ›³οΈ Titanic Data Explorer – Night Sky Edition</h1>")
173
+ gr.HTML("<p style='text-align:center;color:#C4D7F0;'>Interactieve visualisatie & machine learning analyse</p>")
174
+
175
+ with gr.Row():
176
+ kpi1 = gr.HTML("<div class='kpi'><div class='value'>–</div><div class='label'>Totaal passagiers</div></div>")
177
+ kpi2 = gr.HTML("<div class='kpi'><div class='value'>–</div><div class='label'>Overlevenden</div></div>")
178
+ kpi3 = gr.HTML("<div class='kpi'><div class='value'>–</div><div class='label'>% Overleefd</div></div>")
179
+ kpi4 = gr.HTML("<div class='kpi'><div class='value'>–</div><div class='label'>Klassen aanwezig</div></div>")
180
+
181
+ gr.HTML("<div class='section-title'>πŸ“Š Verkenning & Patronen</div>")
182
+ with gr.Row():
183
+ fig1 = gr.Plot(label="Klasse")
184
+ fig2 = gr.Plot(label="Heatmap")
185
+ with gr.Row():
186
+ fig3 = gr.Plot(label="Density")
187
+ fig4 = gr.Plot(label="Bubble Chart")
188
+ with gr.Row():
189
+ fig5 = gr.Plot(label="Sunburst")
190
+ fig6 = gr.Plot(label="Treemap")
191
+ with gr.Row():
192
+ fig7 = gr.Plot(label="Correlaties")
193
+
194
+ gr.HTML("<div class='section-title'>πŸ€– Machine Learning</div>")
195
+ acc_md = gr.Markdown()
196
+ fig_cm = gr.Plot(label="Confusion Matrix")
197
+
198
+ gr.HTML("<div class='section-title'>πŸ—‚οΈ Data voorbeeld</div>")
199
+ table = gr.Dataframe(height=300)
200
+
201
+ def update_dashboard():
202
+ return dashboard()
203
+
204
+ demo.load(
205
+ fn=update_dashboard,
206
+ inputs=[],
207
+ outputs=[kpi1, kpi2, kpi3, kpi4,
208
+ fig1, fig2, fig3, fig4, fig5, fig6, fig7,
209
+ acc_md, fig_cm, table]
210
+ )
211
 
212
  demo.launch()