Naz786 commited on
Commit
9c74b9c
ยท
verified ยท
1 Parent(s): 25f411f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +758 -0
app.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Q-Learning AI for Sensor Placement - Interactive Demo
3
+ For Hugging Face Spaces
4
+ """
5
+
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import gradio as gr
9
+ from collections import defaultdict
10
+
11
+ np.random.seed(42)
12
+
13
+
14
+ # ==============================================================================
15
+ # PART 1: THE SECRET WORLD
16
+ # ==============================================================================
17
+
18
+ class ThiefWorld:
19
+ """Where thieves REALLY appear (AI must discover this!)"""
20
+
21
+ def __init__(self, hotspot1=2.5, hotspot2=7.0):
22
+ self.hotspot1 = hotspot1
23
+ self.hotspot2 = hotspot2
24
+ self.n_zones = 10
25
+
26
+ def get_thief_probability(self, zone):
27
+ zone_center = zone + 0.5
28
+ prob = (
29
+ 0.6 * np.exp(-((zone_center - self.hotspot1)**2) / 1.0) +
30
+ 0.4 * np.exp(-((zone_center - self.hotspot2)**2) / 0.8) +
31
+ 0.05
32
+ )
33
+ return min(prob, 1.0)
34
+
35
+ def generate_thieves(self):
36
+ thieves = np.zeros(self.n_zones)
37
+ for zone in range(self.n_zones):
38
+ if np.random.random() < self.get_thief_probability(zone):
39
+ thieves[zone] = 1
40
+ return thieves
41
+
42
+
43
+ # ==============================================================================
44
+ # PART 2: SENSOR
45
+ # ==============================================================================
46
+
47
+ class Sensor:
48
+ def __init__(self, catch_probability=0.9):
49
+ self.catch_prob = catch_probability
50
+
51
+ def try_catch(self, thief_present):
52
+ if thief_present:
53
+ return np.random.random() < self.catch_prob
54
+ return False
55
+
56
+
57
+ # ==============================================================================
58
+ # PART 3: ENVIRONMENT
59
+ # ==============================================================================
60
+
61
+ class SensorPlacementEnv:
62
+ def __init__(self, n_sensors=4, hotspot1=2.5, hotspot2=7.0):
63
+ self.world = ThiefWorld(hotspot1, hotspot2)
64
+ self.sensor = Sensor()
65
+ self.n_sensors = n_sensors
66
+ self.n_zones = 10
67
+ self.reset()
68
+
69
+ def reset(self):
70
+ self.zone_attempts = np.zeros(self.n_zones)
71
+ self.zone_catches = np.zeros(self.n_zones)
72
+ self.day = 0
73
+ self.total_caught = 0
74
+ self.total_thieves = 0
75
+ return self._get_state()
76
+
77
+ def _get_state(self):
78
+ if self.zone_attempts.sum() == 0:
79
+ return (0, 0)
80
+ most_tried = int(np.argmax(self.zone_attempts))
81
+ catch_rates = np.zeros(self.n_zones)
82
+ for z in range(self.n_zones):
83
+ if self.zone_attempts[z] > 0:
84
+ catch_rates[z] = self.zone_catches[z] / self.zone_attempts[z]
85
+ best_zone = int(np.argmax(catch_rates))
86
+ return (most_tried, best_zone)
87
+
88
+ def step(self, action):
89
+ thieves = self.world.generate_thieves()
90
+ n_thieves = int(thieves.sum())
91
+ self.total_thieves += n_thieves
92
+
93
+ caught = 0
94
+ for zone in action:
95
+ if zone < self.n_zones:
96
+ self.zone_attempts[zone] += 1
97
+ if thieves[zone] == 1:
98
+ if self.sensor.try_catch(True):
99
+ caught += 1
100
+ self.zone_catches[zone] += 1
101
+
102
+ self.total_caught += caught
103
+ self.day += 1
104
+ reward = caught + 0.1 * len(set(action))
105
+ done = self.day >= 30
106
+
107
+ return self._get_state(), reward, done, {'caught': caught}
108
+
109
+
110
+ # ==============================================================================
111
+ # PART 4: Q-LEARNING AGENT
112
+ # ==============================================================================
113
+
114
+ class QLearningAgent:
115
+ def __init__(self):
116
+ self.q_table = defaultdict(lambda: defaultdict(float))
117
+ self.learning_rate = 0.1
118
+ self.discount_factor = 0.95
119
+ self.epsilon = 1.0
120
+ self.epsilon_decay = 0.995
121
+ self.epsilon_min = 0.01
122
+
123
+ def _get_possible_actions(self):
124
+ return [
125
+ (1, 3, 6, 8), (0, 3, 6, 9), (2, 4, 6, 8),
126
+ (0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5),
127
+ (5, 6, 7, 8), (6, 7, 8, 9), (4, 5, 6, 7),
128
+ (2, 3, 7, 8), (1, 2, 6, 7), (2, 3, 6, 7),
129
+ (3, 4, 5, 6), (0, 2, 5, 9), (1, 4, 7, 9),
130
+ ]
131
+
132
+ def choose_action(self, state):
133
+ actions = self._get_possible_actions()
134
+ if np.random.random() < self.epsilon:
135
+ return actions[np.random.randint(len(actions))]
136
+ else:
137
+ best_action = None
138
+ best_value = -999999
139
+ for action in actions:
140
+ value = self.q_table[state][action]
141
+ if value > best_value:
142
+ best_value = value
143
+ best_action = action
144
+ if best_action is None:
145
+ best_action = actions[np.random.randint(len(actions))]
146
+ return best_action
147
+
148
+ def learn(self, state, action, reward, next_state, done):
149
+ old_q = self.q_table[state][action]
150
+ if done:
151
+ max_future_q = 0
152
+ else:
153
+ actions = self._get_possible_actions()
154
+ max_future_q = max([self.q_table[next_state][a] for a in actions])
155
+ target = reward + self.discount_factor * max_future_q
156
+ new_q = old_q + self.learning_rate * (target - old_q)
157
+ self.q_table[state][action] = new_q
158
+
159
+ def decay_epsilon(self):
160
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
161
+
162
+
163
+ # ==============================================================================
164
+ # TRAINING AND TESTING FUNCTIONS
165
+ # ==============================================================================
166
+
167
+ def train_and_test(n_episodes, hotspot1, hotspot2, progress=gr.Progress()):
168
+ """Train AI and compare with other strategies."""
169
+
170
+ np.random.seed(42)
171
+
172
+ # Training
173
+ env = SensorPlacementEnv(hotspot1=hotspot1, hotspot2=hotspot2)
174
+ agent = QLearningAgent()
175
+
176
+ episode_rewards = []
177
+ episode_catch_rates = []
178
+ epsilon_history = []
179
+
180
+ for episode in progress.tqdm(range(n_episodes), desc="Training AI"):
181
+ state = env.reset()
182
+ total_reward = 0
183
+
184
+ for day in range(30):
185
+ action = agent.choose_action(state)
186
+ next_state, reward, done, _ = env.step(action)
187
+ agent.learn(state, action, reward, next_state, done)
188
+ state = next_state
189
+ total_reward += reward
190
+ if done:
191
+ break
192
+
193
+ agent.decay_epsilon()
194
+ episode_rewards.append(total_reward)
195
+ catch_rate = env.total_caught / max(env.total_thieves, 1) * 100
196
+ episode_catch_rates.append(catch_rate)
197
+ epsilon_history.append(agent.epsilon)
198
+
199
+ # Testing
200
+ n_tests = 50
201
+ results = {}
202
+
203
+ # Q-Learning AI
204
+ agent.epsilon = 0
205
+ catches = []
206
+ for _ in range(n_tests):
207
+ state = env.reset()
208
+ for day in range(30):
209
+ action = agent.choose_action(state)
210
+ state, _, done, _ = env.step(action)
211
+ if done:
212
+ break
213
+ catches.append(env.total_caught / max(env.total_thieves, 1) * 100)
214
+ results['Q-Learning AI'] = np.mean(catches)
215
+
216
+ # Random
217
+ catches = []
218
+ for _ in range(n_tests):
219
+ env.reset()
220
+ for day in range(30):
221
+ action = tuple(np.random.choice(10, 4, replace=False))
222
+ _, _, done, _ = env.step(action)
223
+ if done:
224
+ break
225
+ catches.append(env.total_caught / max(env.total_thieves, 1) * 100)
226
+ results['Random'] = np.mean(catches)
227
+
228
+ # Static
229
+ catches = []
230
+ for _ in range(n_tests):
231
+ env.reset()
232
+ for day in range(30):
233
+ _, _, done, _ = env.step((1, 3, 6, 8))
234
+ if done:
235
+ break
236
+ catches.append(env.total_caught / max(env.total_thieves, 1) * 100)
237
+ results['Static Uniform'] = np.mean(catches)
238
+
239
+ # Perfect
240
+ h1_zone = int(hotspot1)
241
+ h2_zone = int(hotspot2)
242
+ perfect_action = (h1_zone, h1_zone+1, h2_zone, h2_zone+1)
243
+ perfect_action = tuple(min(z, 9) for z in perfect_action)
244
+ catches = []
245
+ for _ in range(n_tests):
246
+ env.reset()
247
+ for day in range(30):
248
+ _, _, done, _ = env.step(perfect_action)
249
+ if done:
250
+ break
251
+ catches.append(env.total_caught / max(env.total_thieves, 1) * 100)
252
+ results['Perfect (Cheating)'] = np.mean(catches)
253
+
254
+ # Create plots
255
+ fig = plt.figure(figsize=(16, 12))
256
+
257
+ # Plot 1: Learning curve
258
+ ax1 = fig.add_subplot(2, 2, 1)
259
+ window = max(10, n_episodes // 20)
260
+ if len(episode_catch_rates) >= window:
261
+ smoothed = np.convolve(episode_catch_rates, np.ones(window)/window, mode='valid')
262
+ ax1.plot(episode_catch_rates, alpha=0.3, color='green', label='Raw')
263
+ ax1.plot(range(window-1, len(episode_catch_rates)), smoothed,
264
+ color='green', linewidth=2, label='Smoothed')
265
+ else:
266
+ ax1.plot(episode_catch_rates, color='green', linewidth=2)
267
+ ax1.set_xlabel('Episode', fontsize=12)
268
+ ax1.set_ylabel('Catch Rate (%)', fontsize=12)
269
+ ax1.set_title('๐ŸŽ“ AI Learning Progress', fontsize=14)
270
+ ax1.legend()
271
+ ax1.grid(True, alpha=0.3)
272
+
273
+ # Plot 2: Epsilon decay
274
+ ax2 = fig.add_subplot(2, 2, 2)
275
+ ax2.plot(epsilon_history, color='purple', linewidth=2)
276
+ ax2.set_xlabel('Episode', fontsize=12)
277
+ ax2.set_ylabel('Epsilon (Exploration Rate)', fontsize=12)
278
+ ax2.set_title('๐Ÿ” Explore vs Exploit Balance', fontsize=14)
279
+ ax2.grid(True, alpha=0.3)
280
+
281
+ # Add annotations
282
+ ax2.annotate('100% Random\n(Exploring)', xy=(0, 1), fontsize=10,
283
+ xytext=(n_episodes*0.1, 0.8), arrowprops=dict(arrowstyle='->', color='gray'))
284
+ ax2.annotate('Mostly Using\nKnowledge', xy=(n_episodes-1, epsilon_history[-1]), fontsize=10,
285
+ xytext=(n_episodes*0.7, 0.3), arrowprops=dict(arrowstyle='->', color='gray'))
286
+
287
+ # Plot 3: What AI learned vs Truth
288
+ ax3 = fig.add_subplot(2, 2, 3)
289
+
290
+ zone_values = np.zeros(10)
291
+ zone_counts = np.zeros(10)
292
+ for state, actions in agent.q_table.items():
293
+ for action, value in actions.items():
294
+ for zone in action:
295
+ zone_values[zone] += value
296
+ zone_counts[zone] += 1
297
+ zone_counts[zone_counts == 0] = 1
298
+ learned = zone_values / zone_counts
299
+
300
+ world = ThiefWorld(hotspot1, hotspot2)
301
+ truth = [world.get_thief_probability(z) for z in range(10)]
302
+
303
+ x = np.arange(10)
304
+ width = 0.35
305
+ ax3.bar(x - width/2, learned / max(learned.max(), 0.01), width,
306
+ label='AI Learned', color='blue', alpha=0.7)
307
+ ax3.bar(x + width/2, np.array(truth) / max(truth), width,
308
+ label='True Probability', color='red', alpha=0.7)
309
+ ax3.axvline(hotspot1, color='red', linestyle='--', alpha=0.5, label=f'Hotspot 1 ({hotspot1})')
310
+ ax3.axvline(hotspot2, color='darkred', linestyle='--', alpha=0.5, label=f'Hotspot 2 ({hotspot2})')
311
+ ax3.set_xlabel('Zone', fontsize=12)
312
+ ax3.set_ylabel('Normalized Value', fontsize=12)
313
+ ax3.set_title('๐Ÿง  Did AI Learn the Truth?', fontsize=14)
314
+ ax3.legend(loc='upper right')
315
+ ax3.grid(True, alpha=0.3)
316
+ ax3.set_xticks(range(10))
317
+
318
+ # Plot 4: Final comparison
319
+ ax4 = fig.add_subplot(2, 2, 4)
320
+ names = list(results.keys())
321
+ values = list(results.values())
322
+ colors = ['green', 'gray', 'orange', 'blue']
323
+ bars = ax4.bar(names, values, color=colors, alpha=0.7, edgecolor='black')
324
+
325
+ for bar, val in zip(bars, values):
326
+ ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
327
+ f'{val:.1f}%', ha='center', fontsize=12, fontweight='bold')
328
+
329
+ ax4.set_ylabel('Catch Rate (%)', fontsize=12)
330
+ ax4.set_title('๐Ÿ† Final Comparison', fontsize=14)
331
+ ax4.grid(True, alpha=0.3, axis='y')
332
+ plt.setp(ax4.xaxis.get_majorticklabels(), rotation=15, ha='right')
333
+
334
+ plt.tight_layout()
335
+
336
+ # Results text
337
+ results_text = f"""
338
+ ## ๐ŸŽฏ Training Complete!
339
+
340
+ ### Training Summary:
341
+ - Episodes trained: **{n_episodes}**
342
+ - Hotspot 1: Zone **{hotspot1}**
343
+ - Hotspot 2: Zone **{hotspot2}**
344
+ - Final exploration rate: **{epsilon_history[-1]*100:.1f}%**
345
+
346
+ ### ๐Ÿ“Š Test Results (50 test runs each):
347
+
348
+ | Strategy | Catch Rate |
349
+ |----------|------------|
350
+ | ๐Ÿ† **Q-Learning AI** | **{results['Q-Learning AI']:.1f}%** |
351
+ | Random | {results['Random']:.1f}% |
352
+ | Static Uniform | {results['Static Uniform']:.1f}% |
353
+ | Perfect (Cheating) | {results['Perfect (Cheating)']:.1f}% |
354
+
355
+ ### ๐Ÿง  What AI Learned:
356
+ The AI discovered that zones **{int(hotspot1)}** and **{int(hotspot2)}** have more thieves!
357
+
358
+ ### ๐ŸŽ“ Key Insight:
359
+ AI started knowing **NOTHING** and learned through **trial and error**!
360
+ """
361
+
362
+ return fig, results_text
363
+
364
+
365
+ def explain_qlearning():
366
+ """Create explanation visualization."""
367
+
368
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
369
+
370
+ # Plot 1: Q-Learning cycle
371
+ ax1 = axes[0]
372
+ ax1.axis('off')
373
+
374
+ # Draw cycle
375
+ cycle_text = """
376
+ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
377
+ โ”‚ Q-LEARNING CYCLE โ”‚
378
+ โ”‚ โ”‚
379
+ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
380
+ โ”‚ โ”‚ STATE โ”‚ โ”‚
381
+ โ”‚ โ”‚(What AI โ”‚ โ”‚
382
+ โ”‚ โ”‚ sees) โ”‚ โ”‚
383
+ โ”‚ โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜ โ”‚
384
+ โ”‚ โ”‚ โ”‚
385
+ โ”‚ โ–ผ โ”‚
386
+ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
387
+ โ”‚ โ”‚ UPDATE โ”‚โ—„โ”€โ”€โ”€โ”€โ”‚ ACTION โ”‚โ”€โ”€โ”€โ”€โ–บโ”‚ REWARD โ”‚ โ”‚
388
+ โ”‚ โ”‚ Q-TABLE โ”‚ โ”‚(Place โ”‚ โ”‚(Caught โ”‚ โ”‚
389
+ โ”‚ โ”‚(Remember)โ”‚ โ”‚sensors) โ”‚ โ”‚thieves?) โ”‚ โ”‚
390
+ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
391
+ โ”‚ โ”‚ โ”‚ โ”‚
392
+ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
393
+ โ”‚ REPEAT! โ”‚
394
+ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
395
+ """
396
+ ax1.text(0.5, 0.5, cycle_text, transform=ax1.transAxes, fontsize=10,
397
+ verticalalignment='center', horizontalalignment='center',
398
+ fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='lightyellow'))
399
+ ax1.set_title('How Q-Learning Works', fontsize=14)
400
+
401
+ # Plot 2: Epsilon explanation
402
+ ax2 = axes[1]
403
+ episodes = np.arange(500)
404
+ epsilon = 1.0 * (0.995 ** episodes)
405
+ epsilon = np.maximum(epsilon, 0.01)
406
+
407
+ ax2.fill_between(episodes, epsilon, alpha=0.3, color='blue', label='EXPLORE')
408
+ ax2.fill_between(episodes, 0, 1-epsilon, alpha=0.3, color='green', label='EXPLOIT')
409
+ ax2.plot(episodes, epsilon, 'b-', linewidth=2)
410
+ ax2.plot(episodes, 1-epsilon, 'g-', linewidth=2)
411
+
412
+ ax2.axvline(50, color='gray', linestyle='--', alpha=0.5)
413
+ ax2.axvline(200, color='gray', linestyle='--', alpha=0.5)
414
+ ax2.axvline(400, color='gray', linestyle='--', alpha=0.5)
415
+
416
+ ax2.text(25, 0.5, 'Early:\n80% Explore', fontsize=9, ha='center')
417
+ ax2.text(125, 0.5, 'Middle:\n50-50', fontsize=9, ha='center')
418
+ ax2.text(300, 0.5, 'Late:\n80% Exploit', fontsize=9, ha='center')
419
+
420
+ ax2.set_xlabel('Episode', fontsize=12)
421
+ ax2.set_ylabel('Probability', fontsize=12)
422
+ ax2.set_title('Explore vs Exploit Over Time', fontsize=14)
423
+ ax2.legend(loc='center right')
424
+ ax2.grid(True, alpha=0.3)
425
+
426
+ plt.tight_layout()
427
+ return fig
428
+
429
+
430
+ def show_environment(hotspot1, hotspot2):
431
+ """Visualize the thief world."""
432
+
433
+ fig, ax = plt.subplots(figsize=(12, 5))
434
+
435
+ world = ThiefWorld(hotspot1, hotspot2)
436
+ zones = np.arange(10)
437
+ probs = [world.get_thief_probability(z) for z in zones]
438
+
439
+ colors = ['red' if p > 0.4 else 'orange' if p > 0.2 else 'green' for p in probs]
440
+ bars = ax.bar(zones, probs, color=colors, alpha=0.7, edgecolor='black')
441
+
442
+ for bar, prob in zip(bars, probs):
443
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
444
+ f'{prob*100:.0f}%', ha='center', fontsize=10, fontweight='bold')
445
+
446
+ ax.axvline(hotspot1, color='red', linestyle='--', linewidth=2, label=f'Hotspot 1 ({hotspot1})')
447
+ ax.axvline(hotspot2, color='darkred', linestyle='--', linewidth=2, label=f'Hotspot 2 ({hotspot2})')
448
+
449
+ ax.set_xlabel('Zone', fontsize=12)
450
+ ax.set_ylabel('Thief Probability', fontsize=12)
451
+ ax.set_title('๐Ÿฆน Secret Thief Locations (AI Must Discover This!)', fontsize=14)
452
+ ax.set_xticks(zones)
453
+ ax.legend()
454
+ ax.grid(True, alpha=0.3, axis='y')
455
+
456
+ plt.tight_layout()
457
+ return fig
458
+
459
+
460
+ def simulate_one_episode(hotspot1, hotspot2):
461
+ """Simulate and visualize one episode."""
462
+
463
+ np.random.seed(None) # Random seed for variety
464
+
465
+ env = SensorPlacementEnv(hotspot1=hotspot1, hotspot2=hotspot2)
466
+ agent = QLearningAgent()
467
+ agent.epsilon = 0.5 # 50% explore for demo
468
+
469
+ state = env.reset()
470
+
471
+ # Track daily data
472
+ daily_actions = []
473
+ daily_caught = []
474
+ daily_thieves = []
475
+
476
+ for day in range(30):
477
+ action = agent.choose_action(state)
478
+ daily_actions.append(action)
479
+
480
+ old_caught = env.total_caught
481
+ old_thieves = env.total_thieves
482
+
483
+ state, reward, done, info = env.step(action)
484
+
485
+ daily_caught.append(env.total_caught - old_caught)
486
+ daily_thieves.append(env.total_thieves - old_thieves)
487
+
488
+ agent.learn(state, action, reward, state, done)
489
+
490
+ # Create visualization
491
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
492
+
493
+ # Plot 1: Sensor placements over days
494
+ ax1 = axes[0, 0]
495
+ for day, action in enumerate(daily_actions):
496
+ for zone in action:
497
+ ax1.scatter(day, zone, c='blue', s=30, alpha=0.6)
498
+
499
+ ax1.axhline(hotspot1, color='red', linestyle='--', alpha=0.5, label=f'Hotspot 1')
500
+ ax1.axhline(hotspot2, color='darkred', linestyle='--', alpha=0.5, label=f'Hotspot 2')
501
+ ax1.set_xlabel('Day', fontsize=12)
502
+ ax1.set_ylabel('Zone', fontsize=12)
503
+ ax1.set_title('๐Ÿ“ Where AI Placed Sensors Each Day', fontsize=14)
504
+ ax1.legend()
505
+ ax1.grid(True, alpha=0.3)
506
+ ax1.set_yticks(range(10))
507
+
508
+ # Plot 2: Daily catches
509
+ ax2 = axes[0, 1]
510
+ days = range(1, 31)
511
+ ax2.bar(days, daily_caught, color='green', alpha=0.7, label='Caught')
512
+ ax2.plot(days, daily_thieves, 'ro-', markersize=5, label='Total Thieves')
513
+ ax2.set_xlabel('Day', fontsize=12)
514
+ ax2.set_ylabel('Count', fontsize=12)
515
+ ax2.set_title('๐ŸŽฏ Daily Catches', fontsize=14)
516
+ ax2.legend()
517
+ ax2.grid(True, alpha=0.3)
518
+
519
+ # Plot 3: Cumulative performance
520
+ ax3 = axes[1, 0]
521
+ cum_caught = np.cumsum(daily_caught)
522
+ cum_thieves = np.cumsum(daily_thieves)
523
+ ax3.fill_between(days, cum_caught, alpha=0.3, color='green')
524
+ ax3.plot(days, cum_caught, 'g-', linewidth=2, label='Cumulative Caught')
525
+ ax3.plot(days, cum_thieves, 'r--', linewidth=2, label='Cumulative Thieves')
526
+ ax3.set_xlabel('Day', fontsize=12)
527
+ ax3.set_ylabel('Cumulative Count', fontsize=12)
528
+ ax3.set_title('๐Ÿ“ˆ Cumulative Performance', fontsize=14)
529
+ ax3.legend()
530
+ ax3.grid(True, alpha=0.3)
531
+
532
+ # Plot 4: Zone usage
533
+ ax4 = axes[1, 1]
534
+ zone_usage = np.zeros(10)
535
+ for action in daily_actions:
536
+ for zone in action:
537
+ zone_usage[zone] += 1
538
+
539
+ colors = ['blue' if z in [int(hotspot1), int(hotspot1)+1, int(hotspot2), int(hotspot2)+1]
540
+ else 'gray' for z in range(10)]
541
+ ax4.bar(range(10), zone_usage, color=colors, alpha=0.7, edgecolor='black')
542
+ ax4.axvline(hotspot1, color='red', linestyle='--', alpha=0.5)
543
+ ax4.axvline(hotspot2, color='darkred', linestyle='--', alpha=0.5)
544
+ ax4.set_xlabel('Zone', fontsize=12)
545
+ ax4.set_ylabel('Times Used', fontsize=12)
546
+ ax4.set_title('๐Ÿ—บ๏ธ Zone Usage (Blue = Near Hotspots)', fontsize=14)
547
+ ax4.set_xticks(range(10))
548
+ ax4.grid(True, alpha=0.3, axis='y')
549
+
550
+ plt.tight_layout()
551
+
552
+ # Summary
553
+ catch_rate = env.total_caught / max(env.total_thieves, 1) * 100
554
+ summary = f"""
555
+ ## ๐Ÿ“Š Episode Summary
556
+
557
+ - **Total Thieves:** {env.total_thieves}
558
+ - **Total Caught:** {env.total_caught}
559
+ - **Catch Rate:** {catch_rate:.1f}%
560
+
561
+ ### Zones Most Used:
562
+ {', '.join([f'Zone {i}' for i in np.argsort(zone_usage)[-3:][::-1]])}
563
+
564
+ ### Note:
565
+ This is just ONE episode with 50% exploration.
566
+ Train for 500+ episodes to see real learning!
567
+ """
568
+
569
+ return fig, summary
570
+
571
+
572
+ # ==============================================================================
573
+ # GRADIO INTERFACE
574
+ # ==============================================================================
575
+
576
+ with gr.Blocks(title="Q-Learning AI Demo", theme=gr.themes.Soft()) as demo:
577
+
578
+ gr.Markdown("""
579
+ # ๐Ÿค– Q-Learning AI for Sensor Placement
580
+
581
+ **Watch an AI learn where to place sensors to catch thieves!**
582
+
583
+ The AI starts knowing NOTHING and learns through trial-and-error.
584
+
585
+ ---
586
+ """)
587
+
588
+ with gr.Tabs():
589
+
590
+ # ==== TAB 1: Explanation ====
591
+ with gr.TabItem("1๏ธโƒฃ What is Q-Learning?"):
592
+ gr.Markdown("""
593
+ ## ๐ŸŽ“ Q-Learning Explained Simply
594
+
595
+ ### Like Teaching a Dog:
596
+ ```
597
+ 1. Dog tries something โ†’ 2. Gets treat (or not) โ†’ 3. Remembers โ†’ 4. Gets smarter!
598
+ ```
599
+
600
+ ### For Our AI:
601
+ ```
602
+ 1. AI places sensors โ†’ 2. Catches thieves (reward!) โ†’ 3. Updates Q-Table โ†’ 4. Gets smarter!
603
+ ```
604
+
605
+ ### The Q-Table (AI's Memory):
606
+
607
+ | State | Action | Expected Reward |
608
+ |-------|--------|-----------------|
609
+ | "Day 1" | Zones (1,3,6,8) | 1.5 points |
610
+ | "Day 1" | Zones (2,3,7,8) | 3.2 points โ† Better! |
611
+
612
+ ### Explore vs Exploit:
613
+ - **EXPLORE**: Try random things to learn
614
+ - **EXPLOIT**: Use what you already know
615
+
616
+ Early training โ†’ More EXPLORE
617
+ Late training โ†’ More EXPLOIT
618
+ """)
619
+
620
+ explain_btn = gr.Button("๐Ÿ“Š Show Visual Explanation", variant="primary")
621
+ explain_plot = gr.Plot()
622
+ explain_btn.click(explain_qlearning, outputs=explain_plot)
623
+
624
+ # ==== TAB 2: Environment ====
625
+ with gr.TabItem("2๏ธโƒฃ The Secret World"):
626
+ gr.Markdown("""
627
+ ## ๐Ÿฆน Where Do Thieves Appear?
628
+
629
+ The AI doesn't know this! It must DISCOVER it through learning.
630
+
631
+ Adjust the hotspot locations and see the thief distribution:
632
+ """)
633
+
634
+ with gr.Row():
635
+ h1_slider = gr.Slider(0, 9, value=2.5, step=0.5, label="Hotspot 1 Location")
636
+ h2_slider = gr.Slider(0, 9, value=7.0, step=0.5, label="Hotspot 2 Location")
637
+
638
+ env_btn = gr.Button("๐Ÿ—บ๏ธ Show Thief Distribution", variant="primary")
639
+ env_plot = gr.Plot()
640
+ env_btn.click(show_environment, [h1_slider, h2_slider], env_plot)
641
+
642
+ # ==== TAB 3: One Episode ====
643
+ with gr.TabItem("3๏ธโƒฃ Watch One Episode"):
644
+ gr.Markdown("""
645
+ ## ๐Ÿ‘€ See One Month (30 Days) of Simulation
646
+
647
+ Watch how AI makes decisions and catches thieves.
648
+
649
+ (Note: This is untrained AI with 50% exploration rate)
650
+ """)
651
+
652
+ with gr.Row():
653
+ h1_ep = gr.Slider(0, 9, value=2.5, step=0.5, label="Hotspot 1")
654
+ h2_ep = gr.Slider(0, 9, value=7.0, step=0.5, label="Hotspot 2")
655
+
656
+ ep_btn = gr.Button("โ–ถ๏ธ Run One Episode", variant="primary")
657
+ ep_plot = gr.Plot()
658
+ ep_summary = gr.Markdown()
659
+ ep_btn.click(simulate_one_episode, [h1_ep, h2_ep], [ep_plot, ep_summary])
660
+
661
+ # ==== TAB 4: Full Training ====
662
+ with gr.TabItem("4๏ธโƒฃ Train the AI!"):
663
+ gr.Markdown("""
664
+ ## ๐Ÿ‹๏ธ Train Q-Learning AI
665
+
666
+ Train the AI and compare it against other strategies!
667
+
668
+ โš ๏ธ Training takes a few seconds depending on episodes.
669
+ """)
670
+
671
+ with gr.Row():
672
+ episodes_slider = gr.Slider(100, 1000, value=300, step=50,
673
+ label="Number of Episodes")
674
+
675
+ with gr.Row():
676
+ h1_train = gr.Slider(0, 9, value=2.5, step=0.5, label="Hotspot 1")
677
+ h2_train = gr.Slider(0, 9, value=7.0, step=0.5, label="Hotspot 2")
678
+
679
+ train_btn = gr.Button("๐Ÿš€ Train AI!", variant="primary", size="lg")
680
+
681
+ train_plot = gr.Plot()
682
+ train_results = gr.Markdown()
683
+
684
+ train_btn.click(train_and_test,
685
+ [episodes_slider, h1_train, h2_train],
686
+ [train_plot, train_results])
687
+
688
+ # ==== TAB 5: Summary ====
689
+ with gr.TabItem("5๏ธโƒฃ Key Concepts"):
690
+ gr.Markdown("""
691
+ ## ๐Ÿ“š Summary: Q-Learning Key Concepts
692
+
693
+ ### 1. Q-Table
694
+ ```
695
+ A "cheat sheet" that stores:
696
+ "In STATE X, if I do ACTION Y, I expect REWARD Z"
697
+ ```
698
+
699
+ ### 2. State
700
+ ```
701
+ What the AI "sees" at any moment.
702
+ Example: (most_tried_zone, best_zone_so_far)
703
+ ```
704
+
705
+ ### 3. Action
706
+ ```
707
+ What the AI can do.
708
+ Example: Place sensors in zones (2, 3, 7, 8)
709
+ ```
710
+
711
+ ### 4. Reward
712
+ ```
713
+ Points for good actions.
714
+ Example: +1 for each thief caught
715
+ ```
716
+
717
+ ### 5. Epsilon (ฮต)
718
+ ```
719
+ Exploration rate.
720
+ ฮต = 1.0 โ†’ 100% random (exploring)
721
+ ฮต = 0.01 โ†’ 1% random (exploiting knowledge)
722
+ ```
723
+
724
+ ### 6. Learning Formula
725
+ ```
726
+ Q(s,a) = Q(s,a) + ฮฑ ร— (reward + ฮณ ร— max(Q(s',a')) - Q(s,a))
727
+
728
+ In simple terms:
729
+ New Memory = Old Memory + Learning Rate ร— (Reality - Expectation)
730
+ ```
731
+
732
+ ---
733
+
734
+ ## ๐ŸŽฏ Why This Matters
735
+
736
+ This same technique is used in:
737
+ - ๐ŸŽฎ Game AI (AlphaGo, Chess engines)
738
+ - ๐Ÿš— Self-driving cars
739
+ - ๐Ÿค– Robots
740
+ - ๐Ÿ“ฑ Recommendation systems
741
+
742
+ **You just learned how real AI works!** ๐ŸŽ“
743
+ """)
744
+
745
+ gr.Markdown("""
746
+ ---
747
+
748
+ ### ๐Ÿ”— About
749
+
750
+ This demo shows **Q-Learning Reinforcement Learning** for sensor placement.
751
+
752
+ The AI learns through trial-and-error, just like humans!
753
+ """)
754
+
755
+
756
+ # Launch
757
+ if __name__ == "__main__":
758
+ demo.launch()