Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,55 +5,251 @@ import matplotlib.pyplot as plt
|
|
| 5 |
import gradio as gr
|
| 6 |
import pandas as pd
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
def flatten(img : np.array) -> list[int] :
|
|
|
|
| 10 |
new : list[int] = []
|
| 11 |
for row in img:
|
| 12 |
for item in row:
|
| 13 |
new.append(int(item))
|
| 14 |
return new
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class Hopfield:
|
| 17 |
def __init__(self,patts):
|
| 18 |
self.E : list[int] = []
|
| 19 |
self.patts = patts
|
| 20 |
-
self.size = (4,4)
|
| 21 |
self.Px :int = len(patts)
|
| 22 |
self.Py :int = len(patts[0])
|
| 23 |
-
|
| 24 |
self.W : np.array = np.zeros((self.Py,self.Py),dtype=np.float16)
|
| 25 |
|
| 26 |
def train(self):
|
| 27 |
-
|
| 28 |
for i in range(self.Py):
|
| 29 |
for j in range(self.Py):
|
| 30 |
if i == j:
|
| 31 |
self.W[i][j] = 0
|
| 32 |
continue
|
| 33 |
-
|
| 34 |
self.W[i][j] = (1 / self.Px) * sum([patt[i] * patt[j] for patt in self.patts])
|
| 35 |
|
| 36 |
def Energy(self):
|
| 37 |
-
|
| 38 |
return self.E
|
| 39 |
|
| 40 |
def update(self,pattern):
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
pattern_flat = flatten(pattern)
|
| 43 |
|
| 44 |
-
|
| 45 |
H : list[int] = []
|
| 46 |
for i in self.W:
|
| 47 |
-
|
| 48 |
-
H.append((sum([w * s for w,s in zip(i, pattern_flat)])))
|
| 49 |
|
| 50 |
H = np.array(H)
|
| 51 |
-
H = np.sign(H)
|
| 52 |
|
|
|
|
| 53 |
E = 0
|
| 54 |
for i in range(self.Py):
|
| 55 |
for j in range(self.Py):
|
| 56 |
E += float(-0.5 * self.W[i][j] * H[i] * H[j])
|
| 57 |
self.E.append(E)
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import pandas as pd
|
| 7 |
|
| 8 |
+
# --- Your Helper Functions ---
|
| 9 |
|
| 10 |
def flatten(img : np.array) -> list[int] :
|
| 11 |
+
"""Converts a 2D numpy array into a 1D list."""
|
| 12 |
new : list[int] = []
|
| 13 |
for row in img:
|
| 14 |
for item in row:
|
| 15 |
new.append(int(item))
|
| 16 |
return new
|
| 17 |
|
| 18 |
+
def sgn(x):
|
| 19 |
+
"""Sign function."""
|
| 20 |
+
if x < 0:
|
| 21 |
+
return -1
|
| 22 |
+
if x == 0:
|
| 23 |
+
return 0
|
| 24 |
+
return 1
|
| 25 |
+
|
| 26 |
+
# --- Your Hopfield Class (with one bugfix) ---
|
| 27 |
+
|
| 28 |
class Hopfield:
|
| 29 |
def __init__(self,patts):
|
| 30 |
self.E : list[int] = []
|
| 31 |
self.patts = patts
|
| 32 |
+
self.size = (4,4) # Fixed size for reshaping
|
| 33 |
self.Px :int = len(patts)
|
| 34 |
self.Py :int = len(patts[0])
|
| 35 |
+
# Initialize weights
|
| 36 |
self.W : np.array = np.zeros((self.Py,self.Py),dtype=np.float16)
|
| 37 |
|
| 38 |
def train(self):
|
| 39 |
+
"""Trains the network on the patterns provided in __init__."""
|
| 40 |
for i in range(self.Py):
|
| 41 |
for j in range(self.Py):
|
| 42 |
if i == j:
|
| 43 |
self.W[i][j] = 0
|
| 44 |
continue
|
| 45 |
+
# Hebbian rule
|
| 46 |
self.W[i][j] = (1 / self.Px) * sum([patt[i] * patt[j] for patt in self.patts])
|
| 47 |
|
| 48 |
def Energy(self):
|
| 49 |
+
"""Returns the list of energy values recorded during updates."""
|
| 50 |
return self.E
|
| 51 |
|
| 52 |
def update(self,pattern):
|
| 53 |
+
"""
|
| 54 |
+
Performs one asynchronous update step on the entire pattern.
|
| 55 |
+
"""
|
| 56 |
+
# Flatten the 2D input pattern to 1D
|
| 57 |
pattern_flat = flatten(pattern)
|
| 58 |
|
| 59 |
+
# Calculate the new state vector H
|
| 60 |
H : list[int] = []
|
| 61 |
for i in self.W:
|
| 62 |
+
# H_i = sgn(sum(W_ij * S_j))
|
| 63 |
+
H.append(sgn(sum([w * s for w,s in zip(i, pattern_flat)])))
|
| 64 |
|
| 65 |
H = np.array(H)
|
|
|
|
| 66 |
|
| 67 |
+
# Calculate the energy of this new state H
|
| 68 |
E = 0
|
| 69 |
for i in range(self.Py):
|
| 70 |
for j in range(self.Py):
|
| 71 |
E += float(-0.5 * self.W[i][j] * H[i] * H[j])
|
| 72 |
self.E.append(E)
|
| 73 |
|
| 74 |
+
# --- FIX ---
|
| 75 |
+
# Use reshape, not resize. Resize can add/remove elements.
|
| 76 |
+
# Reshape will fail if H doesn't have 16 elements, which is safer.
|
| 77 |
+
return H.reshape(self.size)
|
| 78 |
+
|
| 79 |
+
# --- Default Patterns for the Gradio App ---
|
| 80 |
+
|
| 81 |
+
# Pattern 1: 'X'
|
| 82 |
+
patt_1_default = [[ 1, -1, -1, 1],
|
| 83 |
+
[-1, 1, 1, -1],
|
| 84 |
+
[-1, 1, 1, -1],
|
| 85 |
+
[ 1, -1, -1, 1]]
|
| 86 |
+
|
| 87 |
+
# Pattern 2: 'C'
|
| 88 |
+
patt_2_default = [[ 1, 1, 1, -1],
|
| 89 |
+
[ 1, -1, -1, -1],
|
| 90 |
+
[ 1, -1, -1, -1],
|
| 91 |
+
[ 1, 1, 1, -1]]
|
| 92 |
+
|
| 93 |
+
# Pattern 3: 'L'
|
| 94 |
+
patt_3_default = [[ 1, -1, -1, -1],
|
| 95 |
+
[ 1, -1, -1, -1],
|
| 96 |
+
[ 1, -1, -1, -1],
|
| 97 |
+
[ 1, 1, 1, 1]]
|
| 98 |
+
|
| 99 |
+
# Initial (corrupted) shape to test
|
| 100 |
+
initial_shape_default = [[ 1, 1, -1, -1],
|
| 101 |
+
[ 1, -1, -1, -1],
|
| 102 |
+
[ 1, -1, 1, -1],
|
| 103 |
+
[ 1, 1, 1, -1]]
|
| 104 |
+
|
| 105 |
+
# --- Gradio Core Logic ---
|
| 106 |
+
|
| 107 |
+
def clean_dataframe(df):
|
| 108 |
+
"""Helper to convert Gradio dataframe to a clean NumPy array."""
|
| 109 |
+
# Fill any empty cells (None) with -1 and convert to int
|
| 110 |
+
return df.fillna(-1).to_numpy(dtype=int)
|
| 111 |
+
|
| 112 |
+
def run_hopfield(patt1_df, patt2_df, patt3_df, initial_shape_df, steps):
|
| 113 |
+
"""
|
| 114 |
+
The main function for the Gradio interface.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# 1. Clean inputs
|
| 118 |
+
p1 = clean_dataframe(patt1_df)
|
| 119 |
+
p2 = clean_dataframe(patt2_df)
|
| 120 |
+
p3 = clean_dataframe(patt3_df)
|
| 121 |
+
initial_shape = clean_dataframe(initial_shape_df)
|
| 122 |
+
|
| 123 |
+
# 2. Collect patterns to train (ignore empty/all -1 patterns)
|
| 124 |
+
patterns_to_train = []
|
| 125 |
+
if np.any(p1 == 1):
|
| 126 |
+
patterns_to_train.append(flatten(p1))
|
| 127 |
+
if np.any(p2 == 1):
|
| 128 |
+
patterns_to_train.append(flatten(p2))
|
| 129 |
+
if np.any(p3 == 1):
|
| 130 |
+
patterns_to_train.append(flatten(p3))
|
| 131 |
+
|
| 132 |
+
# 3. Check if any patterns were provided
|
| 133 |
+
if not patterns_to_train:
|
| 134 |
+
fig_shape = plt.figure()
|
| 135 |
+
plt.title("Error: No patterns provided to train.")
|
| 136 |
+
plt.axis('off')
|
| 137 |
+
|
| 138 |
+
fig_energy = plt.figure()
|
| 139 |
+
plt.title("Error: No patterns provided to train.")
|
| 140 |
+
|
| 141 |
+
return fig_shape, fig_energy
|
| 142 |
+
|
| 143 |
+
# 4. Create and train the model
|
| 144 |
+
patts = np.array(patterns_to_train)
|
| 145 |
+
model = Hopfield(patts)
|
| 146 |
+
model.train()
|
| 147 |
+
|
| 148 |
+
# 5. Run the evolution
|
| 149 |
+
current_shape = initial_shape
|
| 150 |
+
for _ in range(int(steps)):
|
| 151 |
+
next_shape = model.update(current_shape)
|
| 152 |
+
# Check for convergence
|
| 153 |
+
if np.array_equal(current_shape, next_shape):
|
| 154 |
+
break
|
| 155 |
+
current_shape = next_shape
|
| 156 |
+
|
| 157 |
+
# 6. Generate final shape plot
|
| 158 |
+
fig_shape = plt.figure()
|
| 159 |
+
plt.imshow(current_shape, cmap='gray', vmin=-1, vmax=1)
|
| 160 |
+
plt.title("Final Evolved Shape")
|
| 161 |
+
plt.axis('off')
|
| 162 |
+
|
| 163 |
+
# 7. Generate energy plot
|
| 164 |
+
fig_energy = plt.figure()
|
| 165 |
+
energy_data = model.Energy()
|
| 166 |
+
if energy_data:
|
| 167 |
+
plt.plot(list(range(len(energy_data))), energy_data, marker='o')
|
| 168 |
+
plt.title("Energy Evolution")
|
| 169 |
+
plt.xlabel("Update Step")
|
| 170 |
+
plt.ylabel("Energy")
|
| 171 |
+
plt.grid(True)
|
| 172 |
+
else:
|
| 173 |
+
plt.title("Energy (No Updates Run)")
|
| 174 |
+
|
| 175 |
+
return fig_shape, fig_energy
|
| 176 |
+
|
| 177 |
+
# --- Gradio Interface ---
|
| 178 |
+
|
| 179 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 180 |
+
gr.Markdown("# 🧠 Hopfield Network Simulator")
|
| 181 |
+
gr.Markdown(
|
| 182 |
+
"Define up to 3 patterns (1 for 'on', -1 for 'off'). "
|
| 183 |
+
"The network will learn them. Then, draw an 'Initial Shape' "
|
| 184 |
+
"and see if the network can evolve it into one of the patterns it learned."
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
with gr.Row():
|
| 188 |
+
with gr.Column():
|
| 189 |
+
gr.Markdown("### 1. Define Patterns to Memorize")
|
| 190 |
+
# Set headers to be invisible, type to pandas for .fillna
|
| 191 |
+
patt1_in = gr.Dataframe(
|
| 192 |
+
value=patt_1_default,
|
| 193 |
+
label="Pattern 1",
|
| 194 |
+
headers=None,
|
| 195 |
+
datatype="number",
|
| 196 |
+
col_count=4,
|
| 197 |
+
row_count=4,
|
| 198 |
+
type="pandas"
|
| 199 |
+
)
|
| 200 |
+
patt2_in = gr.Dataframe(
|
| 201 |
+
value=patt_2_default,
|
| 202 |
+
label="Pattern 2",
|
| 203 |
+
headers=None,
|
| 204 |
+
datatype="number",
|
| 205 |
+
col_count=4,
|
| 206 |
+
row_count=4,
|
| 207 |
+
type="pandas"
|
| 208 |
+
)
|
| 209 |
+
patt3_in = gr.Dataframe(
|
| 210 |
+
value=patt_3_default,
|
| 211 |
+
label="Pattern 3",
|
| 212 |
+
headers=None,
|
| 213 |
+
datatype="number",
|
| 214 |
+
col_count=4,
|
| 215 |
+
row_count=4,
|
| 216 |
+
type="pandas"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
with gr.Column():
|
| 220 |
+
gr.Markdown("### 2. Set Initial Shape & Run")
|
| 221 |
+
initial_in = gr.Dataframe(
|
| 222 |
+
value=initial_shape_default,
|
| 223 |
+
label="Initial Shape (Test Pattern)",
|
| 224 |
+
headers=None,
|
| 225 |
+
datatype="number",
|
| 226 |
+
col_count=4,
|
| 227 |
+
row_count=4,
|
| 228 |
+
type="pandas"
|
| 229 |
+
)
|
| 230 |
+
steps_in = gr.Slider(
|
| 231 |
+
minimum=1,
|
| 232 |
+
maximum=10,
|
| 233 |
+
value=5,
|
| 234 |
+
step=1,
|
| 235 |
+
label="Max Evolution Steps"
|
| 236 |
+
)
|
| 237 |
+
run_btn = gr.Button("Run Evolution", variant="primary")
|
| 238 |
+
|
| 239 |
+
with gr.Row():
|
| 240 |
+
with gr.Column():
|
| 241 |
+
gr.Markdown("### 3. Results")
|
| 242 |
+
shape_out = gr.Plot(label="Final Evolved Shape")
|
| 243 |
+
with gr.Column():
|
| 244 |
+
gr.Markdown("### 4. Diagnostics")
|
| 245 |
+
energy_out = gr.Plot(label="Energy Evolution")
|
| 246 |
+
|
| 247 |
+
# Connect the button to the function
|
| 248 |
+
run_btn.click(
|
| 249 |
+
fn=run_hopfield,
|
| 250 |
+
inputs=[patt1_in, patt2_in, patt3_in, initial_in, steps_in],
|
| 251 |
+
outputs=[shape_out, energy_out]
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
demo.launch()
|