unfinished work on backtracking
#1
by
SebastienGuissart
- opened
- app.py +9 -2
- sudoku/train.py +33 -1
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from sudoku.train import SudokuTrialErrorLightning
|
| 3 |
from sudoku.helper import display_as_dataframe, get_grid_number_soluce
|
| 4 |
import numpy as np
|
| 5 |
import re
|
|
@@ -116,7 +116,14 @@ if n_sol==1:
|
|
| 116 |
while new_X.sum()<729:
|
| 117 |
i+=1
|
| 118 |
st.markdown(f'iteration {i}')
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
| 121 |
new_X_sum = new_X.sum()
|
| 122 |
assert new_X_sum> X_sum
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from sudoku.train import SudokuTrialErrorLightning, TrialEveryPosException
|
| 3 |
from sudoku.helper import display_as_dataframe, get_grid_number_soluce
|
| 4 |
import numpy as np
|
| 5 |
import re
|
|
|
|
| 116 |
while new_X.sum()<729:
|
| 117 |
i+=1
|
| 118 |
st.markdown(f'iteration {i}')
|
| 119 |
+
try:
|
| 120 |
+
new_X = model.predict(new_X)
|
| 121 |
+
except TrialEveryPosException:
|
| 122 |
+
st.markdown('''## The grid is super evil!
|
| 123 |
+
please share it as A Discussion in the `Community` tab.
|
| 124 |
+
Except if it is this one: https://www.telegraph.co.uk/news/science/science-news/9359579/Worlds-hardest-sudoku-can-you-crack-it.html''')
|
| 125 |
+
is_valid, new_X = model.backtracking_predict(new_X)
|
| 126 |
+
assert is_valid
|
| 127 |
st.html(display_as_dataframe(new_X).to_html(escape=False, index=False))
|
| 128 |
new_X_sum = new_X.sum()
|
| 129 |
assert new_X_sum> X_sum
|
sudoku/train.py
CHANGED
|
@@ -401,6 +401,9 @@ class SudokuLightning(pl.LightningModule):
|
|
| 401 |
|
| 402 |
# TODO adapt training to something softer
|
| 403 |
#
|
|
|
|
|
|
|
|
|
|
| 404 |
class SudokuTrialErrorLightning(SudokuLightning):
|
| 405 |
def __init__(self, **kwargs):
|
| 406 |
super().__init__(**kwargs)
|
|
@@ -752,7 +755,7 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
| 752 |
mask_possibility[pos]=False
|
| 753 |
if mask_possibility.sum()==0:
|
| 754 |
print('mask_possible=0')
|
| 755 |
-
raise
|
| 756 |
|
| 757 |
with torch.no_grad():
|
| 758 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
@@ -861,7 +864,36 @@ class SudokuTrialErrorLightning(SudokuLightning):
|
|
| 861 |
return x
|
| 862 |
X_tried = new_X
|
| 863 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
|
|
|
| 865 |
def on_validation_epoch_start(self) -> None:
|
| 866 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
| 867 |
self.trial_error_buffer = Buffer(self.batch_size)
|
|
|
|
| 401 |
|
| 402 |
# TODO adapt training to something softer
|
| 403 |
#
|
| 404 |
+
class TrialEveryPosException(Exception):
|
| 405 |
+
pass
|
| 406 |
+
|
| 407 |
class SudokuTrialErrorLightning(SudokuLightning):
|
| 408 |
def __init__(self, **kwargs):
|
| 409 |
super().__init__(**kwargs)
|
|
|
|
| 755 |
mask_possibility[pos]=False
|
| 756 |
if mask_possibility.sum()==0:
|
| 757 |
print('mask_possible=0')
|
| 758 |
+
raise TrialEveryPosException()
|
| 759 |
|
| 760 |
with torch.no_grad():
|
| 761 |
x_reg = self.sym_preprocess.forward(s_new_X.view(1,2,-1))
|
|
|
|
| 864 |
return x
|
| 865 |
X_tried = new_X
|
| 866 |
# if one of X_tried is complete (weird but possible) -> return x with tried_position mask_complet set to 1 (cause we still want a step by step resolution)
|
| 867 |
+
|
| 868 |
+
def backtracking_predict(self, x):
|
| 869 |
+
"""
|
| 870 |
+
return is_valid, new_x
|
| 871 |
+
"""
|
| 872 |
+
try:
|
| 873 |
+
next_x = self.predict(x)
|
| 874 |
+
except TrialEveryPosException:
|
| 875 |
+
pos = self.search_trial(x.view(2,729), [])
|
| 876 |
+
new_x = deepcopy(x.view(2,729))
|
| 877 |
+
output = self.forward_layer(x.view(1,2,729))
|
| 878 |
+
pos_output = torch.argmax(output[0,:,pos])
|
| 879 |
+
if (pos_output[0].item()-self.threshold_abs)>(pos_output[1].item-self.threshold_pos):
|
| 880 |
+
trial_abs_pres = 0
|
| 881 |
+
else:
|
| 882 |
+
trial_abs_pres = 1
|
| 883 |
+
new_x[0,trial_abs_pres,pos]=1
|
| 884 |
+
is_valid, new_x = self.backtracking_predict(new_x)
|
| 885 |
+
if not is_valid:
|
| 886 |
+
new_x = deepcopy(x.view(2,729))
|
| 887 |
+
new_x[0,(trial_abs_pres+1)%2,pos]=1
|
| 888 |
+
return self.backtracking_predict(new_x)
|
| 889 |
+
return True, new_x
|
| 890 |
+
except ValueError:
|
| 891 |
+
return False, None
|
| 892 |
+
if next_x.sum()==729:
|
| 893 |
+
return True, next_x
|
| 894 |
+
return self.backtracking_predict(next_x)
|
| 895 |
|
| 896 |
+
|
| 897 |
def on_validation_epoch_start(self) -> None:
|
| 898 |
# self.buffer = BufferArray(self.nets_number, self.batch_size)
|
| 899 |
self.trial_error_buffer = Buffer(self.batch_size)
|