trohith89 commited on
Commit
dde03ea
·
verified ·
1 Parent(s): 9326720

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -14,18 +14,21 @@ if "epoch" not in st.session_state:
14
  st.session_state.epoch = 0
15
  if "running" not in st.session_state:
16
  st.session_state.running = False
17
- if "loss_history" not in st.session_state:
18
- st.session_state.loss_history = []
 
 
19
 
20
  # ================= TRAINING CONTROL PANEL (Top) =================
21
  st.markdown("### Training Controls")
22
- col1, col2, col3, col4, col5, col6 = st.columns(6)
23
 
24
  with col1:
25
  if st.button("↩️ Reset"):
26
  st.session_state.epoch = 0
27
  st.session_state.running = False
28
- st.session_state.loss_history = []
 
29
  with col2:
30
  if st.button("▶️ Train"):
31
  st.session_state.running = True
@@ -38,11 +41,14 @@ with col5:
38
  problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
39
  with col6:
40
  learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1])
 
 
41
 
42
  if st.session_state.running:
43
  time.sleep(1)
44
  st.session_state.epoch += 1
45
- st.session_state.loss_history.append(np.exp(-0.1 * st.session_state.epoch) + np.random.uniform(0, 0.05))
 
46
 
47
  # ================= MAIN LAYOUT =================
48
  col_features, col_hidden, col_plot = st.columns([2, 2, 3])
@@ -74,10 +80,12 @@ with col_plot:
74
 
75
  # Loss Plot
76
  fig, ax = plt.subplots()
77
- ax.plot(range(len(st.session_state.loss_history)), st.session_state.loss_history, marker="o")
 
78
  ax.set_title("Epoch vs. Loss")
79
  ax.set_xlabel("Epoch")
80
  ax.set_ylabel("Loss")
 
81
  st.pyplot(fig)
82
 
83
  # ================= DATASET SELECTION (Sidebar) =================
 
14
  st.session_state.epoch = 0
15
  if "running" not in st.session_state:
16
  st.session_state.running = False
17
+ if "train_loss_history" not in st.session_state:
18
+ st.session_state.train_loss_history = []
19
+ if "test_loss_history" not in st.session_state:
20
+ st.session_state.test_loss_history = []
21
 
22
  # ================= TRAINING CONTROL PANEL (Top) =================
23
  st.markdown("### Training Controls")
24
+ col1, col2, col3, col4, col5, col6, col7 = st.columns(7)
25
 
26
  with col1:
27
  if st.button("↩️ Reset"):
28
  st.session_state.epoch = 0
29
  st.session_state.running = False
30
+ st.session_state.train_loss_history = []
31
+ st.session_state.test_loss_history = []
32
  with col2:
33
  if st.button("▶️ Train"):
34
  st.session_state.running = True
 
41
  problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
42
  with col6:
43
  learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1])
44
+ with col7:
45
+ num_epochs = st.slider("Epochs", 1, 100, 10)
46
 
47
  if st.session_state.running:
48
  time.sleep(1)
49
  st.session_state.epoch += 1
50
+ st.session_state.train_loss_history.append(np.exp(-0.1 * st.session_state.epoch) + np.random.uniform(0, 0.05))
51
+ st.session_state.test_loss_history.append(np.exp(-0.1 * st.session_state.epoch) + np.random.uniform(0, 0.05) + 0.02)
52
 
53
  # ================= MAIN LAYOUT =================
54
  col_features, col_hidden, col_plot = st.columns([2, 2, 3])
 
80
 
81
  # Loss Plot
82
  fig, ax = plt.subplots()
83
+ ax.plot(range(len(st.session_state.train_loss_history)), st.session_state.train_loss_history, marker="o", label="Train Loss")
84
+ ax.plot(range(len(st.session_state.test_loss_history)), st.session_state.test_loss_history, marker="s", label="Test Loss")
85
  ax.set_title("Epoch vs. Loss")
86
  ax.set_xlabel("Epoch")
87
  ax.set_ylabel("Loss")
88
+ ax.legend()
89
  st.pyplot(fig)
90
 
91
  # ================= DATASET SELECTION (Sidebar) =================