Spaces:
Build error
Build error
Shingome
commited on
Commit
·
abd6e8b
1
Parent(s):
583f55a
initial commit
Browse files- app.py +16 -0
- final_synapses.npz +0 -0
- requirements.txt +4 -0
- src/create_dataset.py +114 -0
- src/frommap.py +39 -0
- src/predict.py +95 -0
- src/prepare_image.py +28 -0
- src/prepare_image_for_dataset.py +27 -0
- src/training.py +151 -0
app.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.predict import *
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict(image):
|
| 6 |
+
return draw_image(create_map(prepare_image(image)), image.size)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
iface = gr.Interface(
|
| 11 |
+
fn=predict,
|
| 12 |
+
inputs=gr.Image(type="pil"),
|
| 13 |
+
outputs=["image"]
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
iface.launch(share=True)
|
final_synapses.npz
ADDED
|
Binary file (11.6 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.31.5
|
| 2 |
+
matplotlib==3.9.0
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
Pillow==10.3.0
|
src/create_dataset.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tkinter as tk
|
| 2 |
+
from random import randint as rd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import ImageTk, Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class App:
|
| 8 |
+
def save(fun):
|
| 9 |
+
def saved(self):
|
| 10 |
+
if self.dataset:
|
| 11 |
+
dataset = np.asarray(self.dataset, dtype=object)
|
| 12 |
+
try:
|
| 13 |
+
old_dataset = np.load('./../dataset.npy',
|
| 14 |
+
allow_pickle=True)
|
| 15 |
+
dataset = np.vstack((dataset, old_dataset))
|
| 16 |
+
np.save('./../dataset.npy', dataset)
|
| 17 |
+
except:
|
| 18 |
+
np.save('./../dataset.npy', dataset)
|
| 19 |
+
print('Создан новый файл')
|
| 20 |
+
print(dataset.shape[0])
|
| 21 |
+
fun(self)
|
| 22 |
+
|
| 23 |
+
return saved
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.window = tk.Tk()
|
| 27 |
+
self.window.resizable(False, False)
|
| 28 |
+
self.window.title('dataset')
|
| 29 |
+
self.window.geometry('550x320')
|
| 30 |
+
|
| 31 |
+
self.dataset = []
|
| 32 |
+
|
| 33 |
+
button_135 = tk.Button(self.window, text='135', width=10, height=5, command=lambda: self.save_image(4))
|
| 34 |
+
button_135.place(x=450, y=10)
|
| 35 |
+
|
| 36 |
+
button_45 = tk.Button(self.window, text='45', width=10, height=5, command=lambda: self.save_image(3))
|
| 37 |
+
button_45.place(x=450, y=110)
|
| 38 |
+
|
| 39 |
+
button_180 = tk.Button(self.window, text='180', width=10, height=5, command=lambda: self.save_image(2))
|
| 40 |
+
button_180.place(x=350, y=10)
|
| 41 |
+
|
| 42 |
+
button_90 = tk.Button(self.window, text='90', width=10, height=5, command=lambda: self.save_image(1))
|
| 43 |
+
button_90.place(x=350, y=110)
|
| 44 |
+
|
| 45 |
+
button_none = tk.Button(self.window, text='none', width=10, height=5, command=lambda: self.save_image(0))
|
| 46 |
+
button_none.place(x=350, y=210)
|
| 47 |
+
|
| 48 |
+
button_next = tk.Button(self.window, text='next', width=10, height=5, command=lambda: self.next_image())
|
| 49 |
+
button_next.place(x=450, y=210)
|
| 50 |
+
|
| 51 |
+
self.images = np.load('./../images_for_dataset.npy')
|
| 52 |
+
|
| 53 |
+
self.canvas = tk.Canvas(self.window, height=300, width=300)
|
| 54 |
+
self.image_prev = self.images[rd(0, np.shape(self.images)[0])]
|
| 55 |
+
self.image = Image.fromarray(self.image_prev)
|
| 56 |
+
self.image = self.image.resize((300, 300), resample=Image.NEAREST)
|
| 57 |
+
self.photo = ImageTk.PhotoImage(self.image)
|
| 58 |
+
self.image = self.canvas.create_image(0, 0, anchor='nw', image=self.photo)
|
| 59 |
+
self.canvas.place(x=10, y=10)
|
| 60 |
+
|
| 61 |
+
self.window.mainloop()
|
| 62 |
+
|
| 63 |
+
@save
|
| 64 |
+
def next_image(self):
|
| 65 |
+
self.dataset = []
|
| 66 |
+
self.image_prev = self.images[rd(0, np.shape(self.images)[0])]
|
| 67 |
+
self.image = Image.fromarray(self.image_prev)
|
| 68 |
+
self.image = self.image.resize((300, 300), resample=Image.NEAREST)
|
| 69 |
+
self.photo = ImageTk.PhotoImage(self.image)
|
| 70 |
+
self.image = self.canvas.create_image(0, 0, anchor='nw', image=self.photo)
|
| 71 |
+
self.canvas.place(x=10, y=10)
|
| 72 |
+
|
| 73 |
+
def save_image(self, answer):
|
| 74 |
+
image_mat = np.asarray(self.image_prev) / 255
|
| 75 |
+
image_array = np.reshape(image_mat, (1, 64))
|
| 76 |
+
self.dataset.append([answer, image_array])
|
| 77 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 78 |
+
|
| 79 |
+
if answer in (3, 4):
|
| 80 |
+
image_array = image_mat.T
|
| 81 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 82 |
+
self.dataset.append([answer, image_array])
|
| 83 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 84 |
+
|
| 85 |
+
image_array = np.rot90(np.rot90(image_mat))
|
| 86 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 87 |
+
self.dataset.append([answer, image_array])
|
| 88 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 89 |
+
|
| 90 |
+
image_array = image_mat.T
|
| 91 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 92 |
+
self.dataset.append([answer, image_array])
|
| 93 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
image_array = np.flipud(image_mat)
|
| 97 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 98 |
+
self.dataset.append([answer, image_array])
|
| 99 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 100 |
+
|
| 101 |
+
image_array = np.fliplr(image_mat)
|
| 102 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 103 |
+
self.dataset.append([answer, image_array])
|
| 104 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 105 |
+
|
| 106 |
+
image_array = np.flipud(image_mat)
|
| 107 |
+
image_array = np.reshape(image_array, (1, 64))
|
| 108 |
+
self.dataset.append([answer, image_array])
|
| 109 |
+
self.dataset.append([answer, np.ones((1, 64)) - image_array])
|
| 110 |
+
|
| 111 |
+
self.next_image()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
app = App()
|
src/frommap.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
step = 5
|
| 6 |
+
width = 600
|
| 7 |
+
height = 600
|
| 8 |
+
|
| 9 |
+
width *= step
|
| 10 |
+
height *= step
|
| 11 |
+
|
| 12 |
+
image = Image.new('RGB', (width, height), (255, 255, 255))
|
| 13 |
+
draw = ImageDraw.Draw(image)
|
| 14 |
+
|
| 15 |
+
map = np.load('./../map.npy')
|
| 16 |
+
|
| 17 |
+
print(len(map))
|
| 18 |
+
|
| 19 |
+
iter = 0
|
| 20 |
+
|
| 21 |
+
k = 8
|
| 22 |
+
|
| 23 |
+
for x in range(0, width, step):
|
| 24 |
+
for y in range(0, height, step):
|
| 25 |
+
if map[iter] == 1:
|
| 26 |
+
xn, yn = x, y + k
|
| 27 |
+
elif map[iter] == 2:
|
| 28 |
+
xn, yn = x + k, y
|
| 29 |
+
elif map[iter] == 3:
|
| 30 |
+
xn, yn = x + k, y - k
|
| 31 |
+
elif map[iter] == 4:
|
| 32 |
+
xn, yn = x + k, y + k
|
| 33 |
+
else:
|
| 34 |
+
iter += 1
|
| 35 |
+
continue
|
| 36 |
+
draw.line(xy=[(x, y), (xn, yn)], fill='black')
|
| 37 |
+
iter += 1
|
| 38 |
+
|
| 39 |
+
image.show()
|
src/predict.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image, ImageDraw
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def prepare_image(image: Image):
|
| 6 |
+
# convert image
|
| 7 |
+
width, height = image.size
|
| 8 |
+
width = width // 8 * 8
|
| 9 |
+
height = height // 8 * 8
|
| 10 |
+
image = image.crop((0, 0, width, height))
|
| 11 |
+
image = image.convert('L')
|
| 12 |
+
|
| 13 |
+
image_array = []
|
| 14 |
+
|
| 15 |
+
# image to arrays
|
| 16 |
+
for x in range(width):
|
| 17 |
+
for y in range(height):
|
| 18 |
+
crop = image.crop((x, y, x + 8, y + 8))
|
| 19 |
+
image_array.append(np.reshape(np.asarray(crop) / 255, (1, 64)))
|
| 20 |
+
|
| 21 |
+
# save image_array
|
| 22 |
+
image_array = np.asarray(image_array)
|
| 23 |
+
|
| 24 |
+
return image_array
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def draw_image(map, size):
|
| 28 |
+
# size
|
| 29 |
+
step = 10
|
| 30 |
+
width, height = size
|
| 31 |
+
new_width = width // 8 * 8 * step
|
| 32 |
+
new_height = height // 8 * 8 * step
|
| 33 |
+
|
| 34 |
+
# create canvas
|
| 35 |
+
image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
|
| 36 |
+
draw = ImageDraw.Draw(image)
|
| 37 |
+
|
| 38 |
+
iter = 0
|
| 39 |
+
|
| 40 |
+
# drawing
|
| 41 |
+
for x in range(0, new_width, step):
|
| 42 |
+
for y in range(0, new_height, step):
|
| 43 |
+
if map[iter] == 1:
|
| 44 |
+
xn, yn = x, y + 8
|
| 45 |
+
elif map[iter] == 2:
|
| 46 |
+
xn, yn = x + 8, y
|
| 47 |
+
elif map[iter] == 3:
|
| 48 |
+
xn, yn = x + 8, y - 8
|
| 49 |
+
elif map[iter] == 4:
|
| 50 |
+
xn, yn = x + 8, y + 8
|
| 51 |
+
else:
|
| 52 |
+
iter += 1
|
| 53 |
+
continue
|
| 54 |
+
draw.line(xy=[(x, y), (xn, yn)], fill='black')
|
| 55 |
+
iter += 1
|
| 56 |
+
|
| 57 |
+
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
| 58 |
+
|
| 59 |
+
return image
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_map(image_array):
|
| 63 |
+
# Load synapses
|
| 64 |
+
synapses = np.load('./final_synapses.npz')
|
| 65 |
+
W1 = synapses['arr_0']
|
| 66 |
+
b1 = synapses['arr_1']
|
| 67 |
+
W2 = synapses['arr_2']
|
| 68 |
+
b2 = synapses['arr_3']
|
| 69 |
+
W3 = synapses['arr_4']
|
| 70 |
+
b3 = synapses['arr_5']
|
| 71 |
+
|
| 72 |
+
def predict(x):
|
| 73 |
+
def relu(t):
|
| 74 |
+
return np.maximum(t, 0)
|
| 75 |
+
|
| 76 |
+
def softmax(t):
|
| 77 |
+
out = np.exp(t)
|
| 78 |
+
return out / np.sum(out)
|
| 79 |
+
|
| 80 |
+
# Calculate
|
| 81 |
+
t1 = x @ W1 + b1
|
| 82 |
+
h1 = relu(t1)
|
| 83 |
+
t2 = h1 @ W2 + b2
|
| 84 |
+
h2 = relu(t2)
|
| 85 |
+
t3 = h2 @ W3 + b3
|
| 86 |
+
z = softmax(t3)
|
| 87 |
+
return z
|
| 88 |
+
|
| 89 |
+
# Form map
|
| 90 |
+
map = []
|
| 91 |
+
for x in image_array:
|
| 92 |
+
z = predict(x)
|
| 93 |
+
y_pred = np.argmax(z)
|
| 94 |
+
map.append(y_pred)
|
| 95 |
+
return map
|
src/prepare_image.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# size
|
| 6 |
+
width = 600
|
| 7 |
+
height = 600
|
| 8 |
+
|
| 9 |
+
# open image file
|
| 10 |
+
image = Image.open("img.jpg")
|
| 11 |
+
image = image.crop((0, 0, width, height))
|
| 12 |
+
image = image.convert('L')
|
| 13 |
+
image.show()
|
| 14 |
+
|
| 15 |
+
image_array = []
|
| 16 |
+
|
| 17 |
+
# image_array to arrays
|
| 18 |
+
for x in range(width):
|
| 19 |
+
for y in range(height):
|
| 20 |
+
crop = image.crop((x, y, x + 8, y + 8))
|
| 21 |
+
image_array.append(np.reshape(np.asarray(crop) / 255, (1, 64)))
|
| 22 |
+
|
| 23 |
+
# save image_array
|
| 24 |
+
image_array = np.asarray(image_array)
|
| 25 |
+
|
| 26 |
+
np.save('./../training/image', image_array)
|
| 27 |
+
|
| 28 |
+
print(np.shape(image_array))
|
src/prepare_image_for_dataset.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# open image_for_dataset file
|
| 6 |
+
image_for_dataset = Image.open("./../img_for_dataset.jpg")
|
| 7 |
+
|
| 8 |
+
# change image
|
| 9 |
+
width, height = image_for_dataset.size
|
| 10 |
+
image_for_dataset = image_for_dataset.crop((0, 0, width // 8 * 8, height // 8 * 8))
|
| 11 |
+
image_for_dataset = image_for_dataset.convert('L')
|
| 12 |
+
image_for_dataset.show()
|
| 13 |
+
|
| 14 |
+
images = []
|
| 15 |
+
|
| 16 |
+
# images to arrays
|
| 17 |
+
for x in range(width):
|
| 18 |
+
for y in range(height):
|
| 19 |
+
image = image_for_dataset.crop((x, y, x + 8, y + 8))
|
| 20 |
+
images.append(np.asarray(image))
|
| 21 |
+
|
| 22 |
+
# save crops_for_dataset
|
| 23 |
+
images_for_dateset = np.asarray(images)
|
| 24 |
+
|
| 25 |
+
print(np.shape(images))
|
| 26 |
+
|
| 27 |
+
np.save('./../images_for_dataset', images)
|
src/training.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Hyper arguments
|
| 6 |
+
INPUT_DIM = 64
|
| 7 |
+
OUT_DIM = 5
|
| 8 |
+
H1_DIM = 16
|
| 9 |
+
H2_DIM = 10
|
| 10 |
+
|
| 11 |
+
ALPHA = 0.005
|
| 12 |
+
NUM_EPOCHS = 100
|
| 13 |
+
|
| 14 |
+
# Global
|
| 15 |
+
img_map = []
|
| 16 |
+
loss_arr = []
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def relu(t):
|
| 20 |
+
return np.maximum(t, 0)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def softmax(t):
|
| 24 |
+
out = np.exp(t)
|
| 25 |
+
return out / np.sum(out)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def sparse_cross_entropy(z, y):
|
| 29 |
+
return -np.log(z[0, y])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def to_full(y, num_classes):
|
| 33 |
+
y_full = np.zeros((1, num_classes))
|
| 34 |
+
y_full[0, y] = 1
|
| 35 |
+
return y_full
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def relu_deriv(t):
|
| 39 |
+
return (t >= 0).astype(float)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def predict(x):
|
| 43 |
+
t1 = x @ W1 + b1
|
| 44 |
+
h1 = relu(t1)
|
| 45 |
+
t2 = h1 @ W2 + b2
|
| 46 |
+
h2 = relu(t2)
|
| 47 |
+
t3 = h2 @ W3 + b3
|
| 48 |
+
z = softmax(t3)
|
| 49 |
+
return z
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def calc_accuracy():
|
| 53 |
+
correct = 0
|
| 54 |
+
for y, x in dataset:
|
| 55 |
+
z = predict(x)
|
| 56 |
+
y_pred = np.argmax(z)
|
| 57 |
+
if y_pred == y:
|
| 58 |
+
correct += 1
|
| 59 |
+
acc = correct / len(dataset)
|
| 60 |
+
return acc
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_map():
|
| 64 |
+
for x in image:
|
| 65 |
+
z = predict(x)
|
| 66 |
+
y_pred = np.argmax(z)
|
| 67 |
+
img_map.append(y_pred)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
# Load infomation
|
| 72 |
+
dataset = np.load('./../dataset.npy', allow_pickle=True)
|
| 73 |
+
|
| 74 |
+
image = np.load('./../image.npy')
|
| 75 |
+
|
| 76 |
+
# Random synapses
|
| 77 |
+
W1 = np.random.rand(INPUT_DIM, H1_DIM)
|
| 78 |
+
b1 = np.random.rand(1, H1_DIM)
|
| 79 |
+
W2 = np.random.rand(H1_DIM, H2_DIM)
|
| 80 |
+
b2 = np.random.rand(1, H2_DIM)
|
| 81 |
+
W3 = np.random.rand(H2_DIM, OUT_DIM)
|
| 82 |
+
b3 = np.random.rand(1, OUT_DIM)
|
| 83 |
+
|
| 84 |
+
W1 = (W1 - 0.5) * 2 * np.sqrt(1/INPUT_DIM)
|
| 85 |
+
b1 = (b1 - 0.5) * 2 * np.sqrt(1/INPUT_DIM)
|
| 86 |
+
W2 = (W2 - 0.5) * 2 * np.sqrt(1/H1_DIM)
|
| 87 |
+
b2 = (b2 - 0.5) * 2 * np.sqrt(1/H1_DIM)
|
| 88 |
+
W3 = (W3 - 0.5) * 2 * np.sqrt(1/H2_DIM)
|
| 89 |
+
b3 = (b3 - 0.5) * 2 * np.sqrt(1/H2_DIM)
|
| 90 |
+
|
| 91 |
+
loss = 0
|
| 92 |
+
|
| 93 |
+
# Backpropagation
|
| 94 |
+
for ep in range(NUM_EPOCHS):
|
| 95 |
+
print(ep)
|
| 96 |
+
np.random.shuffle(dataset)
|
| 97 |
+
for i in range(len(dataset)):
|
| 98 |
+
|
| 99 |
+
x = dataset[i][1]
|
| 100 |
+
y = dataset[i][0]
|
| 101 |
+
|
| 102 |
+
# Forward
|
| 103 |
+
t1 = x @ W1 + b1
|
| 104 |
+
h1 = relu(t1)
|
| 105 |
+
t2 = h1 @ W2 + b2
|
| 106 |
+
h2 = relu(t2)
|
| 107 |
+
t3 = h2 @ W3 + b3
|
| 108 |
+
z = softmax(t3)
|
| 109 |
+
E = sparse_cross_entropy(z, y)
|
| 110 |
+
|
| 111 |
+
# Backward
|
| 112 |
+
y_full = to_full(y, OUT_DIM)
|
| 113 |
+
dE_dt3 = z - y_full
|
| 114 |
+
dE_dW3 = h2.T @ dE_dt3
|
| 115 |
+
dE_db3 = np.sum(dE_dt3, axis=0, keepdims=True)
|
| 116 |
+
dE_dh2 = dE_dt3 @ W3.T
|
| 117 |
+
dE_dt2 = dE_dh2 * relu_deriv(t2)
|
| 118 |
+
dE_dW2 = h1.T @ dE_dt2
|
| 119 |
+
dE_db2 = np.sum(dE_dt2, axis=0, keepdims=True)
|
| 120 |
+
dE_dh1 = dE_dt2 @ W2.T
|
| 121 |
+
dE_dt1 = dE_dh1 * relu_deriv(t1)
|
| 122 |
+
dE_dW1 = x.T @ dE_dt1
|
| 123 |
+
dE_db1 = np.sum(dE_dt1, axis=0, keepdims=True)
|
| 124 |
+
|
| 125 |
+
# Update
|
| 126 |
+
W1 = W1 - ALPHA * dE_dW1
|
| 127 |
+
b1 = b1 - ALPHA * dE_db1
|
| 128 |
+
W2 = W2 - ALPHA * dE_dW2
|
| 129 |
+
b2 = b2 - ALPHA * dE_db2
|
| 130 |
+
W3 = W3 - ALPHA * dE_dW3
|
| 131 |
+
b3 = b3 - ALPHA * dE_db3
|
| 132 |
+
|
| 133 |
+
loss += E
|
| 134 |
+
|
| 135 |
+
loss_arr.append(loss)
|
| 136 |
+
loss = 0
|
| 137 |
+
|
| 138 |
+
# Accuracy
|
| 139 |
+
accuracy = calc_accuracy()
|
| 140 |
+
print("Accuracy:", accuracy)
|
| 141 |
+
|
| 142 |
+
# Map
|
| 143 |
+
create_map()
|
| 144 |
+
np.save('map', np.asarray(img_map))
|
| 145 |
+
|
| 146 |
+
# Plot
|
| 147 |
+
plt.plot(loss_arr)
|
| 148 |
+
plt.show()
|
| 149 |
+
|
| 150 |
+
# Save synapses
|
| 151 |
+
np.savez('./../synapses', W1, b1, W2, b2, W3, b3)
|