Joel Woodfield commited on
Commit
dd6ed2e
·
1 Parent(s): c5b45d7

Fix bug in number of data points rounding

Browse files
backend/src/__pycache__/logic.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/logic.cpython-312.pyc and b/backend/src/__pycache__/logic.cpython-312.pyc differ
 
backend/src/logic.py CHANGED
@@ -42,11 +42,12 @@ def generate_dataset(
42
  f = lambdify(('x1', 'x2'), function, modules='numpy')
43
 
44
  if generation_options.method == 'grid':
45
- x1 = np.linspace(x1_lim[0], x1_lim[1], int(np.sqrt(generation_options.num_samples)))
46
- x2 = np.linspace(x2_lim[0], x2_lim[1], int(np.sqrt(generation_options.num_samples)))
 
47
  X1, X2 = np.meshgrid(x1, x2)
48
- X1_flat = X1.flatten()
49
- X2_flat = X2.flatten()
50
  elif generation_options.method == 'random':
51
  X1_flat = np.random.uniform(x1_lim[0], x1_lim[1], generation_options.num_samples)
52
  X2_flat = np.random.uniform(x2_lim[0], x2_lim[1], generation_options.num_samples)
 
42
  f = lambdify(('x1', 'x2'), function, modules='numpy')
43
 
44
  if generation_options.method == 'grid':
45
+ side_length = int(np.ceil(np.sqrt(generation_options.num_samples)))
46
+ x1 = np.linspace(x1_lim[0], x1_lim[1], side_length)
47
+ x2 = np.linspace(x2_lim[0], x2_lim[1], side_length)
48
  X1, X2 = np.meshgrid(x1, x2)
49
+ X1_flat = X1.flatten()[:generation_options.num_samples]
50
+ X2_flat = X2.flatten()[:generation_options.num_samples]
51
  elif generation_options.method == 'random':
52
  X1_flat = np.random.uniform(x1_lim[0], x1_lim[1], generation_options.num_samples)
53
  X2_flat = np.random.uniform(x2_lim[0], x2_lim[1], generation_options.num_samples)