{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "dp7hwReFXokU", "0wpUVBCiXmJg" ], "machine_shape": "hm", "gpuType": "G4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "d5c2a63f6f8544e79268b9bade807345": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f027a57fcb6a49cfbbd28b95a6c0adf7", "IPY_MODEL_4df257d047cb4e7ab2a0a6c57be58b81", "IPY_MODEL_214861fbbd124b45ba164e934f91a024" ], "layout": "IPY_MODEL_f6cf4e1cd78c4c0da8bf756a7cc8760a" } }, "f027a57fcb6a49cfbbd28b95a6c0adf7": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_bd3077d18b9640abb14a4c385cac7c39", "placeholder": "​", "style": "IPY_MODEL_c39efce145764de59d3dc9cb7a818d33", "value": "Loading weights: 100%" } }, "4df257d047cb4e7ab2a0a6c57be58b81": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b9fc695d6d70408690bbd7ef8bb5841a", "max": 202, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a1486fd248674e9584e90d907b5156c2", "value": 202 } }, "214861fbbd124b45ba164e934f91a024": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3472ba9857a543568ec3cd2f340571d6", "placeholder": "​", "style": "IPY_MODEL_89eae54541d447bdaa9625cbbe357e55", "value": " 202/202 [00:00<00:00, 4335.55it/s, Materializing param=cls.predictions.transform.dense.weight]" } }, "f6cf4e1cd78c4c0da8bf756a7cc8760a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bd3077d18b9640abb14a4c385cac7c39": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c39efce145764de59d3dc9cb7a818d33": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b9fc695d6d70408690bbd7ef8bb5841a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a1486fd248674e9584e90d907b5156c2": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "3472ba9857a543568ec3cd2f340571d6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "89eae54541d447bdaa9625cbbe357e55": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ae1bdb89a9704d09a1c03ec82354905e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8d03191c19b64a698be6d6fc141817cb", "IPY_MODEL_8d24919b783b47748aefac2a6c234313", "IPY_MODEL_38bea3901f6c4e339d17a0269f0e535b" ], "layout": "IPY_MODEL_e4655e0917814b1b950d53a8f59d1fa5" } }, "8d03191c19b64a698be6d6fc141817cb": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_325923e6a4054298be70581c2cd1061d", "placeholder": "​", "style": "IPY_MODEL_97bcf95ab5ac4823b2f696f42db534da", "value": "Loading weights: 100%" } }, "8d24919b783b47748aefac2a6c234313": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_866d9713b1f5436b924a5505e36b9e7d", "max": 202, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8f481ca713e04111a7dcab82f19a072b", "value": 202 } }, "38bea3901f6c4e339d17a0269f0e535b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8566ff523c5d4bb5b799cdd12a95c633", "placeholder": "​", "style": "IPY_MODEL_bd53b01466524897aea749b68000684a", "value": " 202/202 [00:00<00:00, 4209.01it/s, Materializing param=cls.predictions.transform.dense.weight]" } }, "e4655e0917814b1b950d53a8f59d1fa5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "325923e6a4054298be70581c2cd1061d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "97bcf95ab5ac4823b2f696f42db534da": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "866d9713b1f5436b924a5505e36b9e7d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8f481ca713e04111a7dcab82f19a072b": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "8566ff523c5d4bb5b799cdd12a95c633": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bd53b01466524897aea749b68000684a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "# geolip-rescale experimentation" ], "metadata": { "id": "dp7hwReFXokU" } }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KrVSx-_rMPBx", "outputId": "0972f0ef-a170-4554-cfac-13f98e6a894a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "======================================================================\n", "ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\n", "======================================================================\n", " Device: cuda\n", "\n", " Task: 16-class pattern recognition, seq_len=64, noise=0.3\n", " Chance accuracy: 6.2%\n", " Scales: 256 → 224 → 192 → 160 → 131 → 113 → 97 → 73 → 64\n", " CV tolerance: ±0.05\n", "\n", "======================================================================\n", "SCALE 0: 256-dim (ROOT — train from scratch)\n", "======================================================================\n", " Params: 153,872\n", " CV before training: 0.0771\n", " Trained: 200 epochs → acc=0.8720, cv=0.0839\n", " layer_0: CV=0.1256 eff_rank=62.2\n", " layer_1: CV=0.0966 eff_rank=203.9\n", " layer_2: CV=0.1023 eff_rank=194.6\n", " layer_3: CV=0.0308 eff_rank=15.9\n", "\n", "======================================================================\n", "SCALE 1: 256-dim → 224-dim (12% reduction)\n", "======================================================================\n", " Params: 120,304 (78.2% of root)\n", "\n", " Projecting 256 → 224...\n", " After transfer: acc=0.8214, cv=0.1000\n", " layer_0: CV=0.1393 eff_rank=61.9\n", " layer_1: CV=0.0941 eff_rank=176.7\n", " layer_2: CV=0.1014 eff_rank=169.4\n", " layer_3: CV=0.0395 eff_rank=15.9\n", "\n", " Healing toward parent CV=0.0839 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8600, cv=0.0950\n", " layer_0: CV=0.1305 eff_rank=61.9\n", " layer_1: CV=0.0888 eff_rank=176.7\n", " layer_2: CV=0.1015 eff_rank=170.0\n", " layer_3: CV=0.0421 eff_rank=15.9\n", "\n", "======================================================================\n", "SCALE 2: 224-dim → 192-dim (14% reduction)\n", "======================================================================\n", " Params: 90,832 (59.0% of root)\n", "\n", " Projecting 224 → 192...\n", " After transfer: acc=0.8190, cv=0.0968\n", " layer_0: CV=0.1230 eff_rank=61.4\n", " layer_1: CV=0.1080 eff_rank=151.3\n", " layer_2: CV=0.1105 eff_rank=145.6\n", " layer_3: CV=0.0510 eff_rank=15.9\n", "\n", " Healing toward parent CV=0.0950 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8855, cv=0.0971\n", " layer_0: CV=0.1356 eff_rank=61.5\n", " layer_1: CV=0.0970 eff_rank=151.3\n", " layer_2: CV=0.1015 eff_rank=145.9\n", " layer_3: CV=0.0538 eff_rank=15.9\n", "\n", "======================================================================\n", "SCALE 3: 192-dim → 160-dim (17% reduction)\n", "======================================================================\n", " Params: 65,456 (42.5% of root)\n", "\n", " Projecting 192 → 160...\n", " After transfer: acc=0.7910, cv=0.0999\n", " layer_0: CV=0.1229 eff_rank=60.8\n", " layer_1: CV=0.1212 eff_rank=125.2\n", " layer_2: CV=0.1105 eff_rank=122.0\n", " layer_3: CV=0.0496 eff_rank=15.8\n", "\n", " Healing toward parent CV=0.0971 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8845, cv=0.1006\n", " layer_0: CV=0.1307 eff_rank=60.8\n", " layer_1: CV=0.1207 eff_rank=125.3\n", " layer_2: CV=0.1217 eff_rank=122.1\n", " layer_3: CV=0.0535 eff_rank=15.9\n", "\n", "======================================================================\n", "SCALE 4: 160-dim → 131-dim (18% reduction)\n", "======================================================================\n", " Params: 45,997 (29.9% of root)\n", "\n", " Projecting 160 → 131...\n", " After transfer: acc=0.7444, cv=0.1079\n", " layer_0: CV=0.1331 eff_rank=59.9\n", " layer_1: CV=0.1135 eff_rank=102.2\n", " layer_2: CV=0.1180 eff_rank=100.1\n", " layer_3: CV=0.0790 eff_rank=15.8\n", "\n", " Healing toward parent CV=0.1006 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8670, cv=0.1126\n", " layer_0: CV=0.1277 eff_rank=59.9\n", " layer_1: CV=0.1292 eff_rank=102.2\n", " layer_2: CV=0.1101 eff_rank=99.9\n", " layer_3: CV=0.0695 eff_rank=15.8\n", "\n", "======================================================================\n", "SCALE 5: 131-dim → 113-dim (14% reduction)\n", "======================================================================\n", " Params: 35,611 (23.1% of root)\n", "\n", " Projecting 131 → 113...\n", " After transfer: acc=0.7944, cv=0.1136\n", " layer_0: CV=0.1412 eff_rank=59.0\n", " layer_1: CV=0.1237 eff_rank=88.6\n", " layer_2: CV=0.1269 eff_rank=86.3\n", " layer_3: CV=0.0863 eff_rank=15.8\n", "\n", " Healing toward parent CV=0.1126 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8775, cv=0.1123\n", " layer_0: CV=0.1186 eff_rank=59.0\n", " layer_1: CV=0.1226 eff_rank=88.6\n", " layer_2: CV=0.1189 eff_rank=86.1\n", " layer_3: CV=0.0726 eff_rank=15.8\n", "\n", "======================================================================\n", "SCALE 6: 113-dim → 97-dim (14% reduction)\n", "======================================================================\n", " Params: 27,467 (17.9% of root)\n", "\n", " Projecting 113 → 97...\n", " After transfer: acc=0.7706, cv=0.1218\n", " layer_0: CV=0.1447 eff_rank=57.8\n", " layer_1: CV=0.1389 eff_rank=76.0\n", " layer_2: CV=0.1464 eff_rank=73.9\n", " layer_3: CV=0.0872 eff_rank=15.7\n", "\n", " Healing toward parent CV=0.1123 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8685, cv=0.1193\n", " layer_0: CV=0.1383 eff_rank=57.8\n", " layer_1: CV=0.1388 eff_rank=76.0\n", " layer_2: CV=0.1364 eff_rank=73.5\n", " layer_3: CV=0.0780 eff_rank=15.8\n", "\n", "======================================================================\n", "SCALE 7: 97-dim → 73-dim (25% reduction)\n", "======================================================================\n", " Params: 17,171 (11.2% of root)\n", "\n", " Projecting 97 → 73...\n", " After transfer: acc=0.6966, cv=0.1427\n", " layer_0: CV=0.1180 eff_rank=54.3\n", " layer_1: CV=0.1580 eff_rank=56.0\n", " layer_2: CV=0.1436 eff_rank=54.4\n", " layer_3: CV=0.1031 eff_rank=15.6\n", "\n", " Healing toward parent CV=0.1193 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8670, cv=0.1325\n", " layer_0: CV=0.1279 eff_rank=54.3\n", " layer_1: CV=0.1600 eff_rank=56.0\n", " layer_2: CV=0.1592 eff_rank=54.3\n", " layer_3: CV=0.0945 eff_rank=15.7\n", "\n", "======================================================================\n", "SCALE 8: 73-dim → 64-dim (12% reduction)\n", "======================================================================\n", " Params: 13,904 (9.0% of root)\n", "\n", " Projecting 73 → 64...\n", " After transfer: acc=0.7416, cv=0.1433\n", " layer_0: CV=0.1420 eff_rank=51.5\n", " layer_1: CV=0.1614 eff_rank=49.2\n", " layer_2: CV=0.1705 eff_rank=48.2\n", " layer_3: CV=0.1047 eff_rank=15.6\n", "\n", " Healing toward parent CV=0.1325 (±0.05)...\n", " Healed: 1 epochs (0.5s) → acc=0.8455, cv=0.1372\n", " layer_0: CV=0.1499 eff_rank=51.6\n", " layer_1: CV=0.1602 eff_rank=49.2\n", " layer_2: CV=0.1749 eff_rank=47.9\n", " layer_3: CV=0.0936 eff_rank=15.7\n", "\n", "======================================================================\n", "BASELINE: Train 64-dim from scratch\n", "======================================================================\n", " Trained: 200 epochs → acc=0.8560, cv=0.1436\n", "\n", "======================================================================\n", "DIRECT PROJECTION: 256 → 64 (single jump)\n", "======================================================================\n", " After direct transfer: acc=0.2960, cv=0.1740\n", "\n", "======================================================================\n", "RESULTS — ITERATIVE CASCADE\n", "======================================================================\n", "\n", " Scale Params Acc(proj) Acc(heal) CV(proj) CV(heal) Epochs\n", " ──────── ──────── ────────── ────────── ───────── ───────── ───────\n", " 256 153,872 0.8720 0.8720 0.0839 0.0839 200\n", " 224 120,304 0.8214 0.8600 0.1000 0.0950 1\n", " 192 90,832 0.8190 0.8855 0.0968 0.0971 1\n", " 160 65,456 0.7910 0.8845 0.0999 0.1006 1\n", " 131 45,997 0.7444 0.8670 0.1079 0.1126 1\n", " 113 35,611 0.7944 0.8775 0.1136 0.1123 1\n", " 97 27,467 0.7706 0.8685 0.1218 0.1193 1\n", " 73 17,171 0.6966 0.8670 0.1427 0.1325 1\n", " 64 13,904 0.7416 0.8455 0.1433 0.1372 1\n", "\n", " COMPARISONS\n", " ────────────────────────────────────────────────────────────\n", " Cascade 64-dim: acc=0.8455 cv=0.1372 (total 8 heal epochs)\n", " Direct proj 64-dim: acc=0.2960 cv=0.1740 (0 training)\n", " Scratch 64-dim: acc=0.8560 cv=0.1436 (200 epochs)\n", " Chance: acc=0.0625\n", "\n", " COMPRESSION:\n", " Root: 153,872 params\n", " Target: 13,904 params (9.0%)\n", " Ratio: 11.1×\n", "\n", " GEOMETRIC PRESERVATION:\n", " Root CV: 0.0839\n", " Final CV: 0.1372\n", " Δ CV: 0.0533\n", " Direct CV: 0.1740\n", " Scratch CV: 0.1436\n", "\n", "Done.\n" ] } ], "source": [ "# ============================================================================\n", "# ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\n", "#\n", "# Cascade: 256 → 224 → 192 → 160 → 131\n", "# At each scale:\n", "# 1. Procrustes-project from parent\n", "# 2. Measure accuracy + CV\n", "# 3. Train ONLY until CV reaches parent's CV band (±tolerance)\n", "# 4. Measure accuracy again\n", "# 5. Project down to next scale\n", "#\n", "# The hypothesis: small iterative steps preserve geometric structure\n", "# better than one large jump, because each intermediate model can\n", "# \"heal\" the projection distortion through minimal training.\n", "# ============================================================================\n", "\n", "import math\n", "import time\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# GEOMETRIC UTILITIES\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def cayley_menger_vol2(pts):\n", " with torch.amp.autocast(\"cuda\", enabled=False):\n", " pts = pts.float()\n", " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", " d2 = (diff * diff).sum(-1)\n", " B, V, _ = d2.shape\n", " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", " s = (-1.0)**V; f = math.factorial(V-1)\n", " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", "\n", "\n", "def pentachoron_cv(embeddings, n_samples=200):\n", " B = embeddings.shape[0]\n", " if B < 5:\n", " return 0.0\n", " vols = []\n", " for _ in range(n_samples):\n", " idx = torch.randperm(B, device=embeddings.device)[:5]\n", " v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))\n", " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", " if v > 0:\n", " vols.append(v)\n", " if len(vols) < 10:\n", " return 0.0\n", " a = np.array(vols, dtype=np.float64)\n", " return float(a.std() / max(a.mean(), 1e-12))\n", "\n", "\n", "def profile_model(model):\n", " \"\"\"Profile all linear layers: CV, effective rank.\"\"\"\n", " results = {}\n", " for i, layer in enumerate(model.get_linear_layers()):\n", " W = layer.weight.detach().float()\n", " cv = pentachoron_cv(W, n_samples=200)\n", " S = torch.linalg.svdvals(W)\n", " S_norm = S / S.sum()\n", " eff_rank = torch.exp(-torch.sum(S_norm * torch.log(S_norm + 1e-12))).item()\n", " results[f\"layer_{i}\"] = {\"cv\": cv, \"eff_rank\": eff_rank}\n", " mean_cv = np.mean([v[\"cv\"] for v in results.values()])\n", " return results, mean_cv\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# TASK: Multi-class sequence pattern recognition (harder than needle)\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class PatternTask:\n", " \"\"\"\n", " Multi-pattern classification. Each class has a distinct learned template.\n", " The model must learn to recognize WHICH pattern, not just WHERE.\n", " This forces genuine geometric restructuring during training.\n", "\n", " Input: (B, seq_len) — noisy pattern\n", " Target: (B,) — pattern class\n", " \"\"\"\n", " def __init__(self, n_classes=16, seq_len=64, noise=0.3, device=\"cpu\"):\n", " self.n_classes = n_classes\n", " self.seq_len = seq_len\n", " self.noise = noise\n", " self.device = device\n", "\n", " # Fixed random templates — each class has a unique pattern\n", " torch.manual_seed(42)\n", " self.templates = torch.randn(n_classes, seq_len, device=device)\n", " self.templates = F.normalize(self.templates, dim=-1)\n", "\n", " def generate(self, n_samples):\n", " labels = torch.randint(0, self.n_classes, (n_samples,), device=self.device)\n", " patterns = self.templates[labels]\n", " noise = torch.randn_like(patterns) * self.noise\n", " inputs = patterns + noise\n", " return inputs, labels\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# MODEL\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class PatternModel(nn.Module):\n", " def __init__(self, seq_len, hidden_dim, n_classes, n_layers=4):\n", " super().__init__()\n", " self.seq_len = seq_len\n", " self.hidden_dim = hidden_dim\n", " self.n_classes = n_classes\n", "\n", " layers = []\n", " layers.append(nn.Linear(seq_len, hidden_dim))\n", " layers.append(nn.GELU())\n", " layers.append(nn.LayerNorm(hidden_dim))\n", " for _ in range(n_layers - 2):\n", " layers.append(nn.Linear(hidden_dim, hidden_dim))\n", " layers.append(nn.GELU())\n", " layers.append(nn.LayerNorm(hidden_dim))\n", " layers.append(nn.Linear(hidden_dim, n_classes))\n", " self.network = nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " return self.network(x)\n", "\n", " def get_linear_layers(self):\n", " return [m for m in self.network.modules() if isinstance(m, nn.Linear)]\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# PROCRUSTES PROJECTION: truncated SVD\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def svd_project(W_large, out_dim, in_dim):\n", " \"\"\"Project weight matrix via truncated SVD reconstruction.\"\"\"\n", " W = W_large.float()\n", " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", " k = min(S.shape[0], out_dim, in_dim)\n", " U_k = U[:min(W.shape[0], out_dim), :k]\n", " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", " result = torch.zeros(out_dim, in_dim, device=W.device)\n", " r, c = W_small.shape\n", " result[:r, :c] = W_small\n", " return result\n", "\n", "\n", "@torch.no_grad()\n", "def transfer_weights(source, target):\n", " \"\"\"Procrustes-project all linear layers + layernorms from source → target.\"\"\"\n", " src_layers = source.get_linear_layers()\n", " tgt_layers = target.get_linear_layers()\n", "\n", " for L, S in zip(src_layers, tgt_layers):\n", " to, ti = S.weight.shape\n", " S.weight.data.copy_(svd_project(L.weight.data, to, ti))\n", "\n", " if L.bias is not None and S.bias is not None:\n", " b = L.bias.data.float()\n", " if b.shape[0] > to:\n", " U, _, _ = torch.linalg.svd(L.weight.data.float(), full_matrices=True)\n", " S.bias.data.copy_(U[:, :to].T @ b)\n", " elif b.shape[0] < to:\n", " S.bias.data.zero_()\n", " S.bias.data[:b.shape[0]].copy_(b)\n", " else:\n", " S.bias.data.copy_(b)\n", "\n", " # LayerNorms\n", " src_norms = [m for m in source.network.modules() if isinstance(m, nn.LayerNorm)]\n", " tgt_norms = [m for m in target.network.modules() if isinstance(m, nn.LayerNorm)]\n", " for ln_s, ln_t in zip(src_norms, tgt_norms):\n", " d = min(ln_s.weight.shape[0], ln_t.weight.shape[0])\n", " ln_t.weight.data[:d].copy_(ln_s.weight.data[:d])\n", " ln_t.bias.data[:d].copy_(ln_s.bias.data[:d])\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# CV-GATED TRAINING\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def train_until_cv(model, task, target_cv, cv_tolerance=0.05,\n", " max_epochs=200, lr=3e-4, batch_size=256):\n", " \"\"\"\n", " Train until mean CV reaches target ± tolerance.\n", " Returns: epochs_used, final_acc, final_cv\n", " \"\"\"\n", " device = next(model.parameters()).device\n", " train_x, train_y = task.generate(10000)\n", " test_x, test_y = task.generate(2000)\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n", "\n", " for epoch in range(max_epochs):\n", " model.train()\n", " perm = torch.randperm(train_x.shape[0], device=device)\n", " for i in range(0, train_x.shape[0], batch_size):\n", " idx = perm[i:i+batch_size]\n", " loss = F.cross_entropy(model(train_x[idx]), train_y[idx])\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " scheduler.step()\n", "\n", " # Check CV every 5 epochs\n", " if (epoch + 1) % 5 == 0 or epoch == 0:\n", " model.eval()\n", " _, mean_cv = profile_model(model)\n", " with torch.no_grad():\n", " acc = (model(test_x).argmax(-1) == test_y).float().mean().item()\n", "\n", " if abs(mean_cv - target_cv) <= cv_tolerance:\n", " return epoch + 1, acc, mean_cv\n", " if acc >= 0.99 and epoch > 20:\n", " return epoch + 1, acc, mean_cv\n", "\n", " # Max epochs reached\n", " model.eval()\n", " _, final_cv = profile_model(model)\n", " with torch.no_grad():\n", " final_acc = (model(test_x).argmax(-1) == test_y).float().mean().item()\n", " return max_epochs, final_acc, final_cv\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# EXPERIMENT\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def run_experiment():\n", " print(\"=\" * 70)\n", " print(\"ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\")\n", " print(\"=\" * 70)\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\" Device: {device}\")\n", "\n", " # ── Configuration ──\n", " SEQ_LEN = 64\n", " N_CLASSES = 16\n", " NOISE = 0.3\n", " N_LAYERS = 4\n", " SCALES = [256, 224, 192, 160, 131, 113, 97, 73, 64] # cascade\n", " CV_TOLERANCE = 0.05\n", " MAX_HEAL_EPOCHS = 500\n", "\n", " print(f\"\\n Task: {N_CLASSES}-class pattern recognition, seq_len={SEQ_LEN}, noise={NOISE}\")\n", " print(f\" Chance accuracy: {1/N_CLASSES:.1%}\")\n", " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", " print(f\" CV tolerance: ±{CV_TOLERANCE}\")\n", "\n", " task = PatternTask(N_CLASSES, SEQ_LEN, NOISE, device)\n", " test_x, test_y = task.generate(5000)\n", "\n", " # ══════════════════════════════════════════════════════════\n", " # STEP 0: Train root model (256-dim) to convergence\n", " # ══════════════════════════════════════════════════════════\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"SCALE 0: {SCALES[0]}-dim (ROOT — train from scratch)\")\n", " print(f\"{'='*70}\")\n", "\n", " root = PatternModel(SEQ_LEN, SCALES[0], N_CLASSES, N_LAYERS).to(device)\n", " n_root = sum(p.numel() for p in root.parameters())\n", " print(f\" Params: {n_root:,}\")\n", "\n", " # Profile before training\n", " _, cv_before = profile_model(root)\n", " print(f\" CV before training: {cv_before:.4f}\")\n", "\n", " epochs, acc, cv = train_until_cv(root, task, target_cv=0.20,\n", " cv_tolerance=0.1, max_epochs=200)\n", " print(f\" Trained: {epochs} epochs → acc={acc:.4f}, cv={cv:.4f}\")\n", "\n", " # Full profile\n", " profile, _ = profile_model(root)\n", " for name, stats in profile.items():\n", " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", "\n", " # ══════════════════════════════════════════════════════════\n", " # ITERATIVE CASCADE\n", " # ══════════════════════════════════════════════════════════\n", "\n", " results = [{\n", " \"scale\": SCALES[0],\n", " \"params\": n_root,\n", " \"acc_after_transfer\": acc,\n", " \"acc_after_heal\": acc,\n", " \"cv_after_transfer\": cv,\n", " \"cv_after_heal\": cv,\n", " \"heal_epochs\": epochs,\n", " \"source\": \"scratch\",\n", " }]\n", "\n", " parent_model = root\n", " parent_cv = cv\n", "\n", " for i in range(1, len(SCALES)):\n", " dim = SCALES[i]\n", " parent_dim = SCALES[i-1]\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"SCALE {i}: {parent_dim}-dim → {dim}-dim \"\n", " f\"({(parent_dim-dim)/parent_dim:.0%} reduction)\")\n", " print(f\"{'='*70}\")\n", "\n", " # Build target model\n", " child = PatternModel(SEQ_LEN, dim, N_CLASSES, N_LAYERS).to(device)\n", " n_child = sum(p.numel() for p in child.parameters())\n", " print(f\" Params: {n_child:,} ({n_child/n_root:.1%} of root)\")\n", "\n", " # ── Transfer ──\n", " print(f\"\\n Projecting {parent_dim} → {dim}...\")\n", " transfer_weights(parent_model, child)\n", "\n", " # Measure immediately after transfer (no training)\n", " child.eval()\n", " _, cv_transfer = profile_model(child)\n", " with torch.no_grad():\n", " acc_transfer = (child(test_x).argmax(-1) == test_y).float().mean().item()\n", " print(f\" After transfer: acc={acc_transfer:.4f}, cv={cv_transfer:.4f}\")\n", "\n", " child_profile, _ = profile_model(child)\n", " for name, stats in child_profile.items():\n", " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", "\n", " # ── Heal: train until CV matches parent ──\n", " print(f\"\\n Healing toward parent CV={parent_cv:.4f} (±{CV_TOLERANCE})...\")\n", " t0 = time.time()\n", " heal_epochs, acc_heal, cv_heal = train_until_cv(\n", " child, task, target_cv=parent_cv,\n", " cv_tolerance=CV_TOLERANCE, max_epochs=MAX_HEAL_EPOCHS)\n", " elapsed = time.time() - t0\n", " print(f\" Healed: {heal_epochs} epochs ({elapsed:.1f}s) → \"\n", " f\"acc={acc_heal:.4f}, cv={cv_heal:.4f}\")\n", "\n", " # Post-heal profile\n", " heal_profile, _ = profile_model(child)\n", " for name, stats in heal_profile.items():\n", " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", "\n", " results.append({\n", " \"scale\": dim,\n", " \"params\": n_child,\n", " \"acc_after_transfer\": acc_transfer,\n", " \"acc_after_heal\": acc_heal,\n", " \"cv_after_transfer\": cv_transfer,\n", " \"cv_after_heal\": cv_heal,\n", " \"heal_epochs\": heal_epochs,\n", " \"source\": f\"projected from {parent_dim}\",\n", " })\n", "\n", " # This child becomes next parent\n", " parent_model = child\n", " parent_cv = cv_heal\n", "\n", " # ══════════════════════════════════════════════════════════\n", " # BASELINE: Train 131-dim from scratch\n", " # ══════════════════════════════════════════════════════════\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"BASELINE: Train {SCALES[-1]}-dim from scratch\")\n", " print(f\"{'='*70}\")\n", "\n", " baseline = PatternModel(SEQ_LEN, SCALES[-1], N_CLASSES, N_LAYERS).to(device)\n", " n_base = sum(p.numel() for p in baseline.parameters())\n", " base_epochs, base_acc, base_cv = train_until_cv(\n", " baseline, task, target_cv=0.20,\n", " cv_tolerance=0.05, max_epochs=200)\n", " print(f\" Trained: {base_epochs} epochs → acc={base_acc:.4f}, cv={base_cv:.4f}\")\n", "\n", " # ══════════════════════════════════════════════════════════\n", " # DIRECT PROJECTION: 256 → 131 (single jump baseline)\n", " # ══════════════════════════════════════════════════════════\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"DIRECT PROJECTION: {SCALES[0]} → {SCALES[-1]} (single jump)\")\n", " print(f\"{'='*70}\")\n", "\n", " direct = PatternModel(SEQ_LEN, SCALES[-1], N_CLASSES, N_LAYERS).to(device)\n", " transfer_weights(root, direct)\n", " direct.eval()\n", " _, direct_cv = profile_model(direct)\n", " with torch.no_grad():\n", " direct_acc = (direct(test_x).argmax(-1) == test_y).float().mean().item()\n", " print(f\" After direct transfer: acc={direct_acc:.4f}, cv={direct_cv:.4f}\")\n", "\n", " # ══════════════════════════════════════════════════════════\n", " # FINAL REPORT\n", " # ══════════════════════════════════════════════════════════\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"RESULTS — ITERATIVE CASCADE\")\n", " print(f\"{'='*70}\\n\")\n", "\n", " print(f\" {'Scale':<8s} {'Params':>8s} {'Acc(proj)':>10s} {'Acc(heal)':>10s} \"\n", " f\"{'CV(proj)':>9s} {'CV(heal)':>9s} {'Epochs':>7s}\")\n", " print(f\" {'─'*8} {'─'*8} {'─'*10} {'─'*10} {'─'*9} {'─'*9} {'─'*7}\")\n", "\n", " total_heal_epochs = 0\n", " for r in results:\n", " print(f\" {r['scale']:<8d} {r['params']:>8,} {r['acc_after_transfer']:>10.4f} \"\n", " f\"{r['acc_after_heal']:>10.4f} {r['cv_after_transfer']:>9.4f} \"\n", " f\"{r['cv_after_heal']:>9.4f} {r['heal_epochs']:>7d}\")\n", " if r['source'] != 'scratch':\n", " total_heal_epochs += r['heal_epochs']\n", "\n", " print(f\"\\n {'COMPARISONS':}\")\n", " print(f\" {'─'*60}\")\n", " print(f\" Cascade {SCALES[-1]}-dim: acc={results[-1]['acc_after_heal']:.4f} \"\n", " f\"cv={results[-1]['cv_after_heal']:.4f} \"\n", " f\"(total {total_heal_epochs} heal epochs)\")\n", " print(f\" Direct proj {SCALES[-1]}-dim: acc={direct_acc:.4f} \"\n", " f\"cv={direct_cv:.4f} (0 training)\")\n", " print(f\" Scratch {SCALES[-1]}-dim: acc={base_acc:.4f} \"\n", " f\"cv={base_cv:.4f} ({base_epochs} epochs)\")\n", " print(f\" Chance: acc={1/N_CLASSES:.4f}\")\n", "\n", " print(f\"\\n COMPRESSION:\")\n", " print(f\" Root: {n_root:>8,} params\")\n", " print(f\" Target: {n_child:>8,} params ({n_child/n_root:.1%})\")\n", " print(f\" Ratio: {n_root/n_child:.1f}×\")\n", "\n", " # ── Geometric preservation ──\n", " root_cv = results[0][\"cv_after_heal\"]\n", " final_cv = results[-1][\"cv_after_heal\"]\n", " print(f\"\\n GEOMETRIC PRESERVATION:\")\n", " print(f\" Root CV: {root_cv:.4f}\")\n", " print(f\" Final CV: {final_cv:.4f}\")\n", " print(f\" Δ CV: {abs(root_cv - final_cv):.4f}\")\n", " print(f\" Direct CV: {direct_cv:.4f}\")\n", " print(f\" Scratch CV: {base_cv:.4f}\")\n", "\n", " print(f\"\\nDone.\")\n", " return results\n", "\n", "\n", "if __name__ == \"__main__\":\n", " results = run_experiment()" ] }, { "cell_type": "code", "source": [ "# ============================================================================\n", "# OPTIMAL SCALING RATIO EXPERIMENT\n", "#\n", "# Sweep: What ratio between consecutive scales minimizes accuracy loss\n", "# while maximizing compression?\n", "#\n", "# For each ratio r ∈ {0.50, 0.55, 0.60, 0.618, 0.65, 0.70, 0.707, 0.75,\n", "# 0.80, 0.85, 0.90, 0.95}:\n", "# - Build cascade from 256 → 64 using steps of dim[i+1] = round(dim[i] * r)\n", "# - Apply iterative project + 1-epoch heal at each step\n", "# - Measure: final accuracy, total heal epochs, CV preservation\n", "#\n", "# Natural candidates:\n", "# φ⁻¹ = 0.6180 (golden ratio inverse — nature's scaling constant)\n", "# 2⁻⁰·⁵ = 0.7071 (inverse sqrt 2 — octave halving)\n", "# 1-0.29514 = 0.7049 (Phil's recurring ratio complement)\n", "# e⁻¹ = 0.3679 (too aggressive, but worth checking)\n", "# ============================================================================\n", "\n", "import math\n", "import time\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# GEOMETRIC UTILITIES\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def cayley_menger_vol2(pts):\n", " with torch.amp.autocast(\"cuda\", enabled=False):\n", " pts = pts.float()\n", " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", " d2 = (diff * diff).sum(-1)\n", " B, V, _ = d2.shape\n", " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", " s = (-1.0)**V; f = math.factorial(V-1)\n", " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", "\n", "\n", "def pentachoron_cv(embeddings, n_samples=100):\n", " B = embeddings.shape[0]\n", " if B < 5:\n", " return 0.0\n", " vols = []\n", " for _ in range(n_samples):\n", " idx = torch.randperm(B, device=embeddings.device)[:5]\n", " v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))\n", " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", " if v > 0:\n", " vols.append(v)\n", " if len(vols) < 10:\n", " return 0.0\n", " a = np.array(vols, dtype=np.float64)\n", " return float(a.std() / max(a.mean(), 1e-12))\n", "\n", "\n", "def profile_model(model):\n", " results = {}\n", " for i, layer in enumerate(model.get_linear_layers()):\n", " W = layer.weight.detach().float()\n", " cv = pentachoron_cv(W, n_samples=100)\n", " results[f\"layer_{i}\"] = {\"cv\": cv}\n", " mean_cv = np.mean([v[\"cv\"] for v in results.values()])\n", " return mean_cv\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# TASK\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class PatternTask:\n", " def __init__(self, n_classes=16, seq_len=64, noise=0.3, device=\"cpu\"):\n", " self.n_classes = n_classes\n", " self.seq_len = seq_len\n", " self.noise = noise\n", " self.device = device\n", " torch.manual_seed(42)\n", " self.templates = F.normalize(\n", " torch.randn(n_classes, seq_len, device=device), dim=-1)\n", "\n", " def generate(self, n_samples):\n", " labels = torch.randint(0, self.n_classes, (n_samples,), device=self.device)\n", " patterns = self.templates[labels]\n", " return patterns + torch.randn_like(patterns) * self.noise, labels\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# MODEL\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class PatternModel(nn.Module):\n", " def __init__(self, seq_len, hidden_dim, n_classes, n_layers=4):\n", " super().__init__()\n", " layers = []\n", " layers.append(nn.Linear(seq_len, hidden_dim))\n", " layers.append(nn.GELU())\n", " layers.append(nn.LayerNorm(hidden_dim))\n", " for _ in range(n_layers - 2):\n", " layers.append(nn.Linear(hidden_dim, hidden_dim))\n", " layers.append(nn.GELU())\n", " layers.append(nn.LayerNorm(hidden_dim))\n", " layers.append(nn.Linear(hidden_dim, n_classes))\n", " self.network = nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " return self.network(x)\n", "\n", " def get_linear_layers(self):\n", " return [m for m in self.network.modules() if isinstance(m, nn.Linear)]\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# PROJECTION + HEALING\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def svd_project(W_large, out_dim, in_dim):\n", " W = W_large.float()\n", " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", " k = min(S.shape[0], out_dim, in_dim)\n", " U_k = U[:min(W.shape[0], out_dim), :k]\n", " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", " result = torch.zeros(out_dim, in_dim, device=W.device)\n", " r, c = W_small.shape\n", " result[:r, :c] = W_small\n", " return result\n", "\n", "\n", "@torch.no_grad()\n", "def transfer_weights(source, target):\n", " src_layers = source.get_linear_layers()\n", " tgt_layers = target.get_linear_layers()\n", " for L, Sm in zip(src_layers, tgt_layers):\n", " to, ti = Sm.weight.shape\n", " Sm.weight.data.copy_(svd_project(L.weight.data, to, ti))\n", " if L.bias is not None and Sm.bias is not None:\n", " b = L.bias.data.float()\n", " if b.shape[0] > to:\n", " U, _, _ = torch.linalg.svd(L.weight.data.float(), full_matrices=True)\n", " Sm.bias.data.copy_(U[:, :to].T @ b)\n", " elif b.shape[0] < to:\n", " Sm.bias.data.zero_()\n", " Sm.bias.data[:b.shape[0]].copy_(b)\n", " else:\n", " Sm.bias.data.copy_(b)\n", " src_norms = [m for m in source.network.modules() if isinstance(m, nn.LayerNorm)]\n", " tgt_norms = [m for m in target.network.modules() if isinstance(m, nn.LayerNorm)]\n", " for ln_s, ln_t in zip(src_norms, tgt_norms):\n", " d = min(ln_s.weight.shape[0], ln_t.weight.shape[0])\n", " ln_t.weight.data[:d].copy_(ln_s.weight.data[:d])\n", " ln_t.bias.data[:d].copy_(ln_s.bias.data[:d])\n", "\n", "\n", "def heal_one_epoch(model, task, batch_size=256, lr=3e-4):\n", " \"\"\"Single healing epoch. Returns accuracy after.\"\"\"\n", " device = next(model.parameters()).device\n", " train_x, train_y = task.generate(10000)\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", " model.train()\n", " perm = torch.randperm(train_x.shape[0], device=device)\n", " for i in range(0, train_x.shape[0], batch_size):\n", " idx = perm[i:i+batch_size]\n", " loss = F.cross_entropy(model(train_x[idx]), train_y[idx])\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " return model\n", "\n", "\n", "def evaluate(model, test_x, test_y):\n", " model.eval()\n", " with torch.no_grad():\n", " return (model(test_x).argmax(-1) == test_y).float().mean().item()\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# BUILD SCALE CASCADE FOR A GIVEN RATIO\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def build_cascade(start_dim, end_dim, ratio):\n", " \"\"\"Generate dimension sequence: start, round(start*r), round(start*r²), ... until ≤ end.\"\"\"\n", " dims = [start_dim]\n", " while True:\n", " next_dim = max(round(dims[-1] * ratio), end_dim)\n", " if next_dim >= dims[-1]:\n", " # Ratio too close to 1, force a step down\n", " next_dim = dims[-1] - 1\n", " if next_dim <= end_dim:\n", " if dims[-1] != end_dim:\n", " dims.append(end_dim)\n", " break\n", " dims.append(next_dim)\n", " return dims\n", "\n", "\n", "def run_cascade(root_model, task, test_x, test_y, scales, device):\n", " \"\"\"\n", " Run a full cascade: project + heal at each scale.\n", " Returns: final_acc, total_heal_epochs, per-step data.\n", " \"\"\"\n", " parent = root_model\n", " steps = []\n", " total_epochs = 0\n", "\n", " for i in range(1, len(scales)):\n", " dim = scales[i]\n", " child = PatternModel(task.seq_len, dim, task.n_classes, 4).to(device)\n", " transfer_weights(parent, child)\n", "\n", " acc_proj = evaluate(child, test_x, test_y)\n", " cv_proj = profile_model(child)\n", "\n", " # 1 healing epoch\n", " heal_one_epoch(child, task)\n", " total_epochs += 1\n", "\n", " acc_heal = evaluate(child, test_x, test_y)\n", " cv_heal = profile_model(child)\n", "\n", " steps.append({\n", " \"from\": scales[i-1], \"to\": dim,\n", " \"acc_proj\": acc_proj, \"acc_heal\": acc_heal,\n", " \"cv_proj\": cv_proj, \"cv_heal\": cv_heal,\n", " })\n", "\n", " parent = child\n", "\n", " return acc_heal, total_epochs, cv_heal, steps\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# EXPERIMENT\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def run_experiment():\n", " print(\"=\" * 70)\n", " print(\"OPTIMAL SCALING RATIO EXPERIMENT\")\n", " print(\"=\" * 70)\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\" Device: {device}\")\n", "\n", " # Config\n", " SEQ_LEN = 64\n", " N_CLASSES = 16\n", " NOISE = 0.3\n", " START_DIM = 256\n", " END_DIM = 64\n", " N_LAYERS = 4\n", "\n", " # Named ratios\n", " PHI_INV = 1.0 / ((1 + math.sqrt(5)) / 2) # 0.6180\n", " SQRT2_INV = 1.0 / math.sqrt(2) # 0.7071\n", " PHIL_COMP = 1.0 - 0.29514 # 0.7049\n", " E_INV = 1.0 / math.e # 0.3679\n", "\n", " RATIOS = [\n", " (0.50, \"0.500 (halving)\"),\n", " (0.55, \"0.550\"),\n", " (0.60, \"0.600\"),\n", " (PHI_INV, f\"0.618 (1/φ golden)\"),\n", " (0.65, \"0.650\"),\n", " (0.70, \"0.700\"),\n", " (PHIL_COMP, f\"0.705 (1-0.295)\"),\n", " (SQRT2_INV, f\"0.707 (1/√2)\"),\n", " (0.75, \"0.750\"),\n", " (0.80, \"0.800\"),\n", " (0.85, \"0.850\"),\n", " (0.90, \"0.900\"),\n", " (0.95, \"0.950\"),\n", " ]\n", "\n", " task = PatternTask(N_CLASSES, SEQ_LEN, NOISE, device)\n", " test_x, test_y = task.generate(5000)\n", "\n", " print(f\"\\n Task: {N_CLASSES}-class, seq_len={SEQ_LEN}, noise={NOISE}\")\n", " print(f\" Compression: {START_DIM} → {END_DIM}\")\n", " print(f\" Testing {len(RATIOS)} scaling ratios\")\n", "\n", " # ── Train root model ──\n", " print(f\"\\n Training root model ({START_DIM}-dim)...\")\n", " root = PatternModel(SEQ_LEN, START_DIM, N_CLASSES, N_LAYERS).to(device)\n", " optimizer = torch.optim.AdamW(root.parameters(), lr=3e-4)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)\n", " train_x, train_y = task.generate(10000)\n", "\n", " for epoch in range(200):\n", " root.train()\n", " perm = torch.randperm(10000, device=device)\n", " for i in range(0, 10000, 256):\n", " idx = perm[i:i+256]\n", " loss = F.cross_entropy(root(train_x[idx]), train_y[idx])\n", " optimizer.zero_grad(); loss.backward(); optimizer.step()\n", " scheduler.step()\n", " if (epoch+1) % 50 == 0:\n", " acc = evaluate(root, test_x, test_y)\n", " print(f\" Epoch {epoch+1}: acc={acc:.4f}\")\n", "\n", " root_acc = evaluate(root, test_x, test_y)\n", " root_cv = profile_model(root)\n", " print(f\" Root: acc={root_acc:.4f}, cv={root_cv:.4f}\")\n", "\n", " # ── Sweep ratios ──\n", " print(f\"\\n{'='*70}\")\n", " print(\"RATIO SWEEP\")\n", " print(f\"{'='*70}\\n\")\n", "\n", " results = []\n", "\n", " for ratio, name in RATIOS:\n", " scales = build_cascade(START_DIM, END_DIM, ratio)\n", " n_steps = len(scales) - 1\n", "\n", " t0 = time.time()\n", " final_acc, total_epochs, final_cv, steps = run_cascade(\n", " root, task, test_x, test_y, scales, device)\n", " elapsed = time.time() - t0\n", "\n", " # Compute efficiency metric: accuracy retained per heal epoch\n", " acc_retained = final_acc / max(root_acc, 1e-8)\n", " efficiency = acc_retained / max(total_epochs, 1)\n", "\n", " result = {\n", " \"ratio\": ratio,\n", " \"name\": name,\n", " \"scales\": scales,\n", " \"n_steps\": n_steps,\n", " \"final_acc\": final_acc,\n", " \"final_cv\": final_cv,\n", " \"total_epochs\": total_epochs,\n", " \"acc_retained\": acc_retained,\n", " \"efficiency\": efficiency,\n", " \"elapsed\": elapsed,\n", " \"steps\": steps,\n", " }\n", " results.append(result)\n", "\n", " scale_str = \"→\".join(str(s) for s in scales)\n", " print(f\" r={ratio:.3f} ({name:20s}): {n_steps} steps \"\n", " f\"acc={final_acc:.4f} ret={acc_retained:.1%} \"\n", " f\"cv={final_cv:.4f} eff={efficiency:.4f} \"\n", " f\"[{scale_str}]\")\n", "\n", " # ── Direct jump baseline ──\n", " direct = PatternModel(SEQ_LEN, END_DIM, N_CLASSES, N_LAYERS).to(device)\n", " transfer_weights(root, direct)\n", " direct_acc = evaluate(direct, test_x, test_y)\n", " heal_one_epoch(direct, task)\n", " direct_heal_acc = evaluate(direct, test_x, test_y)\n", "\n", " # ══════════════════════════════════════════════════════════════\n", " # ANALYSIS\n", " # ══════════════════════════════════════════════════════════════\n", "\n", " # Sort by final accuracy\n", " results.sort(key=lambda x: x[\"final_acc\"], reverse=True)\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(\"RESULTS — SORTED BY ACCURACY\")\n", " print(f\"{'='*70}\\n\")\n", "\n", " print(f\" {'Ratio':<22s} {'Steps':>5s} {'Acc':>7s} {'Retained':>9s} \"\n", " f\"{'CV':>7s} {'Epochs':>6s} {'Eff':>7s}\")\n", " print(f\" {'─'*22} {'─'*5} {'─'*7} {'─'*9} {'─'*7} {'─'*6} {'─'*7}\")\n", "\n", " for r in results:\n", " marker = \" ★\" if \"golden\" in r[\"name\"] or \"0.295\" in r[\"name\"] or \"√2\" in r[\"name\"] else \"\"\n", " print(f\" {r['name']:<22s} {r['n_steps']:>5d} {r['final_acc']:>7.4f} \"\n", " f\"{r['acc_retained']:>8.1%} {r['final_cv']:>7.4f} \"\n", " f\"{r['total_epochs']:>6d} {r['efficiency']:>7.4f}{marker}\")\n", "\n", " print(f\"\\n {'Direct 256→64':22s} {'1':>5s} {direct_heal_acc:>7.4f} \"\n", " f\"{direct_heal_acc/root_acc:>8.1%} {'—':>7s} {'1':>6s}\")\n", " print(f\" {'Root (256)':22s} {'—':>5s} {root_acc:>7.4f} \"\n", " f\"{'100.0%':>9s} {root_cv:>7.4f} {'200':>6s}\")\n", "\n", " # ── Find optimal ──\n", " best = results[0]\n", " print(f\"\\n OPTIMAL RATIO: {best['name']}\")\n", " print(f\" Accuracy: {best['final_acc']:.4f} ({best['acc_retained']:.1%} retained)\")\n", " print(f\" Steps: {best['n_steps']}\")\n", " print(f\" Scales: {'→'.join(str(s) for s in best['scales'])}\")\n", " print(f\" CV: {best['final_cv']:.4f} (root: {root_cv:.4f})\")\n", "\n", " # ── Natural constant comparison ──\n", " phi_result = next(r for r in results if \"golden\" in r[\"name\"])\n", " sqrt2_result = next(r for r in results if \"√2\" in r[\"name\"])\n", " phil_result = next(r for r in results if \"0.295\" in r[\"name\"])\n", "\n", " print(f\"\\n NATURAL CONSTANTS:\")\n", " print(f\" 1/φ (0.618): acc={phi_result['final_acc']:.4f} \"\n", " f\"steps={phi_result['n_steps']} scales={'→'.join(str(s) for s in phi_result['scales'])}\")\n", " print(f\" 1/√2 (0.707): acc={sqrt2_result['final_acc']:.4f} \"\n", " f\"steps={sqrt2_result['n_steps']} scales={'→'.join(str(s) for s in sqrt2_result['scales'])}\")\n", " print(f\" 1-0.295(0.705): acc={phil_result['final_acc']:.4f} \"\n", " f\"steps={phil_result['n_steps']} scales={'→'.join(str(s) for s in phil_result['scales'])}\")\n", "\n", " # ── Pareto analysis: accuracy vs training cost ──\n", " print(f\"\\n PARETO FRONTIER (accuracy vs epochs):\")\n", " print(f\" {'─'*50}\")\n", " pareto = []\n", " best_acc_so_far = 0\n", " for r in sorted(results, key=lambda x: x[\"total_epochs\"]):\n", " if r[\"final_acc\"] > best_acc_so_far:\n", " best_acc_so_far = r[\"final_acc\"]\n", " pareto.append(r)\n", " print(f\" {r['total_epochs']:2d} epochs → {r['final_acc']:.4f} \"\n", " f\"({r['name']})\")\n", "\n", " print(f\"\\nDone.\")\n", " return results\n", "\n", "\n", "if __name__ == \"__main__\":\n", " results = run_experiment()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gZk3JSs4UCke", "outputId": "2e8b6628-056e-44f0-c168-931134031e84" }, "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "======================================================================\n", "OPTIMAL SCALING RATIO EXPERIMENT\n", "======================================================================\n", " Device: cuda\n", "\n", " Task: 16-class, seq_len=64, noise=0.3\n", " Compression: 256 → 64\n", " Testing 13 scaling ratios\n", "\n", " Training root model (256-dim)...\n", " Epoch 50: acc=0.8788\n", " Epoch 100: acc=0.8766\n", " Epoch 150: acc=0.8746\n", " Epoch 200: acc=0.8746\n", " Root: acc=0.8746, cv=0.0929\n", "\n", "======================================================================\n", "RATIO SWEEP\n", "======================================================================\n", "\n", " r=0.500 (0.500 (halving) ): 2 steps acc=0.7724 ret=88.3% cv=0.1617 eff=0.4416 [256→128→64]\n", " r=0.550 (0.550 ): 3 steps acc=0.8236 ret=94.2% cv=0.1478 eff=0.3139 [256→141→78→64]\n", " r=0.600 (0.600 ): 3 steps acc=0.8148 ret=93.2% cv=0.1406 eff=0.3105 [256→154→92→64]\n", " r=0.618 (0.618 (1/φ golden) ): 3 steps acc=0.8114 ret=92.8% cv=0.1378 eff=0.3092 [256→158→98→64]\n", " r=0.650 (0.650 ): 4 steps acc=0.8416 ret=96.2% cv=0.1357 eff=0.2406 [256→166→108→70→64]\n", " r=0.700 (0.700 ): 4 steps acc=0.8248 ret=94.3% cv=0.1418 eff=0.2358 [256→179→125→88→64]\n", " r=0.705 (0.705 (1-0.295) ): 4 steps acc=0.8232 ret=94.1% cv=0.1378 eff=0.2353 [256→180→127→90→64]\n", " r=0.707 (0.707 (1/√2) ): 4 steps acc=0.8194 ret=93.7% cv=0.1404 eff=0.2342 [256→181→128→91→64]\n", " r=0.750 (0.750 ): 5 steps acc=0.8350 ret=95.5% cv=0.1483 eff=0.1909 [256→192→144→108→81→64]\n", " r=0.800 (0.800 ): 7 steps acc=0.8626 ret=98.6% cv=0.1432 eff=0.1409 [256→205→164→131→105→84→67→64]\n", " r=0.850 (0.850 ): 9 steps acc=0.8680 ret=99.2% cv=0.1400 eff=0.1103 [256→218→185→157→133→113→96→82→70→64]\n", " r=0.900 (0.900 ): 14 steps acc=0.8852 ret=101.2% cv=0.1316 eff=0.0723 [256→230→207→186→167→150→135→122→110→99→89→80→72→65→64]\n", " r=0.950 (0.950 ): 27 steps acc=0.8916 ret=101.9% cv=0.1318 eff=0.0378 [256→243→231→219→208→198→188→179→170→162→154→146→139→132→125→119→113→107→102→97→92→87→83→79→75→71→67→64]\n", "\n", "======================================================================\n", "RESULTS — SORTED BY ACCURACY\n", "======================================================================\n", "\n", " Ratio Steps Acc Retained CV Epochs Eff\n", " ────────────────────── ───── ─────── ───────── ─────── ────── ───────\n", " 0.950 27 0.8916 101.9% 0.1318 27 0.0378\n", " 0.900 14 0.8852 101.2% 0.1316 14 0.0723\n", " 0.850 9 0.8680 99.2% 0.1400 9 0.1103\n", " 0.800 7 0.8626 98.6% 0.1432 7 0.1409\n", " 0.650 4 0.8416 96.2% 0.1357 4 0.2406\n", " 0.750 5 0.8350 95.5% 0.1483 5 0.1909\n", " 0.700 4 0.8248 94.3% 0.1418 4 0.2358\n", " 0.550 3 0.8236 94.2% 0.1478 3 0.3139\n", " 0.705 (1-0.295) 4 0.8232 94.1% 0.1378 4 0.2353 ★\n", " 0.707 (1/√2) 4 0.8194 93.7% 0.1404 4 0.2342 ★\n", " 0.600 3 0.8148 93.2% 0.1406 3 0.3105\n", " 0.618 (1/φ golden) 3 0.8114 92.8% 0.1378 3 0.3092 ★\n", " 0.500 (halving) 2 0.7724 88.3% 0.1617 2 0.4416\n", "\n", " Direct 256→64 1 0.7426 84.9% — 1\n", " Root (256) — 0.8746 100.0% 0.0929 200\n", "\n", " OPTIMAL RATIO: 0.950\n", " Accuracy: 0.8916 (101.9% retained)\n", " Steps: 27\n", " Scales: 256→243→231→219→208→198→188→179→170→162→154→146→139→132→125→119→113→107→102→97→92→87→83→79→75→71→67→64\n", " CV: 0.1318 (root: 0.0929)\n", "\n", " NATURAL CONSTANTS:\n", " 1/φ (0.618): acc=0.8114 steps=3 scales=256→158→98→64\n", " 1/√2 (0.707): acc=0.8194 steps=4 scales=256→181→128→91→64\n", " 1-0.295(0.705): acc=0.8232 steps=4 scales=256→180→127→90→64\n", "\n", " PARETO FRONTIER (accuracy vs epochs):\n", " ──────────────────────────────────────────────────\n", " 2 epochs → 0.7724 (0.500 (halving))\n", " 3 epochs → 0.8236 (0.550)\n", " 4 epochs → 0.8416 (0.650)\n", " 7 epochs → 0.8626 (0.800)\n", " 9 epochs → 0.8680 (0.850)\n", " 14 epochs → 0.8852 (0.900)\n", " 27 epochs → 0.8916 (0.950)\n", "\n", "Done.\n" ] } ] }, { "cell_type": "markdown", "source": [ "# bert rescaling" ], "metadata": { "id": "0wpUVBCiXmJg" } }, { "cell_type": "code", "source": [ "# ============================================================================\n", "# ITERATIVE GEOMETRIC CASCADE ON PRETRAINED BERT\n", "#\n", "# Take BERT-base (768-dim, 12 layers, 110M params) and cascade it down:\n", "# 768 → 672 → 576 → 480 → 384\n", "#\n", "# At each scale:\n", "# 1. SVD-project ALL weight matrices from parent\n", "# 2. Evaluate MLM accuracy (can it still predict masked words?)\n", "# 3. Optionally heal with 1 epoch of MLM\n", "# 4. Project to next scale\n", "#\n", "# No fine-tuning on any downstream task. Pure compression of pretrained\n", "# knowledge via geometric projection.\n", "# ============================================================================\n", "\n", "import math\n", "import time\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "from dataclasses import dataclass\n", "from typing import Dict, List, Tuple, Optional\n", "from transformers import (\n", " BertForMaskedLM, BertTokenizer, BertConfig,\n", " DataCollatorForLanguageModeling\n", ")\n", "from datasets import load_dataset\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# GEOMETRIC UTILITIES\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def cayley_menger_vol2(pts):\n", " with torch.amp.autocast(\"cuda\", enabled=False):\n", " pts = pts.float()\n", " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", " d2 = (diff * diff).sum(-1)\n", " B, V, _ = d2.shape\n", " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", " s = (-1.0)**V; f = math.factorial(V-1)\n", " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", "\n", "\n", "def pentachoron_cv(W, n_samples=200):\n", " \"\"\"CV on weight matrix rows.\"\"\"\n", " if W.dim() != 2 or W.shape[0] < 5:\n", " return 0.0\n", " B = W.shape[0]\n", " vols = []\n", " for _ in range(n_samples):\n", " idx = torch.randperm(B, device=W.device)[:5]\n", " v2 = cayley_menger_vol2(W[idx].unsqueeze(0))\n", " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", " if v > 0:\n", " vols.append(v)\n", " if len(vols) < 10:\n", " return 0.0\n", " a = np.array(vols, dtype=np.float64)\n", " return float(a.std() / max(a.mean(), 1e-12))\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# SVD PROJECTION\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def svd_project_matrix(W, out_dim, in_dim):\n", " \"\"\"Project weight matrix via truncated SVD.\"\"\"\n", " W = W.float()\n", " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", " k = min(S.shape[0], out_dim, in_dim)\n", " U_k = U[:min(W.shape[0], out_dim), :k]\n", " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", " result = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", " r, c = W_small.shape\n", " result[:r, :c] = W_small\n", " return result\n", "\n", "\n", "@torch.no_grad()\n", "def svd_project_vector(v, target_dim):\n", " \"\"\"Project 1D vector (bias, layernorm) by truncation or padding.\"\"\"\n", " if v.shape[0] == target_dim:\n", " return v.clone()\n", " elif v.shape[0] > target_dim:\n", " return v[:target_dim].clone()\n", " else:\n", " result = torch.zeros(target_dim, dtype=v.dtype, device=v.device)\n", " result[:v.shape[0]] = v\n", " return result\n", "\n", "\n", "@torch.no_grad()\n", "def svd_project_embedding(E, target_dim):\n", " \"\"\"Project embedding matrix (vocab_size, hidden) → (vocab_size, target_dim).\"\"\"\n", " E = E.float()\n", " # Keep all vocab rows, reduce hidden dim via SVD on the embedding matrix\n", " U, S, Vt = torch.linalg.svd(E, full_matrices=False)\n", " k = min(S.shape[0], target_dim)\n", " # Reconstruct at reduced dimension\n", " projected = U[:, :k] @ torch.diag(S[:k]) @ Vt[:k, :target_dim]\n", " if projected.shape[1] < target_dim:\n", " result = torch.zeros(E.shape[0], target_dim, dtype=E.dtype, device=E.device)\n", " result[:, :projected.shape[1]] = projected\n", " return result\n", " return projected\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# BERT WEIGHT TRANSFER\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def create_scaled_bert(source_model, target_hidden, target_intermediate, device):\n", " \"\"\"\n", " Create a new BERT with smaller hidden/intermediate dims,\n", " SVD-projecting all weights from source.\n", " \"\"\"\n", " src_config = source_model.config\n", " src_hidden = src_config.hidden_size\n", " src_inter = src_config.intermediate_size\n", " n_heads = src_config.num_attention_heads\n", " head_dim = target_hidden // n_heads\n", "\n", " # New config\n", " new_config = BertConfig(\n", " vocab_size=src_config.vocab_size,\n", " hidden_size=target_hidden,\n", " num_hidden_layers=src_config.num_hidden_layers,\n", " num_attention_heads=n_heads,\n", " intermediate_size=target_intermediate,\n", " max_position_embeddings=src_config.max_position_embeddings,\n", " type_vocab_size=src_config.type_vocab_size,\n", " hidden_act=src_config.hidden_act,\n", " hidden_dropout_prob=0.0,\n", " attention_probs_dropout_prob=0.0,\n", " )\n", "\n", " new_model = BertForMaskedLM(new_config).to(device)\n", " src_sd = source_model.state_dict()\n", " new_sd = new_model.state_dict()\n", "\n", " transferred = {}\n", "\n", " for name, param in new_sd.items():\n", " if name not in src_sd:\n", " continue\n", " src_p = src_sd[name].to(device)\n", "\n", " if src_p.shape == param.shape:\n", " transferred[name] = src_p.clone()\n", " elif src_p.dim() == 2:\n", " transferred[name] = svd_project_matrix(\n", " src_p, param.shape[0], param.shape[1])\n", " elif src_p.dim() == 1:\n", " transferred[name] = svd_project_vector(src_p, param.shape[0])\n", " else:\n", " # Skip or pad higher-dim tensors\n", " transferred[name] = param.clone()\n", "\n", " # Load transferred weights\n", " missing, unexpected = new_model.load_state_dict(transferred, strict=False)\n", " n_transferred = len(transferred)\n", " n_total = len(new_sd)\n", " print(f\" Transferred {n_transferred}/{n_total} params, \"\n", " f\"{len(missing)} missing, {len(unexpected)} unexpected\")\n", "\n", " return new_model\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# PROFILING\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def profile_bert(model, tag=\"\"):\n", " \"\"\"Profile CV of attention and FFN weight matrices.\"\"\"\n", " cvs = []\n", " for name, param in model.named_parameters():\n", " if param.dim() == 2 and param.shape[0] >= 5 and param.shape[1] >= 5:\n", " if \"weight\" in name and (\"dense\" in name or \"query\" in name\n", " or \"key\" in name or \"value\" in name):\n", " cv = pentachoron_cv(param.detach(), n_samples=100)\n", " cvs.append(cv)\n", " mean_cv = np.mean(cvs) if cvs else 0.0\n", " n_params = sum(p.numel() for p in model.parameters())\n", " if tag:\n", " print(f\" [{tag}] {n_params:,} params, mean CV={mean_cv:.4f} \"\n", " f\"(across {len(cvs)} weight matrices)\")\n", " return mean_cv, n_params\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# EVALUATION: MLM accuracy on short stories\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def evaluate_mlm(model, tokenizer, texts, device, mask_prob=0.15, max_len=128):\n", " \"\"\"\n", " Mask random tokens, see if model predicts them correctly.\n", " Returns: top-1 accuracy, top-5 accuracy.\n", " \"\"\"\n", " model.eval()\n", " total_correct_1 = 0\n", " total_correct_5 = 0\n", " total_masked = 0\n", "\n", " for text in texts:\n", " tokens = tokenizer(text, return_tensors=\"pt\", max_length=max_len,\n", " truncation=True, padding=False).to(device)\n", " input_ids = tokens[\"input_ids\"][0]\n", " seq_len = input_ids.shape[0]\n", "\n", " if seq_len < 5:\n", " continue\n", "\n", " # Create masks (skip [CLS], [SEP], [PAD])\n", " special_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)\n", " special_mask[0] = True # CLS\n", " special_mask[seq_len - 1] = True # SEP\n", " special_mask[input_ids == tokenizer.pad_token_id] = True\n", "\n", " maskable = ~special_mask\n", " n_mask = max(1, int(maskable.sum().item() * mask_prob))\n", " mask_positions = maskable.nonzero(as_tuple=True)[0]\n", " if len(mask_positions) == 0:\n", " continue\n", " chosen = mask_positions[torch.randperm(len(mask_positions))[:n_mask]]\n", "\n", " # Save originals\n", " original_ids = input_ids[chosen].clone()\n", "\n", " # Mask\n", " masked_ids = input_ids.clone()\n", " masked_ids[chosen] = tokenizer.mask_token_id\n", "\n", " # Forward\n", " outputs = model(masked_ids.unsqueeze(0),\n", " attention_mask=tokens[\"attention_mask\"])\n", " logits = outputs.logits[0, chosen] # (n_mask, vocab_size)\n", "\n", " # Top-1\n", " preds = logits.argmax(dim=-1)\n", " total_correct_1 += (preds == original_ids).sum().item()\n", "\n", " # Top-5\n", " top5 = logits.topk(5, dim=-1).indices\n", " total_correct_5 += (top5 == original_ids.unsqueeze(-1)).any(dim=-1).sum().item()\n", "\n", " total_masked += n_mask\n", "\n", " if total_masked == 0:\n", " return 0.0, 0.0\n", " return total_correct_1 / total_masked, total_correct_5 / total_masked\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# HEAL: minimal MLM training\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def heal_mlm(model, tokenizer, texts, device, n_epochs=1,\n", " lr=5e-5, max_len=128, batch_size=16):\n", " \"\"\"Quick MLM training to heal projection distortion.\"\"\"\n", " model.train()\n", "\n", " # Tokenize\n", " encodings = tokenizer(texts, max_length=max_len, truncation=True,\n", " padding=\"max_length\", return_tensors=\"pt\")\n", " dataset_ids = encodings[\"input_ids\"]\n", " dataset_mask = encodings[\"attention_mask\"]\n", "\n", " collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n", "\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", " n_samples = dataset_ids.shape[0]\n", "\n", " total_loss = 0\n", " n_batches = 0\n", "\n", " for epoch in range(n_epochs):\n", " perm = torch.randperm(n_samples)\n", " for i in range(0, n_samples, batch_size):\n", " idx = perm[i:i+batch_size]\n", " batch_ids = dataset_ids[idx]\n", " batch_mask = dataset_mask[idx]\n", "\n", " # Manual masking\n", " collated = collator([{\"input_ids\": ids, \"attention_mask\": m}\n", " for ids, m in zip(batch_ids, batch_mask)])\n", "\n", " input_ids = collated[\"input_ids\"].to(device)\n", " attention_mask = collated[\"attention_mask\"].to(device)\n", " labels = collated[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels)\n", " loss = outputs.loss\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " total_loss += loss.item()\n", " n_batches += 1\n", "\n", " return total_loss / max(n_batches, 1)\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# EXPERIMENT\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def run_experiment():\n", " print(\"=\" * 70)\n", " print(\"ITERATIVE CASCADE ON PRETRAINED BERT-BASE\")\n", " print(\"=\" * 70)\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\" Device: {device}\")\n", "\n", " # ── Configuration ──\n", " # hidden_size must be divisible by num_heads=12\n", " # 768/12=64, 672/12=56, 576/12=48, 480/12=40, 384/12=32\n", " SCALES = [768, 720, 672, 624, 576, 528, 480, 432, 384]\n", "\n", " # intermediate scales proportionally: 3072 → ...\n", " INTER_SCALES = [3072, 2880, 2688, 2496, 2304, 2112, 1920, 1728, 1536]\n", " N_EVAL_TEXTS = 200\n", " N_HEAL_TEXTS = 5000\n", " HEAL_EPOCHS = 5\n", "\n", " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", " print(f\" Compression: {SCALES[0]} → {SCALES[-1]} \"\n", " f\"({SCALES[-1]/SCALES[0]:.0%})\")\n", "\n", " # ── Load data ──\n", " print(f\"\\n Loading evaluation data...\")\n", " ds = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\")\n", " eval_texts = [r[\"text\"].strip() for r in ds if len(r[\"text\"].strip()) > 50]\n", " eval_texts = eval_texts[:N_EVAL_TEXTS]\n", " print(f\" {len(eval_texts)} eval texts\")\n", "\n", " ds_train = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"train\")\n", " heal_texts = [r[\"text\"].strip() for r in ds_train if len(r[\"text\"].strip()) > 100]\n", " heal_texts = heal_texts[:N_HEAL_TEXTS]\n", " print(f\" {len(heal_texts)} heal texts\")\n", "\n", " # ── Load BERT-base ──\n", " print(f\"\\n Loading BERT-base...\")\n", " tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")\n", " root_model = BertForMaskedLM.from_pretrained(\"google-bert/bert-base-uncased\").to(device)\n", " root_model.eval()\n", "\n", " # ── Profile + evaluate root ──\n", " root_cv, root_params = profile_bert(root_model, \"Root 768-dim\")\n", "\n", " print(f\" Evaluating root MLM accuracy...\")\n", " root_top1, root_top5 = evaluate_mlm(root_model, tokenizer, eval_texts, device)\n", " print(f\" Root: top1={root_top1:.4f} top5={root_top5:.4f}\")\n", "\n", " # ── Cascade ──\n", " results = [{\n", " \"scale\": SCALES[0],\n", " \"inter\": INTER_SCALES[0],\n", " \"params\": root_params,\n", " \"cv\": root_cv,\n", " \"top1_proj\": root_top1,\n", " \"top5_proj\": root_top5,\n", " \"top1_heal\": root_top1,\n", " \"top5_heal\": root_top5,\n", " \"heal_loss\": 0,\n", " \"heal_epochs\": 0,\n", " }]\n", "\n", " parent_model = root_model\n", "\n", " for i in range(1, len(SCALES)):\n", " hidden = SCALES[i]\n", " inter = INTER_SCALES[i]\n", " parent_hidden = SCALES[i-1]\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"SCALE {i}: {parent_hidden} → {hidden} \"\n", " f\"({(parent_hidden-hidden)/parent_hidden:.0%} reduction)\")\n", " print(f\"{'='*70}\")\n", "\n", " # ── Project ──\n", " print(f\" Projecting...\")\n", " t0 = time.time()\n", " child_model = create_scaled_bert(\n", " parent_model, hidden, inter, device)\n", " proj_time = time.time() - t0\n", " print(f\" Projection took {proj_time:.1f}s\")\n", "\n", " # ── Profile + evaluate after projection ──\n", " child_cv, child_params = profile_bert(child_model, f\"Projected {hidden}-dim\")\n", "\n", " print(f\" Evaluating MLM after projection...\")\n", " proj_top1, proj_top5 = evaluate_mlm(\n", " child_model, tokenizer, eval_texts, device)\n", " print(f\" After projection: top1={proj_top1:.4f} top5={proj_top5:.4f}\")\n", "\n", " # ── Heal ──\n", " print(f\" Healing ({HEAL_EPOCHS} epoch MLM)...\")\n", " t0 = time.time()\n", " heal_loss = heal_mlm(child_model, tokenizer, heal_texts, device,\n", " n_epochs=HEAL_EPOCHS, lr=5e-5)\n", " heal_time = time.time() - t0\n", " print(f\" Heal loss: {heal_loss:.4f} ({heal_time:.1f}s)\")\n", "\n", " # ── Profile + evaluate after heal ──\n", " heal_cv, _ = profile_bert(child_model, f\"Healed {hidden}-dim\")\n", "\n", " print(f\" Evaluating MLM after heal...\")\n", " heal_top1, heal_top5 = evaluate_mlm(\n", " child_model, tokenizer, eval_texts, device)\n", " print(f\" After heal: top1={heal_top1:.4f} top5={heal_top5:.4f}\")\n", "\n", " results.append({\n", " \"scale\": hidden,\n", " \"inter\": inter,\n", " \"params\": child_params,\n", " \"cv\": heal_cv,\n", " \"top1_proj\": proj_top1,\n", " \"top5_proj\": proj_top5,\n", " \"top1_heal\": heal_top1,\n", " \"top5_heal\": heal_top5,\n", " \"heal_loss\": heal_loss,\n", " \"heal_epochs\": HEAL_EPOCHS,\n", " })\n", "\n", " parent_model = child_model\n", "\n", " # ── Direct jump: 768 → 384 ──\n", " print(f\"\\n{'='*70}\")\n", " print(f\"DIRECT PROJECTION: {SCALES[0]} → {SCALES[-1]}\")\n", " print(f\"{'='*70}\")\n", "\n", " direct_model = create_scaled_bert(\n", " root_model, SCALES[-1], INTER_SCALES[-1], device)\n", " direct_cv, direct_params = profile_bert(direct_model, f\"Direct {SCALES[-1]}-dim\")\n", " direct_top1, direct_top5 = evaluate_mlm(\n", " direct_model, tokenizer, eval_texts, device)\n", " print(f\" Direct: top1={direct_top1:.4f} top5={direct_top5:.4f}\")\n", "\n", " # ── Report ──\n", " print(f\"\\n{'='*70}\")\n", " print(f\"RESULTS\")\n", " print(f\"{'='*70}\\n\")\n", "\n", " print(f\" {'Scale':>6s} {'Params':>12s} {'Top1(proj)':>11s} {'Top1(heal)':>11s} \"\n", " f\"{'Top5(proj)':>11s} {'Top5(heal)':>11s} {'CV':>7s}\")\n", " print(f\" {'─'*6} {'─'*12} {'─'*11} {'─'*11} {'─'*11} {'─'*11} {'─'*7}\")\n", "\n", " for r in results:\n", " print(f\" {r['scale']:>6d} {r['params']:>12,} {r['top1_proj']:>11.4f} \"\n", " f\"{r['top1_heal']:>11.4f} {r['top5_proj']:>11.4f} \"\n", " f\"{r['top5_heal']:>11.4f} {r['cv']:>7.4f}\")\n", "\n", " print(f\"\\n DIRECT {SCALES[-1]}: {direct_params:>12,} \"\n", " f\"top1={direct_top1:.4f} top5={direct_top5:.4f} cv={direct_cv:.4f}\")\n", "\n", " # ── Retention ──\n", " final = results[-1]\n", " print(f\"\\n SUMMARY:\")\n", " print(f\" Root: {root_params:>12,} params top1={root_top1:.4f} top5={root_top5:.4f}\")\n", " print(f\" Cascade: {final['params']:>12,} params \"\n", " f\"top1={final['top1_heal']:.4f} top5={final['top5_heal']:.4f}\")\n", " print(f\" Direct: {direct_params:>12,} params \"\n", " f\"top1={direct_top1:.4f} top5={direct_top5:.4f}\")\n", " print(f\" Compression: {root_params/final['params']:.1f}×\")\n", " print(f\" Top1 retained (cascade): {final['top1_heal']/root_top1:.1%}\")\n", " print(f\" Top1 retained (direct): {direct_top1/root_top1:.1%}\")\n", "\n", " print(f\"\\nDone.\")\n", " return results\n", "\n", "\n", "if __name__ == \"__main__\":\n", " results = run_experiment()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "d5c2a63f6f8544e79268b9bade807345", "f027a57fcb6a49cfbbd28b95a6c0adf7", "4df257d047cb4e7ab2a0a6c57be58b81", "214861fbbd124b45ba164e934f91a024", "f6cf4e1cd78c4c0da8bf756a7cc8760a", "bd3077d18b9640abb14a4c385cac7c39", "c39efce145764de59d3dc9cb7a818d33", "b9fc695d6d70408690bbd7ef8bb5841a", "a1486fd248674e9584e90d907b5156c2", "3472ba9857a543568ec3cd2f340571d6", "89eae54541d447bdaa9625cbbe357e55" ] }, "id": "dOhM9mTJXnQl", "outputId": "f68265c8-ba6a-4182-fba7-e5736bbaf37d" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "======================================================================\n", "ITERATIVE CASCADE ON PRETRAINED BERT-BASE\n", "======================================================================\n", " Device: cuda\n", " Scales: 768 → 720 → 672 → 624 → 576 → 528 → 480 → 432 → 384\n", " Compression: 768 → 384 (50%)\n", "\n", " Loading evaluation data...\n", " 200 eval texts\n", " 5000 heal texts\n", "\n", " Loading BERT-base...\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Loading weights: 0%| | 0/202 [00:00 0:\n", " vols.append(v)\n", " if len(vols) < 10:\n", " return 0.0\n", " a = np.array(vols, dtype=np.float64)\n", " return float(a.std() / max(a.mean(), 1e-12))\n", "\n", "def profile_bert(model, tag=\"\"):\n", " cvs = []\n", " for name, param in model.named_parameters():\n", " if param.dim() == 2 and param.shape[0] >= 5 and param.shape[1] >= 5:\n", " if \"weight\" in name and (\"dense\" in name or \"query\" in name\n", " or \"key\" in name or \"value\" in name):\n", " cv = pentachoron_cv(param.detach(), n_samples=100)\n", " cvs.append(cv)\n", " mean_cv = np.mean(cvs) if cvs else 0.0\n", " n_params = sum(p.numel() for p in model.parameters())\n", " if tag:\n", " print(f\" [{tag}] {n_params:,} params, CV={mean_cv:.4f} ({len(cvs)} matrices)\")\n", " return mean_cv, n_params\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# FIX 3: SVD PROJECTION WITH PROPER BIAS HANDLING\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def svd_project_weight_and_bias(W, b, out_dim, in_dim):\n", " \"\"\"\n", " Project weight matrix via truncated SVD AND project bias\n", " into the new SVD output basis.\n", "\n", " W: (D_out, D_in) → (out_dim, in_dim)\n", " b: (D_out,) → (out_dim,) projected via U_k.T\n", "\n", " Returns: W_new, b_new\n", " \"\"\"\n", " W = W.float()\n", " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", " k = min(S.shape[0], out_dim, in_dim)\n", "\n", " U_k = U[:, :k] # (D_out, k) — output basis\n", " S_k = S[:k] # (k,)\n", " Vt_k = Vt[:k, :] # (k, D_in)\n", "\n", " # Truncate input dimension\n", " Vt_k_trunc = Vt_k[:, :min(W.shape[1], in_dim)]\n", "\n", " # Reconstruct at target dimensions\n", " W_new = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", " core = U_k[:min(W.shape[0], out_dim), :] @ torch.diag(S_k) @ Vt_k_trunc\n", " r, c = core.shape\n", " W_new[:r, :c] = core\n", "\n", " # Project bias into new output basis\n", " b_new = torch.zeros(out_dim, dtype=W.dtype, device=W.device)\n", " if b is not None:\n", " b = b.float()\n", " # b_projected = U_k[:out_dim, :].T @ b → but U_k might be (D_out, k) with k < out_dim\n", " # Use the same truncation as W\n", " b_proj = U_k[:min(W.shape[0], out_dim), :].T @ b[:min(W.shape[0], out_dim)]\n", " b_new[:min(k, out_dim)] = b_proj[:min(k, out_dim)]\n", "\n", " return W_new, b_new\n", "\n", "\n", "@torch.no_grad()\n", "def svd_project_matrix_only(W, out_dim, in_dim):\n", " \"\"\"SVD project weight matrix without bias.\"\"\"\n", " W = W.float()\n", " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", " k = min(S.shape[0], out_dim, in_dim)\n", " U_k = U[:min(W.shape[0], out_dim), :k]\n", " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", " result = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", " r, c = W_small.shape\n", " result[:r, :c] = W_small\n", " return result\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# FIX 4: L1 MAGNITUDE PRUNING FOR FFN INTERMEDIATE\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def l1_prune_ffn(W_up, b_up, W_down, b_down, target_intermediate):\n", " \"\"\"\n", " Prune FFN intermediate dimension by keeping rows/cols with highest L1 norm.\n", " Preserves coordinate alignment with GELU nonlinearity.\n", "\n", " W_up: (src_inter, src_hidden) — expands\n", " b_up: (src_inter,)\n", " W_down: (src_hidden, src_inter) — contracts\n", " b_down: (src_hidden,)\n", "\n", " Returns: pruned W_up, b_up, W_down (columns pruned)\n", " \"\"\"\n", " src_inter = W_up.shape[0]\n", " if src_inter <= target_intermediate:\n", " return W_up, b_up, W_down\n", "\n", " # Importance = L1 norm of each intermediate neuron\n", " # Combined from both up-projection row and down-projection column\n", " importance = W_up.float().abs().sum(dim=1) + W_down.float().abs().sum(dim=0)\n", "\n", " # Keep top-k\n", " _, keep_idx = importance.topk(target_intermediate)\n", " keep_idx = keep_idx.sort().values\n", "\n", " W_up_pruned = W_up[keep_idx, :]\n", " b_up_pruned = b_up[keep_idx] if b_up is not None else None\n", " W_down_pruned = W_down[:, keep_idx]\n", " # b_down stays same dimension (hidden_size)\n", "\n", " return W_up_pruned, b_up_pruned, W_down_pruned\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# CORRECTED BERT PROJECTION\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def create_projected_bert(source_model, target_hidden, target_intermediate, device):\n", " \"\"\"\n", " Project BERT with:\n", " - SVD + proper bias projection for attention/embedding matrices\n", " - L1 magnitude pruning for FFN intermediate (respects GELU)\n", " \"\"\"\n", " src_config = source_model.config\n", " new_config = BertConfig(\n", " vocab_size=src_config.vocab_size,\n", " hidden_size=target_hidden,\n", " num_hidden_layers=src_config.num_hidden_layers,\n", " num_attention_heads=src_config.num_attention_heads,\n", " intermediate_size=target_intermediate,\n", " max_position_embeddings=src_config.max_position_embeddings,\n", " type_vocab_size=src_config.type_vocab_size,\n", " hidden_act=src_config.hidden_act,\n", " hidden_dropout_prob=0.0,\n", " attention_probs_dropout_prob=0.0,\n", " )\n", " target = BertForMaskedLM(new_config).to(device)\n", " src = source_model\n", "\n", " # ── Embeddings (SVD on hidden dim, keep vocab) ──\n", " for emb_name in [\"word_embeddings\", \"position_embeddings\", \"token_type_embeddings\"]:\n", " src_w = getattr(src.bert.embeddings, emb_name).weight.data\n", " tgt_w = getattr(target.bert.embeddings, emb_name).weight.data\n", " tgt_w.copy_(svd_project_matrix_only(src_w, tgt_w.shape[0], tgt_w.shape[1]))\n", "\n", " # Embedding LayerNorm — truncate (element-wise, no rotation)\n", " target.bert.embeddings.LayerNorm.weight.data.copy_(\n", " src.bert.embeddings.LayerNorm.weight.data[:target_hidden])\n", " target.bert.embeddings.LayerNorm.bias.data.copy_(\n", " src.bert.embeddings.LayerNorm.bias.data[:target_hidden])\n", "\n", " # ── Encoder layers ──\n", " for i, (src_layer, tgt_layer) in enumerate(\n", " zip(src.bert.encoder.layer, target.bert.encoder.layer)):\n", "\n", " # Q, K, V: (src_hidden, src_hidden) → (target_hidden, target_hidden)\n", " for attr in [\"query\", \"key\", \"value\"]:\n", " src_mod = getattr(src_layer.attention.self, attr)\n", " tgt_mod = getattr(tgt_layer.attention.self, attr)\n", " W_new, b_new = svd_project_weight_and_bias(\n", " src_mod.weight.data, src_mod.bias.data,\n", " target_hidden, target_hidden)\n", " tgt_mod.weight.data.copy_(W_new)\n", " tgt_mod.bias.data.copy_(b_new)\n", "\n", " # Attention output: (src_hidden, src_hidden) → (target_hidden, target_hidden)\n", " W_new, b_new = svd_project_weight_and_bias(\n", " src_layer.attention.output.dense.weight.data,\n", " src_layer.attention.output.dense.bias.data,\n", " target_hidden, target_hidden)\n", " tgt_layer.attention.output.dense.weight.data.copy_(W_new)\n", " tgt_layer.attention.output.dense.bias.data.copy_(b_new)\n", "\n", " # Attention LayerNorm — truncate\n", " tgt_layer.attention.output.LayerNorm.weight.data.copy_(\n", " src_layer.attention.output.LayerNorm.weight.data[:target_hidden])\n", " tgt_layer.attention.output.LayerNorm.bias.data.copy_(\n", " src_layer.attention.output.LayerNorm.bias.data[:target_hidden])\n", "\n", " # FFN: L1 magnitude pruning (respects GELU coordinate alignment)\n", " W_up = src_layer.intermediate.dense.weight.data.to(device)\n", " b_up = src_layer.intermediate.dense.bias.data.to(device)\n", " W_down = src_layer.output.dense.weight.data.to(device)\n", " b_down = src_layer.output.dense.bias.data.to(device)\n", "\n", " # First prune intermediate dimension\n", " W_up_p, b_up_p, W_down_p = l1_prune_ffn(\n", " W_up, b_up, W_down, b_down, target_intermediate)\n", "\n", " # Then SVD-project the hidden dimensions with proper bias\n", " W_up_final, b_up_final = svd_project_weight_and_bias(\n", " W_up_p, b_up_p, target_intermediate, target_hidden)\n", " tgt_layer.intermediate.dense.weight.data.copy_(W_up_final)\n", " tgt_layer.intermediate.dense.bias.data.copy_(b_up_final)\n", "\n", " W_down_final, b_down_final = svd_project_weight_and_bias(\n", " W_down_p, b_down, target_hidden, target_intermediate)\n", " tgt_layer.output.dense.weight.data.copy_(W_down_final)\n", " tgt_layer.output.dense.bias.data.copy_(b_down_final)\n", "\n", " # Output LayerNorm — truncate\n", " tgt_layer.output.LayerNorm.weight.data.copy_(\n", " src_layer.output.LayerNorm.weight.data[:target_hidden])\n", " tgt_layer.output.LayerNorm.bias.data.copy_(\n", " src_layer.output.LayerNorm.bias.data[:target_hidden])\n", "\n", " # ── MLM Head ──\n", " if hasattr(src.cls.predictions.transform, 'dense'):\n", " W_new, b_new = svd_project_weight_and_bias(\n", " src.cls.predictions.transform.dense.weight.data,\n", " src.cls.predictions.transform.dense.bias.data,\n", " target_hidden, target_hidden)\n", " target.cls.predictions.transform.dense.weight.data.copy_(W_new)\n", " target.cls.predictions.transform.dense.bias.data.copy_(b_new)\n", "\n", " if hasattr(src.cls.predictions.transform, 'LayerNorm'):\n", " target.cls.predictions.transform.LayerNorm.weight.data.copy_(\n", " src.cls.predictions.transform.LayerNorm.weight.data[:target_hidden])\n", " target.cls.predictions.transform.LayerNorm.bias.data.copy_(\n", " src.cls.predictions.transform.LayerNorm.bias.data[:target_hidden])\n", "\n", " if hasattr(src.cls.predictions, 'bias'):\n", " target.cls.predictions.bias.data.copy_(src.cls.predictions.bias.data)\n", "\n", " return target\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# FIX 2: FROZEN PER-LAYER PROJECTORS\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class FrozenLayerProjectors(nn.Module):\n", " \"\"\"\n", " Per-layer projectors initialized from Procrustes, then FROZEN.\n", " The student must move to match the fixed target — no shortcut collapse.\n", " \"\"\"\n", " def __init__(self, teacher_dim, student_dim, n_layers, device):\n", " super().__init__()\n", " self.projectors = nn.ModuleList([\n", " nn.Linear(teacher_dim, student_dim, bias=False).to(device)\n", " for _ in range(n_layers + 1)\n", " ])\n", "\n", " @torch.no_grad()\n", " def init_from_layer_procrustes(self, teacher_model, student_dim, device):\n", " teacher_dim = teacher_model.config.hidden_size\n", " for i, layer in enumerate(teacher_model.bert.encoder.layer):\n", " weights = [\n", " layer.attention.self.query.weight.data.T,\n", " layer.attention.self.key.weight.data.T,\n", " layer.attention.self.value.weight.data.T,\n", " layer.intermediate.dense.weight.data.T,\n", " ]\n", " L = torch.cat(weights, dim=1).float().to(device)\n", " U, S, Vt = torch.linalg.svd(L, full_matrices=False)\n", " P_layer = U[:, :student_dim] # (teacher_dim, student_dim)\n", " self.projectors[i + 1].weight.data.copy_(P_layer.T)\n", " if i == 0:\n", " self.projectors[0].weight.data.copy_(P_layer.T)\n", "\n", " # FREEZE all projectors\n", " for p in self.parameters():\n", " p.requires_grad = False\n", "\n", " print(f\" {len(self.projectors)} projectors initialized + FROZEN\")\n", "\n", " def forward(self, teacher_hiddens):\n", " projected = []\n", " for t_h, proj in zip(teacher_hiddens, self.projectors):\n", " projected.append(proj(t_h.float()))\n", " return projected\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# FIX 1: DISTILLATION LOSS ON ACTIVE TOKENS ONLY\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "class TeacherGuidedHealerV3:\n", " def __init__(self, teacher_model, projectors, device,\n", " mlm_weight=1.0, distill_weight=2.0):\n", " self.teacher = teacher_model\n", " self.teacher.eval()\n", " self.projectors = projectors # FROZEN\n", " self.device = device\n", " self.mlm_weight = mlm_weight\n", " self.distill_weight = distill_weight\n", "\n", " def compute_distillation_loss(self, student_model, input_ids, attention_mask):\n", " with torch.no_grad():\n", " teacher_out = self.teacher.bert(\n", " input_ids=input_ids, attention_mask=attention_mask,\n", " output_hidden_states=True, return_dict=True)\n", "\n", " student_out = student_model.bert(\n", " input_ids=input_ids, attention_mask=attention_mask,\n", " output_hidden_states=True, return_dict=True)\n", "\n", " # Per-layer projection (frozen projectors)\n", " projected = self.projectors(teacher_out.hidden_states)\n", " student_hiddens = student_out.hidden_states\n", "\n", " n_layers = min(len(projected), len(student_hiddens))\n", " total_loss = torch.tensor(0.0, device=self.device)\n", "\n", " # FIX 1: Only compute loss on active (non-padding) tokens\n", " active_mask = attention_mask.float() # (B, seq), 1=active, 0=pad\n", " n_active = active_mask.sum().clamp(min=1.0)\n", "\n", " for layer_idx in range(1, n_layers):\n", " t_proj = projected[layer_idx] # (B, seq, student_dim)\n", " s_h = student_hiddens[layer_idx].float()\n", "\n", " # Cosine similarity per token\n", " t_norm = F.normalize(t_proj, dim=-1)\n", " s_norm = F.normalize(s_h, dim=-1)\n", " cos_sim = (t_norm * s_norm).sum(-1) # (B, seq)\n", "\n", " # Mask out padding, average over active tokens only\n", " cos_sim_active = cos_sim * active_mask\n", " layer_loss = 1.0 - cos_sim_active.sum() / n_active\n", " total_loss = total_loss + layer_loss\n", "\n", " return total_loss / max(n_layers - 1, 1)\n", "\n", " def heal(self, student_model, tokenizer, texts, n_epochs=5,\n", " lr=5e-5, max_len=128, batch_size=16):\n", " student_model.train()\n", " # Only student params — projectors are frozen\n", " optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)\n", "\n", " enc = tokenizer(texts, max_length=max_len, truncation=True,\n", " padding=\"max_length\", return_tensors=\"pt\")\n", " ids, masks = enc[\"input_ids\"], enc[\"attention_mask\"]\n", " collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n", " n = ids.shape[0]\n", " total_loss = 0\n", " n_batches = 0\n", "\n", " for epoch in range(n_epochs):\n", " perm = torch.randperm(n)\n", " for i in range(0, n, batch_size):\n", " idx = perm[i:i+batch_size]\n", " batch = [{\"input_ids\": ids[j], \"attention_mask\": masks[j]}\n", " for j in idx]\n", " c = collator(batch)\n", " c_ids = c[\"input_ids\"].to(self.device)\n", " c_mask = c[\"attention_mask\"].to(self.device)\n", " c_labels = c[\"labels\"].to(self.device)\n", "\n", " mlm_out = student_model(\n", " input_ids=c_ids, attention_mask=c_mask, labels=c_labels)\n", "\n", " # Distillation on unmasked input\n", " orig_ids = ids[idx].to(self.device)\n", " orig_mask = masks[idx].to(self.device)\n", " distill_loss = self.compute_distillation_loss(\n", " student_model, orig_ids, orig_mask)\n", "\n", " loss = (self.mlm_weight * mlm_out.loss +\n", " self.distill_weight * distill_loss)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " total_loss += loss.item()\n", " n_batches += 1\n", "\n", " if (epoch + 1) % 2 == 0 or epoch == 0:\n", " nb = max(n_batches, 1)\n", " print(f\" Epoch {epoch+1}: loss={total_loss/nb:.4f} \"\n", " f\"(mlm≈{mlm_out.loss.item():.3f}, \"\n", " f\"distill≈{distill_loss.item():.3f})\")\n", "\n", " return total_loss / max(n_batches, 1)\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# FIX 5: SEEDED EVALUATION\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "@torch.no_grad()\n", "def evaluate_mlm(model, tokenizer, texts, device, mask_prob=0.15,\n", " max_len=128, seed=42):\n", " \"\"\"Deterministic masking for consistent cross-scale comparison.\"\"\"\n", " model.eval()\n", " gen = torch.Generator().manual_seed(seed)\n", " total_1 = total_5 = total_m = 0\n", "\n", " for text in texts:\n", " tokens = tokenizer(text, return_tensors=\"pt\", max_length=max_len,\n", " truncation=True, padding=False).to(device)\n", " input_ids = tokens[\"input_ids\"][0]\n", " seq_len = input_ids.shape[0]\n", " if seq_len < 5:\n", " continue\n", " special = torch.zeros(seq_len, dtype=torch.bool, device=device)\n", " special[0] = special[seq_len-1] = True\n", " special[input_ids == tokenizer.pad_token_id] = True\n", " maskable = (~special).nonzero(as_tuple=True)[0]\n", " if len(maskable) == 0:\n", " continue\n", " n_mask = max(1, int(len(maskable) * mask_prob))\n", " chosen = maskable[torch.randperm(len(maskable), generator=gen)[:n_mask]]\n", " orig = input_ids[chosen].clone()\n", " masked = input_ids.clone()\n", " masked[chosen] = tokenizer.mask_token_id\n", " logits = model(masked.unsqueeze(0),\n", " attention_mask=tokens[\"attention_mask\"]).logits[0, chosen]\n", " total_1 += (logits.argmax(-1) == orig).sum().item()\n", " top5 = logits.topk(5, dim=-1).indices\n", " total_5 += (top5 == orig.unsqueeze(-1)).any(-1).sum().item()\n", " total_m += n_mask\n", " if total_m == 0:\n", " return 0.0, 0.0\n", " return total_1 / total_m, total_5 / total_m\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════\n", "# EXPERIMENT\n", "# ══════════════════════════════════════════════════════════════════\n", "\n", "def run_experiment():\n", " print(\"=\" * 70)\n", " print(\"TEACHER-GUIDED CASCADE v3 — ALL FIXES\")\n", " print(\"=\" * 70)\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\" Device: {device}\")\n", "\n", " SCALES = [768, 720, 672, 624, 576, 528, 480, 432, 384]\n", " INTER_SCALES = [3072, 2880, 2688, 2496, 2304, 2112, 1920, 1728, 1536]\n", " N_EVAL = 200\n", " N_HEAL = 5000\n", " HEAL_EPOCHS = 5\n", "\n", " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", " print(f\" Fixes: padding-masked loss, frozen projectors, \"\n", " f\"SVD bias projection, L1 FFN pruning, seeded eval\")\n", "\n", " # ── Data ──\n", " print(f\"\\n Loading data...\")\n", " ds_val = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\")\n", " eval_texts = [r[\"text\"].strip() for r in ds_val if len(r[\"text\"].strip()) > 50][:N_EVAL]\n", " ds_train = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"train\")\n", " heal_texts = [r[\"text\"].strip() for r in ds_train if len(r[\"text\"].strip()) > 100][:N_HEAL]\n", " print(f\" {len(eval_texts)} eval, {len(heal_texts)} heal texts\")\n", "\n", " # ── Teacher ──\n", " print(f\"\\n Loading BERT-base (teacher)...\")\n", " tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")\n", " teacher = BertForMaskedLM.from_pretrained(\"google-bert/bert-base-uncased\").to(device)\n", " teacher.eval()\n", " for p in teacher.parameters():\n", " p.requires_grad = False\n", "\n", " root_cv, root_params = profile_bert(teacher, \"Teacher 768\")\n", " root_top1, root_top5 = evaluate_mlm(teacher, tokenizer, eval_texts, device)\n", " print(f\" Teacher: top1={root_top1:.4f} top5={root_top5:.4f}\")\n", "\n", " results = [{\n", " \"scale\": 768, \"params\": root_params, \"cv\": root_cv,\n", " \"top1_proj\": root_top1, \"top5_proj\": root_top5,\n", " \"top1_heal\": root_top1, \"top5_heal\": root_top5,\n", " }]\n", "\n", " parent = teacher\n", " n_encoder_layers = teacher.config.num_hidden_layers\n", "\n", " for i in range(1, len(SCALES)):\n", " hidden = SCALES[i]\n", " inter = INTER_SCALES[i]\n", " parent_hidden = SCALES[i-1]\n", "\n", " print(f\"\\n{'='*70}\")\n", " print(f\"SCALE {i}: {parent_hidden} → {hidden} \"\n", " f\"({(parent_hidden-hidden)/parent_hidden:.0%} reduction)\")\n", " print(f\"{'='*70}\")\n", "\n", " # ── Project ──\n", " print(f\" Projecting (SVD + L1 FFN prune)...\")\n", " t0 = time.time()\n", " child = create_projected_bert(parent, hidden, inter, device)\n", " print(f\" Projection: {time.time()-t0:.1f}s\")\n", "\n", " child_cv, child_params = profile_bert(child, f\"Projected {hidden}\")\n", " proj_top1, proj_top5 = evaluate_mlm(child, tokenizer, eval_texts, device)\n", " print(f\" After proj: top1={proj_top1:.4f} top5={proj_top5:.4f}\")\n", "\n", " # ── Frozen per-layer projectors ──\n", " print(f\" Initializing frozen per-layer projectors...\")\n", " projectors = FrozenLayerProjectors(768, hidden, n_encoder_layers, device)\n", " projectors.init_from_layer_procrustes(teacher, hidden, device)\n", "\n", " # ── Teacher-guided healing ──\n", " print(f\" Healing ({HEAL_EPOCHS} epochs)...\")\n", " healer = TeacherGuidedHealerV3(\n", " teacher, projectors, device,\n", " mlm_weight=1.0, distill_weight=2.0)\n", " t0 = time.time()\n", " heal_loss = healer.heal(child, tokenizer, heal_texts,\n", " n_epochs=HEAL_EPOCHS, lr=5e-5)\n", " print(f\" Heal: {time.time()-t0:.1f}s\")\n", "\n", " heal_cv, _ = profile_bert(child, f\"Healed {hidden}\")\n", " heal_top1, heal_top5 = evaluate_mlm(child, tokenizer, eval_texts, device)\n", " print(f\" After heal: top1={heal_top1:.4f} top5={heal_top5:.4f}\")\n", "\n", " results.append({\n", " \"scale\": hidden, \"params\": child_params, \"cv\": heal_cv,\n", " \"top1_proj\": proj_top1, \"top5_proj\": proj_top5,\n", " \"top1_heal\": heal_top1, \"top5_heal\": heal_top5,\n", " })\n", "\n", " parent = child\n", "\n", " # ── Report ──\n", " print(f\"\\n{'='*70}\")\n", " print(f\"RESULTS\")\n", " print(f\"{'='*70}\\n\")\n", "\n", " print(f\" {'Scale':>6s} {'Params':>12s} {'Top1(proj)':>11s} {'Top1(heal)':>11s} \"\n", " f\"{'Top5(proj)':>11s} {'Top5(heal)':>11s} {'CV':>7s}\")\n", " print(f\" {'─'*6} {'─'*12} {'─'*11} {'─'*11} {'─'*11} {'─'*11} {'─'*7}\")\n", "\n", " for r in results:\n", " print(f\" {r['scale']:>6d} {r['params']:>12,} {r['top1_proj']:>11.4f} \"\n", " f\"{r['top1_heal']:>11.4f} {r['top5_proj']:>11.4f} \"\n", " f\"{r['top5_heal']:>11.4f} {r['cv']:>7.4f}\")\n", "\n", " final = results[-1]\n", " print(f\"\\n SUMMARY:\")\n", " print(f\" Teacher: {root_params:>12,} top1={root_top1:.4f} top5={root_top5:.4f}\")\n", " print(f\" Cascade: {final['params']:>12,} \"\n", " f\"top1={final['top1_heal']:.4f} top5={final['top5_heal']:.4f}\")\n", " print(f\" Compression: {root_params/final['params']:.1f}×\")\n", " print(f\" Top1 retained: {final['top1_heal']/root_top1:.1%}\")\n", "\n", " print(f\"\\n ALL APPROACHES:\")\n", " print(f\" v1 Independent SVD + blind MLM: 61.5%\")\n", " print(f\" v2 + teacher global P: 62.5%\")\n", " print(f\" v2 + per-layer projectors (buggy): ???\")\n", " print(f\" v3 all fixes: {final['top1_heal']/root_top1:.1%}\")\n", "\n", " print(f\"\\nDone.\")\n", " return results\n", "\n", "\n", "if __name__ == \"__main__\":\n", " results = run_experiment()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "ae1bdb89a9704d09a1c03ec82354905e", "8d03191c19b64a698be6d6fc141817cb", "8d24919b783b47748aefac2a6c234313", "38bea3901f6c4e339d17a0269f0e535b", "e4655e0917814b1b950d53a8f59d1fa5", "325923e6a4054298be70581c2cd1061d", "97bcf95ab5ac4823b2f696f42db534da", "866d9713b1f5436b924a5505e36b9e7d", "8f481ca713e04111a7dcab82f19a072b", "8566ff523c5d4bb5b799cdd12a95c633", "bd53b01466524897aea749b68000684a" ] }, "id": "ddAQ1-RGp3Fx", "outputId": "a9f9c9f5-fc9a-4120-d1c4-06a50e9f7483" }, "execution_count": 20, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "======================================================================\n", "TEACHER-GUIDED CASCADE v3 — ALL FIXES\n", "======================================================================\n", " Device: cuda\n", " Scales: 768 → 720 → 672 → 624 → 576 → 528 → 480 → 432 → 384\n", " Fixes: padding-masked loss, frozen projectors, SVD bias projection, L1 FFN pruning, seeded eval\n", "\n", " Loading data...\n", " 200 eval, 5000 heal texts\n", "\n", " Loading BERT-base (teacher)...\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Loading weights: 0%| | 0/202 [00:00