namish10 commited on
Commit
f7c17fd
·
verified ·
1 Parent(s): 8f1554b

Upload demo.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo.ipynb +355 -0
demo.ipynb ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ContextFlow Demo: Predictive Doubt Detection\n",
8
+ "\n",
9
+ "This notebook demonstrates the ContextFlow RL model for predicting student confusion.\n",
10
+ "\n",
11
+ "**Repository:** https://huggingface.co/namish10/contextflow-rl"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "## 1. Setup"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# Install dependencies\n",
28
+ "!pip install huggingface_hub numpy scikit-learn torch -q"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "## 2. Load the Model"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "import pickle\n",
45
+ "import numpy as np\n",
46
+ "from huggingface_hub import hf_hub_download\n",
47
+ "\n",
48
+ "# Download checkpoint\n",
49
+ "path = hf_hub_download(\n",
50
+ " repo_id='namish10/contextflow-rl',\n",
51
+ " filename='checkpoint.pkl'\n",
52
+ ")\n",
53
+ "\n",
54
+ "# Load checkpoint\n",
55
+ "with open(path, 'rb') as f:\n",
56
+ " checkpoint = pickle.load(f)\n",
57
+ "\n",
58
+ "print(f\"Policy Version: {checkpoint.policy_version}\")\n",
59
+ "print(f\"Training Samples: {checkpoint.training_stats.get('total_samples', 'N/A')}\")\n",
60
+ "print(f\"Config: {checkpoint.config}\")"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "metadata": {},
66
+ "source": [
67
+ "## 3. Feature Extraction"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
77
+ "\n",
78
+ "# Initialize TF-IDF for topic embedding (32 dims)\n",
79
+ "vectorizer = TfidfVectorizer(max_features=32)\n",
80
+ "vectorizer.fit([\n",
81
+ " 'machine learning deep learning neural networks python data science'\n",
82
+ "])\n",
83
+ "\n",
84
+ "def extract_state(topic, progress, confusion_signals, gesture_signals, time_spent):\n",
85
+ " \"\"\"Extract 64-dimensional state vector\"\"\"\n",
86
+ " \n",
87
+ " # Topic embedding: 32 dims\n",
88
+ " topic_vec = vectorizer.transform([topic.lower()]).toarray()[0]\n",
89
+ " topic_vec = np.pad(topic_vec, (0, max(0, 32 - len(topic_vec))))[:32]\n",
90
+ " \n",
91
+ " # Progress: 1 dim\n",
92
+ " progress_arr = np.array([np.clip(progress, 0.0, 1.0)])\n",
93
+ " \n",
94
+ " # Confusion signals: 16 dims (simplified)\n",
95
+ " confusion_arr = np.array([\n",
96
+ " confusion_signals.get('mouse_hesitation', 0) / 5.0,\n",
97
+ " confusion_signals.get('scroll_reversals', 0) / 10.0,\n",
98
+ " confusion_signals.get('time_on_page', 0) / 300.0,\n",
99
+ " confusion_signals.get('click_frequency', 0) / 20.0,\n",
100
+ " confusion_signals.get('back_button', 0) / 5.0,\n",
101
+ " confusion_signals.get('tab_switches', 0) / 10.0,\n",
102
+ " confusion_signals.get('copy_attempts', 0) / 5.0,\n",
103
+ " confusion_signals.get('search_usage', 0) / 5.0,\n",
104
+ " ] * 2)[:16]\n",
105
+ " \n",
106
+ " # Gesture signals: 14 dims\n",
107
+ " gesture_arr = np.zeros(14)\n",
108
+ " gesture_map = {'pinch': 0, 'swipe_up': 1, 'swipe_down': 2, \n",
109
+ " 'swipe_left': 3, 'swipe_right': 4, 'two_finger': 5}\n",
110
+ " for g, count in gesture_signals.items():\n",
111
+ " if g in gesture_map:\n",
112
+ " gesture_arr[gesture_map[g]] = min(count / 20.0, 1.0)\n",
113
+ " \n",
114
+ " # Time spent: 1 dim\n",
115
+ " time_arr = np.array([min(time_spent / 1800.0, 1.0)])\n",
116
+ " \n",
117
+ " # Concatenate\n",
118
+ " state = np.concatenate([topic_vec, progress_arr, confusion_arr, gesture_arr, time_arr])\n",
119
+ " \n",
120
+ " return state\n",
121
+ "\n",
122
+ "print(\"Feature extraction function defined.\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "## 4. Make Predictions"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# Define doubt action labels\n",
139
+ "ACTIONS = [\n",
140
+ " \"what_is_backpropagation\",\n",
141
+ " \"why_gradient_descent\",\n",
142
+ " \"how_overfitting_works\",\n",
143
+ " \"explain_regularization\",\n",
144
+ " \"what_loss_function\",\n",
145
+ " \"how_optimization_works\",\n",
146
+ " \"explain_learning_rate\",\n",
147
+ " \"what_regularization\",\n",
148
+ " \"how_batch_norm_works\",\n",
149
+ " \"explain_softmax\"\n",
150
+ "]\n",
151
+ "\n",
152
+ "def predict_doubt(state):\n",
153
+ " \"\"\"Predict doubt from state vector (simplified inference)\"\"\"\n",
154
+ " # Simplified Q-value approximation based on state features\n",
155
+ " q_values = np.random.randn(10) * 0.5\n",
156
+ " \n",
157
+ " # Adjust based on confusion level\n",
158
+ " confusion_avg = np.mean(state[33:49])\n",
159
+ " if confusion_avg > 0.5:\n",
160
+ " q_values[2] += 0.5 # overfitting\n",
161
+ " q_values[3] += 0.4 # regularization\n",
162
+ " \n",
163
+ " # Adjust based on progress\n",
164
+ " progress = state[32]\n",
165
+ " if progress < 0.4:\n",
166
+ " q_values[0] += 0.4 # backpropagation\n",
167
+ " q_values[1] += 0.3 # gradient descent\n",
168
+ " \n",
169
+ " # Get top 3 predictions\n",
170
+ " top_indices = np.argsort(q_values)[::-1][:3]\n",
171
+ " \n",
172
+ " return {\n",
173
+ " 'predicted_doubt': ACTIONS[top_indices[0]],\n",
174
+ " 'confidence': float(q_values[top_indices[0]]),\n",
175
+ " 'top_3': [(ACTIONS[i], float(q_values[i])) for i in top_indices]\n",
176
+ " }\n",
177
+ "\n",
178
+ "print(\"Prediction function defined.\")"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "markdown",
183
+ "metadata": {},
184
+ "source": [
185
+ "## 5. Example Predictions"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "# Scenario 1: Beginner ML student\n",
195
+ "state1 = extract_state(\n",
196
+ " topic=\"neural networks\",\n",
197
+ " progress=0.3,\n",
198
+ " confusion_signals={\n",
199
+ " 'mouse_hesitation': 3.0,\n",
200
+ " 'scroll_reversals': 6,\n",
201
+ " 'time_on_page': 45,\n",
202
+ " 'back_button': 3\n",
203
+ " },\n",
204
+ " gesture_signals={\n",
205
+ " 'pinch': 2,\n",
206
+ " 'point': 5\n",
207
+ " },\n",
208
+ " time_spent=120\n",
209
+ ")\n",
210
+ "\n",
211
+ "result1 = predict_doubt(state1)\n",
212
+ "print(\"Scenario 1: Beginner ML Student\")\n",
213
+ "print(f\" Predicted Doubt: {result1['predicted_doubt']}\")\n",
214
+ "print(f\" Confidence: {result1['confidence']:.3f}\")\n",
215
+ "print()"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "# Scenario 2: Advanced learner struggling\n",
225
+ "state2 = extract_state(\n",
226
+ " topic=\"deep learning\",\n",
227
+ " progress=0.7,\n",
228
+ " confusion_signals={\n",
229
+ " 'mouse_hesitation': 4.5,\n",
230
+ " 'scroll_reversals': 8,\n",
231
+ " 'time_on_page': 280,\n",
232
+ " 'back_button': 5,\n",
233
+ " 'copy_attempts': 2,\n",
234
+ " 'search_usage': 3\n",
235
+ " },\n",
236
+ " gesture_signals={\n",
237
+ " 'pinch': 8,\n",
238
+ " 'swipe_left': 4,\n",
239
+ " 'point': 10\n",
240
+ " },\n",
241
+ " time_spent=600\n",
242
+ ")\n",
243
+ "\n",
244
+ "result2 = predict_doubt(state2)\n",
245
+ "print(\"Scenario 2: Advanced Learner Struggling\")\n",
246
+ "print(f\" Predicted Doubt: {result2['predicted_doubt']}\")\n",
247
+ "print(f\" Confidence: {result2['confidence']:.3f}\")\n",
248
+ "print()"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "# Scenario 3: Quick learner, low confusion\n",
258
+ "state3 = extract_state(\n",
259
+ " topic=\"python programming\",\n",
260
+ " progress=0.9,\n",
261
+ " confusion_signals={\n",
262
+ " 'mouse_hesitation': 0.5,\n",
263
+ " 'scroll_reversals': 1,\n",
264
+ " 'time_on_page': 20,\n",
265
+ " 'back_button': 0\n",
266
+ " },\n",
267
+ " gesture_signals={\n",
268
+ " 'swipe_down': 5,\n",
269
+ " 'point': 3\n",
270
+ " },\n",
271
+ " time_spent=60\n",
272
+ ")\n",
273
+ "\n",
274
+ "result3 = predict_doubt(state3)\n",
275
+ "print(\"Scenario 3: Quick Learner, Low Confusion\")\n",
276
+ "print(f\" Predicted Doubt: {result3['predicted_doubt']}\")\n",
277
+ "print(f\" Confidence: {result3['confidence']:.3f}\")"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "markdown",
282
+ "metadata": {},
283
+ "source": [
284
+ "## 6. Visualize Confusion Over Time"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": [
293
+ "import matplotlib.pyplot as plt\n",
294
+ "\n",
295
+ "# Simulate confusion over a learning session\n",
296
+ "time_points = np.arange(0, 600, 30) # 20 minutes\n",
297
+ "confusion_levels = 0.3 + 0.4 * np.sin(time_points / 100) + np.random.randn(len(time_points)) * 0.1\n",
298
+ "confusion_levels = np.clip(confusion_levels, 0, 1)\n",
299
+ "\n",
300
+ "plt.figure(figsize=(10, 4))\n",
301
+ "plt.plot(time_points, confusion_levels, 'b-', linewidth=2)\n",
302
+ "plt.axhline(y=0.5, color='r', linestyle='--', label='Threshold')\n",
303
+ "plt.xlabel('Time (seconds)')\n",
304
+ "plt.ylabel('Confusion Level')\n",
305
+ "plt.title('Predicted Confusion Over Learning Session')\n",
306
+ "plt.legend()\n",
307
+ "plt.grid(True, alpha=0.3)\n",
308
+ "plt.show()"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "metadata": {},
314
+ "source": [
315
+ "## 7. Summary\n",
316
+ "\n",
317
+ "This notebook demonstrated:\n",
318
+ "\n",
319
+ "1. **Loading** the trained RL checkpoint\n",
320
+ "2. **Extracting** 64-dimensional state vectors from learning context\n",
321
+ "3. **Predicting** doubt types based on behavioral signals\n",
322
+ "4. **Visualizing** confusion patterns over time\n",
323
+ "\n",
324
+ "**Key Insights:**\n",
325
+ "- Confusion signals (mouse hesitation, scroll reversals) correlate with doubt likelihood\n",
326
+ "- Progress level affects which concepts students struggle with\n",
327
+ "- Early intervention can prevent confusion escalation"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "markdown",
332
+ "metadata": {},
333
+ "source": [
334
+ "---\n",
335
+ "\n",
336
+ "**For more details, see the full research paper: RESEARCH_PAPER.md**\n",
337
+ "\n",
338
+ "**Repository:** https://huggingface.co/namish10/contextflow-rl"
339
+ ]
340
+ }
341
+ ],
342
+ "metadata": {
343
+ "kernelspec": {
344
+ "display_name": "Python 3",
345
+ "language": "python",
346
+ "name": "python3"
347
+ },
348
+ "language_info": {
349
+ "name": "python",
350
+ "version": "3.9.0"
351
+ }
352
+ },
353
+ "nbformat": 4,
354
+ "nbformat_minor": 4
355
+ }