unfinished work on backtracking
Browse files- sudoku/train.py +25 -1
sudoku/train.py
CHANGED
|
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
|
|
| 401 |
|
| 402 |
# TODO adapt training to something softer
|
| 403 |
#
|
|
|
|
|
|
|
|
|
|
| 404 |
class SudokuTrialErrorLightning(SudokuLightning):
|
| 405 |
def __init__(self, **kwargs):
|
| 406 |
super().__init__(**kwargs)
|
|
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
| 752 |
mask_possibility[pos]=False
|
| 753 |
if mask_possibility.sum()==0:
|
| 754 |
print('mask_possible=0')
|
| 755 |
-
raise
|
| 756 |
|
| 757 |
with torch.no_grad():
|
| 758 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
@@ -861,7 +864,28 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
| 861 |
return x
|
| 862 |
X_tried = new_X
|
| 863 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
|
|
|
| 865 |
def on_validation_epoch_start(self) -> None:
|
| 866 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
| 867 |
self.trial_error_buffer = Buffer(self.batch_size)
|
|
|
|
| 401 |
|
| 402 |
# TODO adapt training to something softer
|
| 403 |
#
|
| 404 |
+
class TrialEveryPosException(Exception):
|
| 405 |
+
pass
|
| 406 |
+
|
| 407 |
class SudokuTrialErrorLightning(SudokuLightning):
|
| 408 |
def __init__(self, **kwargs):
|
| 409 |
super().__init__(**kwargs)
|
|
|
|
| 755 |
mask_possibility[pos]=False
|
| 756 |
if mask_possibility.sum()==0:
|
| 757 |
print('mask_possible=0')
|
| 758 |
+
raise TrialEveryPosException()
|
| 759 |
|
| 760 |
with torch.no_grad():
|
| 761 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
|
|
| 864 |
return x
|
| 865 |
X_tried = new_X
|
| 866 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
| 867 |
+
|
| 868 |
+
def backtracking_predict(self, x):
|
| 869 |
+
"""
|
| 870 |
+
return is_valid, new_x
|
| 871 |
+
"""
|
| 872 |
+
try:
|
| 873 |
+
x = self.predict(x)
|
| 874 |
+
except TrialEveryPosException:
|
| 875 |
+
pos = self.search_trial(x.view(2,729), [])
|
| 876 |
+
new_x = deepcopy(x.view(2,729))
|
| 877 |
+
output = self.forward_layer(x.view(1,2,729))
|
| 878 |
+
pos_output = torch.argmax(output[0,:,pos])
|
| 879 |
+
if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
|
| 880 |
+
trial_abs_pres = 0
|
| 881 |
+
else:
|
| 882 |
+
trial_abs_pres = 1
|
| 883 |
+
new_x[0,trial_abs_pres,pos]=1
|
| 884 |
+
is_valid, new_x = self.backtracking_predict(new_x)
|
| 885 |
+
if not is_valid:
|
| 886 |
+
return is_valid ### TO continue
|
| 887 |
|
| 888 |
+
|
| 889 |
def on_validation_epoch_start(self) -> None:
|
| 890 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
| 891 |
self.trial_error_buffer = Buffer(self.batch_size)
|