Spaces:
Build error
Build error
Commit
·
0987346
1
Parent(s):
0a04cd7
stuck on working with computeOutput function, getting dim error every
Browse files- go.mod +3 -3
- go.sum +7 -0
- nn/backprop.go +91 -0
- nn/main.go +32 -36
- nn/split.go +22 -4
- nn/train.go +1 -7
- server.go +0 -1
go.mod
CHANGED
|
@@ -18,7 +18,7 @@ require (
|
|
| 18 |
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
| 19 |
github.com/valyala/fasthttp v1.49.0 // indirect
|
| 20 |
github.com/valyala/tcplisten v1.0.0 // indirect
|
| 21 |
-
golang.org/x/net v0.
|
| 22 |
-
golang.org/x/sys v0.
|
| 23 |
-
gonum.org/v1/gonum v0.
|
| 24 |
)
|
|
|
|
| 18 |
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
| 19 |
github.com/valyala/fasthttp v1.49.0 // indirect
|
| 20 |
github.com/valyala/tcplisten v1.0.0 // indirect
|
| 21 |
+
golang.org/x/net v0.17.0 // indirect
|
| 22 |
+
golang.org/x/sys v0.13.0 // indirect
|
| 23 |
+
gonum.org/v1/gonum v0.14.0 // indirect
|
| 24 |
)
|
go.sum
CHANGED
|
@@ -55,6 +55,7 @@ golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL
|
|
| 55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
| 56 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
| 57 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
|
|
|
| 58 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
| 59 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
| 60 |
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
|
@@ -71,6 +72,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
|
|
| 71 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
| 72 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
| 73 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
|
|
|
|
|
|
| 74 |
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
| 75 |
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
| 76 |
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
@@ -82,6 +85,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
|
| 82 |
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 83 |
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
| 84 |
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
|
|
|
|
|
| 85 |
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
| 86 |
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
| 87 |
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
@@ -95,6 +100,8 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
|
|
| 95 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
| 96 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
| 97 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
|
|
|
|
|
|
| 98 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
| 99 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
| 100 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
|
|
|
| 55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
| 56 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
| 57 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
| 58 |
+
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
| 59 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
| 60 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
| 61 |
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
|
|
|
| 72 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
| 73 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
| 74 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
| 75 |
+
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
| 76 |
+
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
| 77 |
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
| 78 |
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
| 79 |
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
|
|
| 85 |
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 86 |
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
| 87 |
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 88 |
+
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
| 89 |
+
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 90 |
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
| 91 |
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
| 92 |
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
|
|
| 100 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
| 101 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
| 102 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
| 103 |
+
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
|
| 104 |
+
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
|
| 105 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
| 106 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
| 107 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
nn/backprop.go
CHANGED
|
@@ -1,8 +1,99 @@
|
|
| 1 |
package nn
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
func (nn *NN) Backprop() {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
for i := 0; i < nn.Epochs; i++ {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
}
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
|
|
|
| 1 |
package nn
|
| 2 |
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
|
| 6 |
+
"gonum.org/v1/gonum/mat"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
func (nn *NN) Backprop() {
|
| 10 |
+
var (
|
| 11 |
+
activation = *nn.ActivationFunc
|
| 12 |
+
// lossHist []float64
|
| 13 |
+
)
|
| 14 |
|
| 15 |
for i := 0; i < nn.Epochs; i++ {
|
| 16 |
+
// compute output with current w + b
|
| 17 |
+
// then compute loss & backprop
|
| 18 |
+
hiddenOutput, err := computeOutput(
|
| 19 |
+
nn.XTrain,
|
| 20 |
+
nn.Wh,
|
| 21 |
+
nn.Bh,
|
| 22 |
+
activation,
|
| 23 |
+
)
|
| 24 |
+
if err != nil {
|
| 25 |
+
fmt.Printf("error computing hidden output: %v", err)
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
yHat, err := computeOutput(
|
| 29 |
+
hiddenOutput,
|
| 30 |
+
nn.Wo,
|
| 31 |
+
nn.Bo,
|
| 32 |
+
activation,
|
| 33 |
+
)
|
| 34 |
+
if err != nil {
|
| 35 |
+
fmt.Printf("error computing yHat: %v", err)
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
mse := meanSquaredError(nn.YTrain, yHat)
|
| 39 |
+
fmt.Println(mse)
|
| 40 |
+
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
|
| 46 |
+
// Check if any of the input matrices is nil
|
| 47 |
+
if arr == nil || w == nil || b == nil {
|
| 48 |
+
return nil, fmt.Errorf("Input matrices cannot be nil")
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Check input dimensions
|
| 52 |
+
arrRows, arrCols := arr.Dims()
|
| 53 |
+
wRows, wCols := w.Dims()
|
| 54 |
+
bRows, bCols := b.Dims()
|
| 55 |
+
|
| 56 |
+
if arrCols != wRows || bCols != wCols {
|
| 57 |
+
return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
|
| 58 |
}
|
| 59 |
|
| 60 |
+
// Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
|
| 61 |
+
var product mat.Dense
|
| 62 |
+
product.Mul(arr, w)
|
| 63 |
+
|
| 64 |
+
// Check dimensions of product and bias
|
| 65 |
+
productRows, productCols := product.Dims()
|
| 66 |
+
if productCols != bCols {
|
| 67 |
+
return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// Add the bias matrix 'b' to the product
|
| 71 |
+
var result mat.Dense
|
| 72 |
+
result.Add(&product, b)
|
| 73 |
+
|
| 74 |
+
// Apply the activation function to the result
|
| 75 |
+
applyActivation(&result, activationFunc)
|
| 76 |
+
|
| 77 |
+
return &result, nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func applyActivation(m *mat.Dense, f func(float64) float64) {
|
| 81 |
+
r, c := m.Dims()
|
| 82 |
+
data := m.RawMatrix().Data
|
| 83 |
+
for i := 0; i < r*c; i++ {
|
| 84 |
+
data[i] = f(data[i])
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
func meanSquaredError(y, yHat *mat.Dense) float64 {
|
| 89 |
+
var sum float64
|
| 90 |
+
r, c := y.Dims()
|
| 91 |
+
|
| 92 |
+
for row := 0; row < r; row++ {
|
| 93 |
+
for col := 0; col < c; col++ {
|
| 94 |
+
diff := y.At(row, col) - yHat.At(row, col)
|
| 95 |
+
sum += (diff * diff)
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
return sum / float64((r * c))
|
| 99 |
}
|
nn/main.go
CHANGED
|
@@ -7,6 +7,7 @@ import (
|
|
| 7 |
|
| 8 |
"github.com/go-gota/gota/dataframe"
|
| 9 |
"github.com/gofiber/fiber/v2"
|
|
|
|
| 10 |
)
|
| 11 |
|
| 12 |
type NN struct {
|
|
@@ -23,14 +24,14 @@ type NN struct {
|
|
| 23 |
// attributes set after args above are parsed
|
| 24 |
ActivationFunc *func(float64) float64
|
| 25 |
Df *dataframe.DataFrame
|
| 26 |
-
XTrain *
|
| 27 |
-
YTrain *
|
| 28 |
-
XTest *
|
| 29 |
-
YTest *
|
| 30 |
-
Wh *
|
| 31 |
-
Bh *
|
| 32 |
-
Wo *
|
| 33 |
-
Bo *
|
| 34 |
}
|
| 35 |
|
| 36 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
@@ -53,36 +54,31 @@ func (nn *NN) InitWnB() {
|
|
| 53 |
hiddenSize := nn.HiddenSize
|
| 54 |
outputSize := 1 // only predicting one thing
|
| 55 |
|
| 56 |
-
// input hidden layer weights
|
| 57 |
-
wh :=
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
wh[i][j] = rand.Float64() - 0.5
|
| 62 |
-
}
|
| 63 |
-
}
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
|
| 70 |
-
//
|
| 71 |
-
wo :=
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
wo[i][j] = rand.Float64() - 0.5
|
| 76 |
-
}
|
| 77 |
-
}
|
| 78 |
|
| 79 |
-
bo :=
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
}
|
| 83 |
|
| 84 |
-
nn.Wh =
|
| 85 |
-
nn.Bh =
|
| 86 |
-
nn.Wo =
|
| 87 |
-
nn.Bo =
|
| 88 |
}
|
|
|
|
| 7 |
|
| 8 |
"github.com/go-gota/gota/dataframe"
|
| 9 |
"github.com/gofiber/fiber/v2"
|
| 10 |
+
"gonum.org/v1/gonum/mat"
|
| 11 |
)
|
| 12 |
|
| 13 |
type NN struct {
|
|
|
|
| 24 |
// attributes set after args above are parsed
|
| 25 |
ActivationFunc *func(float64) float64
|
| 26 |
Df *dataframe.DataFrame
|
| 27 |
+
XTrain *mat.Dense
|
| 28 |
+
YTrain *mat.Dense
|
| 29 |
+
XTest *mat.Dense
|
| 30 |
+
YTest *mat.Dense
|
| 31 |
+
Wh *mat.Dense
|
| 32 |
+
Bh *mat.Dense
|
| 33 |
+
Wo *mat.Dense
|
| 34 |
+
Bo *mat.Dense
|
| 35 |
}
|
| 36 |
|
| 37 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
|
|
| 54 |
hiddenSize := nn.HiddenSize
|
| 55 |
outputSize := 1 // only predicting one thing
|
| 56 |
|
| 57 |
+
// Initialize input hidden layer weights as a Gonum matrix
|
| 58 |
+
wh := mat.NewDense(inputSize, hiddenSize, nil)
|
| 59 |
+
wh.Apply(func(i, j int, v float64) float64 {
|
| 60 |
+
return rand.Float64() - 0.5
|
| 61 |
+
}, wh)
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
// Initialize hidden layer bias as a Gonum matrix
|
| 64 |
+
bh := mat.NewDense(1, hiddenSize, nil)
|
| 65 |
+
bh.Apply(func(i, j int, v float64) float64 {
|
| 66 |
+
return rand.Float64() - 0.5
|
| 67 |
+
}, bh)
|
| 68 |
|
| 69 |
+
// Initialize weights and biases for hidden -> output layer as Gonum matrices
|
| 70 |
+
wo := mat.NewDense(hiddenSize, outputSize, nil)
|
| 71 |
+
wo.Apply(func(i, j int, v float64) float64 {
|
| 72 |
+
return rand.Float64() - 0.5
|
| 73 |
+
}, wo)
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
bo := mat.NewDense(1, outputSize, nil)
|
| 76 |
+
bo.Apply(func(i, j int, v float64) float64 {
|
| 77 |
+
return rand.Float64() - 0.5
|
| 78 |
+
}, bo)
|
| 79 |
|
| 80 |
+
nn.Wh = wh
|
| 81 |
+
nn.Bh = bh
|
| 82 |
+
nn.Wo = wo
|
| 83 |
+
nn.Bo = bo
|
| 84 |
}
|
nn/split.go
CHANGED
|
@@ -3,6 +3,9 @@ package nn
|
|
| 3 |
import (
|
| 4 |
"math"
|
| 5 |
"math/rand"
|
|
|
|
|
|
|
|
|
|
| 6 |
)
|
| 7 |
|
| 8 |
func (nn *NN) TrainTestSplit() {
|
|
@@ -34,9 +37,24 @@ func (nn *NN) TrainTestSplit() {
|
|
| 34 |
XTest := test.Select(nn.Features)
|
| 35 |
YTest := test.Select(nn.Target)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
}
|
|
|
|
| 3 |
import (
|
| 4 |
"math"
|
| 5 |
"math/rand"
|
| 6 |
+
|
| 7 |
+
"github.com/go-gota/gota/dataframe"
|
| 8 |
+
"gonum.org/v1/gonum/mat"
|
| 9 |
)
|
| 10 |
|
| 11 |
func (nn *NN) TrainTestSplit() {
|
|
|
|
| 37 |
XTest := test.Select(nn.Features)
|
| 38 |
YTest := test.Select(nn.Target)
|
| 39 |
|
| 40 |
+
// to make linear algebra easier & faster,
|
| 41 |
+
// we convert these dataframes that we are
|
| 42 |
+
// performing potentially expensive computations
|
| 43 |
+
// on into gonum matrices since we no longer need the
|
| 44 |
+
// column names.
|
| 45 |
+
nn.XTrain = df2mat(&XTrain)
|
| 46 |
+
nn.YTrain = df2mat(&YTrain)
|
| 47 |
+
nn.XTest = df2mat(&XTest)
|
| 48 |
+
nn.YTest = df2mat(&YTest)
|
| 49 |
+
}
|
| 50 |
|
| 51 |
+
// df2mat -> converts gota dataframe into gonum matrix
|
| 52 |
+
func df2mat(df *dataframe.DataFrame) *mat.Dense {
|
| 53 |
+
m := mat.NewDense(df.Nrow(), df.Ncol(), nil)
|
| 54 |
+
for i := 0; i < df.Nrow(); i++ {
|
| 55 |
+
for j := 0; j < df.Ncol(); j++ {
|
| 56 |
+
m.Set(i, j, df.Elem(i, j).Float())
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
return m
|
| 60 |
}
|
nn/train.go
CHANGED
|
@@ -3,11 +3,5 @@ package nn
|
|
| 3 |
func (nn *NN) Train() {
|
| 4 |
nn.InitWnB()
|
| 5 |
nn.TrainTestSplit()
|
| 6 |
-
|
| 7 |
-
// iterate n times where n = nn.Epochs
|
| 8 |
-
// use backprop algorithm on each iteration
|
| 9 |
-
// to fit the model to the data
|
| 10 |
-
for i := 0; i < nn.Epochs; i++ {
|
| 11 |
-
}
|
| 12 |
-
|
| 13 |
}
|
|
|
|
| 3 |
func (nn *NN) Train() {
|
| 4 |
nn.InitWnB()
|
| 5 |
nn.TrainTestSplit()
|
| 6 |
+
nn.Backprop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
}
|
server.go
CHANGED
|
@@ -19,7 +19,6 @@ func main() {
|
|
| 19 |
}
|
| 20 |
|
| 21 |
nn.Train()
|
| 22 |
-
|
| 23 |
return c.SendString("No error")
|
| 24 |
})
|
| 25 |
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
nn.Train()
|
|
|
|
| 22 |
return c.SendString("No error")
|
| 23 |
})
|
| 24 |
|