Suriyaaan commited on
Commit
3f40137
·
verified ·
1 Parent(s): 71b6d09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -23,12 +23,25 @@ if st.button("Train"):
23
  model.compile(optimizer="sgd",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
24
  model.fit(x_train,y_train,epochs=10,batch_size=128,validation_split=0.2)
25
  fig, axs = plt.subplots(6, 1, figsize=(8, 6))
26
- for i in range(6):
27
- axs[i].imshow(
28
- m.fit_transform(model.layers[0].weights[0][:, :, :, i][:, :, 0]),
29
- cmap="gray"
30
- )
31
- axs[i].axis("off") # remove axes for cleaner look
32
-
33
- st.pyplot(fig)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  model.compile(optimizer="sgd",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
24
  model.fit(x_train,y_train,epochs=10,batch_size=128,validation_split=0.2)
25
  fig, axs = plt.subplots(6, 1, figsize=(8, 6))
26
+ col1,col2,col3 = st.columns(3)
27
+ with col1:
28
+ st.write("Filters")
29
+ for i in range(6):
30
+ axs[i].imshow(
31
+ m.fit_transform(model.layers[0].weights[0][:, :, :, i][:, :, 0]),
32
+ cmap="gray"
33
+ )
34
+ axs[i].axis("off")
35
+ with col2:
36
+ st.write("Conv2D Layer")
37
+ sub = Model(inputs=model.inputs[0],outputs=model.layers[0].output)
38
+ fig,ax = plt.subplots(6,1,figsize=(8,6))
39
+ for i in range(6):
40
+ ax[i].imshow(sub.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
41
+ ax[i].axis("off")
42
+ with col3:
43
+ st.write("Average Pooling Layer")
44
+ max1 = Model(inputs=model.inputs[0],outputs=model.layers[1].output)
45
+ fig,ax1 = plt.subplots(6,1,figsize=(8,6))
46
+ for i in range(6):
47
+ ax1[i].imshow(max1.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")