| #pragma once |
| #include <vector> |
| #include <cassert> |
| #include <cmath> |
| #include <random> |
| #include <iostream> |
| #include <iomanip> |
| #include <cstring> |
| #include <algorithm> |
|
|
| namespace newnet { |
|
|
| class Tensor { |
| public: |
| std::vector<float> data; |
| std::vector<float> grad; |
| std::vector<int> shape; |
| bool requires_grad; |
|
|
| |
|
|
| Tensor() : requires_grad(false) {} |
|
|
| |
| Tensor(std::vector<int> shape_) |
| : shape(shape_), requires_grad(false) { |
| int size = 1; |
| for (int dim : shape) size *= dim; |
| data.resize(size, 0.0f); |
| } |
|
|
| |
| Tensor(std::vector<int> shape_, float val) |
| : shape(shape_), requires_grad(false) { |
| int size = 1; |
| for (int dim : shape) size *= dim; |
| data.resize(size, val); |
| } |
|
|
| |
|
|
| static Tensor zeros(std::vector<int> shape) { |
| return Tensor(shape, 0.0f); |
| } |
|
|
| static Tensor ones(std::vector<int> shape) { |
| return Tensor(shape, 1.0f); |
| } |
|
|
| |
| |
| static Tensor xavier(int fan_in, int fan_out) { |
| Tensor t({fan_in, fan_out}); |
| float limit = std::sqrt(6.0f / (fan_in + fan_out)); |
| std::mt19937 gen(42); |
| std::uniform_real_distribution<float> dist(-limit, limit); |
| for (int i = 0; i < (int)t.data.size(); i++) { |
| t.data[i] = dist(gen); |
| } |
| t.requires_grad = true; |
| return t; |
| } |
|
|
| |
|
|
| int size() const { |
| int s = 1; |
| for (int dim : shape) s *= dim; |
| return s; |
| } |
|
|
| int rows() const { |
| assert(shape.size() >= 1); |
| return shape[0]; |
| } |
|
|
| int cols() const { |
| assert(shape.size() >= 2); |
| return shape[1]; |
| } |
|
|
| |
| float& operator()(int row, int col) { |
| assert((int)shape.size() == 2); |
| return data[row * shape[1] + col]; |
| } |
|
|
| const float& operator()(int row, int col) const { |
| assert((int)shape.size() == 2); |
| return data[row * shape[1] + col]; |
| } |
|
|
| |
|
|
| void init_grad() { |
| grad.resize(data.size(), 0.0f); |
| } |
|
|
| void zero_grad() { |
| std::fill(grad.begin(), grad.end(), 0.0f); |
| } |
|
|
| |
|
|
| void print(const std::string& name = "") const { |
| if (!name.empty()) std::cout << name << " "; |
| std::cout << "["; |
| for (int i = 0; i < (int)shape.size(); i++) { |
| std::cout << shape[i]; |
| if (i < (int)shape.size() - 1) std::cout << "x"; |
| } |
| std::cout << "]:\n"; |
|
|
| if ((int)shape.size() == 2) { |
| int r = std::min(rows(), 6); |
| int c = std::min(cols(), 6); |
| for (int i = 0; i < r; i++) { |
| std::cout << " "; |
| for (int j = 0; j < c; j++) { |
| std::cout << std::setw(9) << std::fixed |
| << std::setprecision(4) << (*this)(i, j); |
| } |
| if (cols() > 6) std::cout << " ..."; |
| std::cout << "\n"; |
| } |
| if (rows() > 6) std::cout << " ...\n"; |
| } else { |
| int n = std::min(size(), 10); |
| std::cout << " "; |
| for (int i = 0; i < n; i++) { |
| std::cout << std::setw(9) << std::fixed |
| << std::setprecision(4) << data[i]; |
| } |
| if (size() > 10) std::cout << " ..."; |
| std::cout << "\n"; |
| } |
| } |
| }; |
|
|
| } |
|
|