FalloutGeckoAllDay commited on
Commit
6ec47f3
·
verified ·
1 Parent(s): 774ca36

Upload Zouhayer_v3.ipynb

Browse files
Files changed (1) hide show
  1. Zouhayer_v3.ipynb +2116 -0
Zouhayer_v3.ipynb ADDED
@@ -0,0 +1,2116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "from google.colab import drive\n",
23
+ "drive.mount('/content/drive')"
24
+ ],
25
+ "metadata": {
26
+ "colab": {
27
+ "base_uri": "https://localhost:8080/"
28
+ },
29
+ "id": "1JHSkdJSJLlV",
30
+ "outputId": "4bc51836-8266-4a47-d50c-a3d7b7abdce0"
31
+ },
32
+ "execution_count": 1,
33
+ "outputs": [
34
+ {
35
+ "output_type": "stream",
36
+ "name": "stdout",
37
+ "text": [
38
+ "Mounted at /content/drive\n"
39
+ ]
40
+ }
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "metadata": {
47
+ "colab": {
48
+ "base_uri": "https://localhost:8080/"
49
+ },
50
+ "id": "pYiI4zfBztum",
51
+ "outputId": "cc5f1000-52d0-433f-a44a-b6526a3daf4f"
52
+ },
53
+ "outputs": [
54
+ {
55
+ "output_type": "stream",
56
+ "name": "stdout",
57
+ "text": [
58
+ "Collecting fasttext\n",
59
+ " Downloading fasttext-0.9.3.tar.gz (73 kB)\n",
60
+ "\u001b[?25l \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/73.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m73.4/73.4 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
61
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
62
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
63
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
64
+ "Collecting pybind11>=2.2 (from fasttext)\n",
65
+ " Using cached pybind11-3.0.3-py3-none-any.whl.metadata (10 kB)\n",
66
+ "Requirement already satisfied: setuptools>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from fasttext) (75.2.0)\n",
67
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from fasttext) (2.0.2)\n",
68
+ "Using cached pybind11-3.0.3-py3-none-any.whl (313 kB)\n",
69
+ "Building wheels for collected packages: fasttext\n",
70
+ " Building wheel for fasttext (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
71
+ " Created wheel for fasttext: filename=fasttext-0.9.3-cp312-cp312-linux_x86_64.whl size=4653979 sha256=a6c11b55dfb9ca99d6c2ad8c026d2d0bcee4cd11dbfb8cd9864ca7de1260b569\n",
72
+ " Stored in directory: /root/.cache/pip/wheels/20/27/95/a7baf1b435f1cbde017cabdf1e9688526d2b0e929255a359c6\n",
73
+ "Successfully built fasttext\n",
74
+ "Installing collected packages: pybind11, fasttext\n",
75
+ "Successfully installed fasttext-0.9.3 pybind11-3.0.3\n"
76
+ ]
77
+ }
78
+ ],
79
+ "source": [
80
+ "!pip install fasttext"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "source": [
86
+ "import pandas as pd\n",
87
+ "import numpy as np\n",
88
+ "corpus_full = pd.read_csv(\"/content/ArabiziProfanityDatasetByMotaz.csv\").astype(str)\n",
89
+ "corpus_text = np.array(corpus_full[\"text\"].tolist(), dtype=str)"
90
+ ],
91
+ "metadata": {
92
+ "id": "R0bAuNkz2HDx"
93
+ },
94
+ "execution_count": 3,
95
+ "outputs": []
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "source": [
100
+ "import fasttext\n",
101
+ "import pandas as pd\n",
102
+ "import numpy as np\n",
103
+ "import os\n",
104
+ "# Ensure data is loaded\n",
105
+ "corpus_full = pd.read_csv(\"/content/ArabiziProfanityDatasetByMotaz.csv\").astype(str)\n",
106
+ "corpus_text = np.array(corpus_full[\"text\"].tolist(), dtype=str)\n",
107
+ "\n",
108
+ "# Save the corpus to a text file as required by fastText\n",
109
+ "with open('corpus.txt', 'w', encoding='utf-8') as f:\n",
110
+ " for line in corpus_text:\n",
111
+ " f.write(str(line) + '\\n')\n",
112
+ "\n",
113
+ "# Train the model with verbose=2 to see progress\n",
114
+ "print(\"Starting training...\")\n",
115
+ "model = fasttext.train_unsupervised(input=\"corpus.txt\", model=\"skipgram\", lr=0.05, dim=256, minn=2, maxn=8, wordNgrams=3, ws=5, epoch=20, verbose=2)\n",
116
+ "model.save_model(\"TunisianEmbeddings.ftz\")\n",
117
+ "print(\"Model saved as cc.ftz\")"
118
+ ],
119
+ "metadata": {
120
+ "colab": {
121
+ "base_uri": "https://localhost:8080/"
122
+ },
123
+ "id": "cGatmAbG1SE4",
124
+ "outputId": "7479bde2-2edb-41af-d7cf-dfca701e5fdb"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": [
128
+ {
129
+ "output_type": "stream",
130
+ "name": "stdout",
131
+ "text": [
132
+ "Starting training...\n",
133
+ "Model saved as cc.ftz\n"
134
+ ]
135
+ }
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "source": [
141
+ "# Copies the file to the root of your Google Drive\n",
142
+ "!cp /content/TunisianEmbeddings.ftz /content/drive/MyDrive/"
143
+ ],
144
+ "metadata": {
145
+ "id": "KGkrmaYfJ021"
146
+ },
147
+ "execution_count": null,
148
+ "outputs": []
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "import torch"
154
+ ],
155
+ "metadata": {
156
+ "id": "xWdg88mW8Nn9"
157
+ },
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "source": [
164
+ "import torch.nn as nn"
165
+ ],
166
+ "metadata": {
167
+ "id": "IngqsV4b8QTL"
168
+ },
169
+ "execution_count": null,
170
+ "outputs": []
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "import torch\n",
176
+ "import torch.nn as nn\n",
177
+ "\n",
178
+ "# \u2500\u2500 LiGRU cell \u2014 LayerNorm, no BN \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
179
+ "class LiGRUCell(nn.Module):\n",
180
+ " def __init__(self, input_size, hidden_size, dropout=0.0):\n",
181
+ " super().__init__()\n",
182
+ " self.hidden_size = hidden_size\n",
183
+ " self.Wz = nn.Linear(input_size + hidden_size, hidden_size)\n",
184
+ " self.Wh = nn.Linear(input_size + hidden_size, hidden_size)\n",
185
+ " self.ln_h = nn.LayerNorm(hidden_size)\n",
186
+ " self.drop = nn.Dropout(dropout)\n",
187
+ "\n",
188
+ " def forward(self, x, h):\n",
189
+ " combined = torch.cat([x, h], dim=-1)\n",
190
+ " z = torch.sigmoid(self.Wz(combined))\n",
191
+ " h_candidate = torch.relu(self.ln_h(self.Wh(combined)))\n",
192
+ " h_candidate = self.drop(h_candidate)\n",
193
+ " return (1 - z) * h + z * h_candidate\n",
194
+ "\n",
195
+ "\n",
196
+ "# \u2500\u2500 Bidirectional LiGRU \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
197
+ "class BiLiGRU(nn.Module):\n",
198
+ " def __init__(self, input_size, hidden_size, dropout=0.4):\n",
199
+ " super().__init__()\n",
200
+ " self.hidden_size = hidden_size\n",
201
+ " self.fwd_cell = LiGRUCell(input_size, hidden_size, dropout=dropout)\n",
202
+ " self.bwd_cell = LiGRUCell(input_size, hidden_size, dropout=dropout)\n",
203
+ " self.out_drop = nn.Dropout(dropout)\n",
204
+ "\n",
205
+ " def forward(self, x):\n",
206
+ " batch, seq_len, _ = x.shape\n",
207
+ " h_fwd = torch.zeros(batch, self.hidden_size, device=x.device)\n",
208
+ " h_bwd = torch.zeros(batch, self.hidden_size, device=x.device)\n",
209
+ " fwd_out, bwd_out = [], []\n",
210
+ " for t in range(seq_len):\n",
211
+ " h_fwd = self.fwd_cell(x[:, t, :], h_fwd)\n",
212
+ " h_bwd = self.bwd_cell(x[:, seq_len - 1 - t, :], h_bwd)\n",
213
+ " fwd_out.append(h_fwd)\n",
214
+ " bwd_out.append(h_bwd)\n",
215
+ " bwd_out.reverse()\n",
216
+ " out = torch.stack(\n",
217
+ " [torch.cat([f, b], dim=-1) for f, b in zip(fwd_out, bwd_out)], dim=1\n",
218
+ " ) # [B, T, H*2]\n",
219
+ " return self.out_drop(out)\n",
220
+ "\n",
221
+ "\n",
222
+ "# \u2500\u2500 Asfour v3 \u2014 anti-overfitting redesign \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
223
+ "class Asfour(nn.Module):\n",
224
+ " \"\"\"\n",
225
+ " v3 vs v2 changes (all anti-overfitting):\n",
226
+ " - REMOVED RoPE: overkill for \u226430-token sequences; added memorisation risk\n",
227
+ " - REMOVED CNN block: too many stacked components for dataset size\n",
228
+ " - hidden_size 64 \u2192 48: further capacity cut (~67% fewer total params vs v2)\n",
229
+ " - dropout 0.25 \u2192 0.40: meaningful regularisation throughout the recurrent path\n",
230
+ " - max+avg pooling \u2192 learned attention pooling: fewer params, differentiable focus\n",
231
+ " - head simplified to Dropout \u2192 Linear: removes extra dense layer that was overfitting\n",
232
+ " \"\"\"\n",
233
+ " def __init__(self, input_dim=256, hidden_size=48, num_classes=1, dropout=0.4):\n",
234
+ " super().__init__()\n",
235
+ "\n",
236
+ " # Input projection with built-in regularisation\n",
237
+ " self.proj = nn.Sequential(\n",
238
+ " nn.Linear(input_dim, hidden_size),\n",
239
+ " nn.LayerNorm(hidden_size),\n",
240
+ " nn.ReLU(),\n",
241
+ " nn.Dropout(dropout),\n",
242
+ " )\n",
243
+ "\n",
244
+ " self.bigru = BiLiGRU(\n",
245
+ " input_size=hidden_size,\n",
246
+ " hidden_size=hidden_size,\n",
247
+ " dropout=dropout,\n",
248
+ " ) # output: [B, T, hidden_size*2]\n",
249
+ "\n",
250
+ " # Learned attention pooling \u2014 collapses sequence without bloating params\n",
251
+ " self.attn = nn.Linear(hidden_size * 2, 1, bias=False)\n",
252
+ "\n",
253
+ " # Minimal head: single linear after dropout\n",
254
+ " self.head = nn.Sequential(\n",
255
+ " nn.Dropout(dropout),\n",
256
+ " nn.Linear(hidden_size * 2, num_classes),\n",
257
+ " )\n",
258
+ "\n",
259
+ " def forward(self, x):\n",
260
+ " # x: [B, T, 256]\n",
261
+ " x = self.proj(x) # [B, T, H]\n",
262
+ " x = self.bigru(x) # [B, T, H*2]\n",
263
+ "\n",
264
+ " # Attention pooling\n",
265
+ " w = torch.softmax(self.attn(x), dim=1) # [B, T, 1]\n",
266
+ " x = (x * w).sum(dim=1) # [B, H*2]\n",
267
+ "\n",
268
+ " return self.head(x) # [B, 1]\n"
269
+ ],
270
+ "metadata": {
271
+ "id": "KPoh1f6y8SPY"
272
+ },
273
+ "execution_count": 4,
274
+ "outputs": []
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "source": [
279
+ "import fasttext\n",
280
+ "\n",
281
+ "model = fasttext.load_model(\"/content/drive/MyDrive/cc.ftz\")"
282
+ ],
283
+ "metadata": {
284
+ "id": "2h4pWkfxE_yS"
285
+ },
286
+ "execution_count": 5,
287
+ "outputs": []
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "metadata": {
292
+ "id": "0c95f755"
293
+ },
294
+ "source": [
295
+ "i",
296
+ "m",
297
+ "p",
298
+ "o",
299
+ "r",
300
+ "t",
301
+ " ",
302
+ "t",
303
+ "o",
304
+ "r",
305
+ "c",
306
+ "h",
307
+ "\n",
308
+ "f",
309
+ "r",
310
+ "o",
311
+ "m",
312
+ " ",
313
+ "t",
314
+ "o",
315
+ "r",
316
+ "c",
317
+ "h",
318
+ ".",
319
+ "u",
320
+ "t",
321
+ "i",
322
+ "l",
323
+ "s",
324
+ ".",
325
+ "d",
326
+ "a",
327
+ "t",
328
+ "a",
329
+ " ",
330
+ "i",
331
+ "m",
332
+ "p",
333
+ "o",
334
+ "r",
335
+ "t",
336
+ " ",
337
+ "D",
338
+ "a",
339
+ "t",
340
+ "a",
341
+ "s",
342
+ "e",
343
+ "t",
344
+ ",",
345
+ " ",
346
+ "D",
347
+ "a",
348
+ "t",
349
+ "a",
350
+ "L",
351
+ "o",
352
+ "a",
353
+ "d",
354
+ "e",
355
+ "r",
356
+ "\n",
357
+ "i",
358
+ "m",
359
+ "p",
360
+ "o",
361
+ "r",
362
+ "t",
363
+ " ",
364
+ "n",
365
+ "u",
366
+ "m",
367
+ "p",
368
+ "y",
369
+ " ",
370
+ "a",
371
+ "s",
372
+ " ",
373
+ "n",
374
+ "p",
375
+ "\n",
376
+ "\n",
377
+ "c",
378
+ "l",
379
+ "a",
380
+ "s",
381
+ "s",
382
+ " ",
383
+ "T",
384
+ "u",
385
+ "n",
386
+ "i",
387
+ "s",
388
+ "i",
389
+ "a",
390
+ "n",
391
+ "D",
392
+ "a",
393
+ "t",
394
+ "a",
395
+ "s",
396
+ "e",
397
+ "t",
398
+ "(",
399
+ "D",
400
+ "a",
401
+ "t",
402
+ "a",
403
+ "s",
404
+ "e",
405
+ "t",
406
+ ")",
407
+ ":",
408
+ "\n",
409
+ " ",
410
+ " ",
411
+ " ",
412
+ " ",
413
+ "d",
414
+ "e",
415
+ "f",
416
+ " ",
417
+ "_",
418
+ "_",
419
+ "i",
420
+ "n",
421
+ "i",
422
+ "t",
423
+ "_",
424
+ "_",
425
+ "(",
426
+ "s",
427
+ "e",
428
+ "l",
429
+ "f",
430
+ ",",
431
+ " ",
432
+ "t",
433
+ "e",
434
+ "x",
435
+ "t",
436
+ "s",
437
+ ",",
438
+ " ",
439
+ "l",
440
+ "a",
441
+ "b",
442
+ "e",
443
+ "l",
444
+ "s",
445
+ ",",
446
+ " ",
447
+ "f",
448
+ "a",
449
+ "s",
450
+ "t",
451
+ "t",
452
+ "e",
453
+ "x",
454
+ "t",
455
+ "_",
456
+ "m",
457
+ "o",
458
+ "d",
459
+ "e",
460
+ "l",
461
+ ",",
462
+ " ",
463
+ "m",
464
+ "a",
465
+ "x",
466
+ "_",
467
+ "l",
468
+ "e",
469
+ "n",
470
+ "=",
471
+ "3",
472
+ "0",
473
+ ")",
474
+ ":",
475
+ "\n",
476
+ " ",
477
+ " ",
478
+ " ",
479
+ " ",
480
+ " ",
481
+ " ",
482
+ " ",
483
+ " ",
484
+ "s",
485
+ "e",
486
+ "l",
487
+ "f",
488
+ ".",
489
+ "t",
490
+ "e",
491
+ "x",
492
+ "t",
493
+ "s",
494
+ " ",
495
+ "=",
496
+ " ",
497
+ "t",
498
+ "e",
499
+ "x",
500
+ "t",
501
+ "s",
502
+ "\n",
503
+ " ",
504
+ " ",
505
+ " ",
506
+ " ",
507
+ " ",
508
+ " ",
509
+ " ",
510
+ " ",
511
+ "s",
512
+ "e",
513
+ "l",
514
+ "f",
515
+ ".",
516
+ "l",
517
+ "a",
518
+ "b",
519
+ "e",
520
+ "l",
521
+ "s",
522
+ " ",
523
+ "=",
524
+ " ",
525
+ "l",
526
+ "a",
527
+ "b",
528
+ "e",
529
+ "l",
530
+ "s",
531
+ "\n",
532
+ " ",
533
+ " ",
534
+ " ",
535
+ " ",
536
+ " ",
537
+ " ",
538
+ " ",
539
+ " ",
540
+ "s",
541
+ "e",
542
+ "l",
543
+ "f",
544
+ ".",
545
+ "f",
546
+ "t",
547
+ "_",
548
+ "m",
549
+ "o",
550
+ "d",
551
+ "e",
552
+ "l",
553
+ " ",
554
+ "=",
555
+ " ",
556
+ "f",
557
+ "a",
558
+ "s",
559
+ "t",
560
+ "t",
561
+ "e",
562
+ "x",
563
+ "t",
564
+ "_",
565
+ "m",
566
+ "o",
567
+ "d",
568
+ "e",
569
+ "l",
570
+ "\n",
571
+ " ",
572
+ " ",
573
+ " ",
574
+ " ",
575
+ " ",
576
+ " ",
577
+ " ",
578
+ " ",
579
+ "s",
580
+ "e",
581
+ "l",
582
+ "f",
583
+ ".",
584
+ "m",
585
+ "a",
586
+ "x",
587
+ "_",
588
+ "l",
589
+ "e",
590
+ "n",
591
+ " ",
592
+ "=",
593
+ " ",
594
+ "m",
595
+ "a",
596
+ "x",
597
+ "_",
598
+ "l",
599
+ "e",
600
+ "n",
601
+ "\n",
602
+ "\n",
603
+ " ",
604
+ " ",
605
+ " ",
606
+ " ",
607
+ "d",
608
+ "e",
609
+ "f",
610
+ " ",
611
+ "_",
612
+ "_",
613
+ "l",
614
+ "e",
615
+ "n",
616
+ "_",
617
+ "_",
618
+ "(",
619
+ "s",
620
+ "e",
621
+ "l",
622
+ "f",
623
+ ")",
624
+ ":",
625
+ "\n",
626
+ " ",
627
+ " ",
628
+ " ",
629
+ " ",
630
+ " ",
631
+ " ",
632
+ " ",
633
+ " ",
634
+ "r",
635
+ "e",
636
+ "t",
637
+ "u",
638
+ "r",
639
+ "n",
640
+ " ",
641
+ "l",
642
+ "e",
643
+ "n",
644
+ "(",
645
+ "s",
646
+ "e",
647
+ "l",
648
+ "f",
649
+ ".",
650
+ "t",
651
+ "e",
652
+ "x",
653
+ "t",
654
+ "s",
655
+ ")",
656
+ "\n",
657
+ "\n",
658
+ " ",
659
+ " ",
660
+ " ",
661
+ " ",
662
+ "d",
663
+ "e",
664
+ "f",
665
+ " ",
666
+ "_",
667
+ "_",
668
+ "g",
669
+ "e",
670
+ "t",
671
+ "i",
672
+ "t",
673
+ "e",
674
+ "m",
675
+ "_",
676
+ "_",
677
+ "(",
678
+ "s",
679
+ "e",
680
+ "l",
681
+ "f",
682
+ ",",
683
+ " ",
684
+ "i",
685
+ "d",
686
+ "x",
687
+ ")",
688
+ ":",
689
+ "\n",
690
+ " ",
691
+ " ",
692
+ " ",
693
+ " ",
694
+ " ",
695
+ " ",
696
+ " ",
697
+ " ",
698
+ "t",
699
+ "e",
700
+ "x",
701
+ "t",
702
+ " ",
703
+ "=",
704
+ " ",
705
+ "s",
706
+ "t",
707
+ "r",
708
+ "(",
709
+ "s",
710
+ "e",
711
+ "l",
712
+ "f",
713
+ ".",
714
+ "t",
715
+ "e",
716
+ "x",
717
+ "t",
718
+ "s",
719
+ "[",
720
+ "i",
721
+ "d",
722
+ "x",
723
+ "]",
724
+ ")",
725
+ "\n",
726
+ " ",
727
+ " ",
728
+ " ",
729
+ " ",
730
+ " ",
731
+ " ",
732
+ " ",
733
+ " ",
734
+ "l",
735
+ "a",
736
+ "b",
737
+ "e",
738
+ "l",
739
+ " ",
740
+ "=",
741
+ " ",
742
+ "s",
743
+ "e",
744
+ "l",
745
+ "f",
746
+ ".",
747
+ "l",
748
+ "a",
749
+ "b",
750
+ "e",
751
+ "l",
752
+ "s",
753
+ "[",
754
+ "i",
755
+ "d",
756
+ "x",
757
+ "]",
758
+ "\n",
759
+ "\n",
760
+ " ",
761
+ " ",
762
+ " ",
763
+ " ",
764
+ " ",
765
+ " ",
766
+ " ",
767
+ " ",
768
+ "#",
769
+ " ",
770
+ "V",
771
+ "e",
772
+ "c",
773
+ "t",
774
+ "o",
775
+ "r",
776
+ "i",
777
+ "z",
778
+ "e",
779
+ " ",
780
+ "o",
781
+ "n",
782
+ " ",
783
+ "t",
784
+ "h",
785
+ "e",
786
+ " ",
787
+ "f",
788
+ "l",
789
+ "y",
790
+ "\n",
791
+ " ",
792
+ " ",
793
+ " ",
794
+ " ",
795
+ " ",
796
+ " ",
797
+ " ",
798
+ " ",
799
+ "w",
800
+ "o",
801
+ "r",
802
+ "d",
803
+ "s",
804
+ " ",
805
+ "=",
806
+ " ",
807
+ "t",
808
+ "e",
809
+ "x",
810
+ "t",
811
+ ".",
812
+ "s",
813
+ "p",
814
+ "l",
815
+ "i",
816
+ "t",
817
+ "(",
818
+ ")",
819
+ "[",
820
+ ":",
821
+ "s",
822
+ "e",
823
+ "l",
824
+ "f",
825
+ ".",
826
+ "m",
827
+ "a",
828
+ "x",
829
+ "_",
830
+ "l",
831
+ "e",
832
+ "n",
833
+ "]",
834
+ "\n",
835
+ " ",
836
+ " ",
837
+ " ",
838
+ " ",
839
+ " ",
840
+ " ",
841
+ " ",
842
+ " ",
843
+ "v",
844
+ "e",
845
+ "c",
846
+ "s",
847
+ " ",
848
+ "=",
849
+ " ",
850
+ "[",
851
+ "s",
852
+ "e",
853
+ "l",
854
+ "f",
855
+ ".",
856
+ "f",
857
+ "t",
858
+ "_",
859
+ "m",
860
+ "o",
861
+ "d",
862
+ "e",
863
+ "l",
864
+ ".",
865
+ "g",
866
+ "e",
867
+ "t",
868
+ "_",
869
+ "w",
870
+ "o",
871
+ "r",
872
+ "d",
873
+ "_",
874
+ "v",
875
+ "e",
876
+ "c",
877
+ "t",
878
+ "o",
879
+ "r",
880
+ "(",
881
+ "w",
882
+ ")",
883
+ " ",
884
+ "f",
885
+ "o",
886
+ "r",
887
+ " ",
888
+ "w",
889
+ " ",
890
+ "i",
891
+ "n",
892
+ " ",
893
+ "w",
894
+ "o",
895
+ "r",
896
+ "d",
897
+ "s",
898
+ "]",
899
+ "\n",
900
+ "\n",
901
+ " ",
902
+ " ",
903
+ " ",
904
+ " ",
905
+ " ",
906
+ " ",
907
+ " ",
908
+ " ",
909
+ "#",
910
+ " ",
911
+ "P",
912
+ "a",
913
+ "d",
914
+ "d",
915
+ "i",
916
+ "n",
917
+ "g",
918
+ "\n",
919
+ " ",
920
+ " ",
921
+ " ",
922
+ " ",
923
+ " ",
924
+ " ",
925
+ " ",
926
+ " ",
927
+ "i",
928
+ "f",
929
+ " ",
930
+ "l",
931
+ "e",
932
+ "n",
933
+ "(",
934
+ "v",
935
+ "e",
936
+ "c",
937
+ "s",
938
+ ")",
939
+ " ",
940
+ "<",
941
+ " ",
942
+ "s",
943
+ "e",
944
+ "l",
945
+ "f",
946
+ ".",
947
+ "m",
948
+ "a",
949
+ "x",
950
+ "_",
951
+ "l",
952
+ "e",
953
+ "n",
954
+ ":",
955
+ "\n",
956
+ " ",
957
+ " ",
958
+ " ",
959
+ " ",
960
+ " ",
961
+ " ",
962
+ " ",
963
+ " ",
964
+ " ",
965
+ " ",
966
+ " ",
967
+ " ",
968
+ "p",
969
+ "a",
970
+ "d",
971
+ "d",
972
+ "i",
973
+ "n",
974
+ "g",
975
+ " ",
976
+ "=",
977
+ " ",
978
+ "[",
979
+ "n",
980
+ "p",
981
+ ".",
982
+ "z",
983
+ "e",
984
+ "r",
985
+ "o",
986
+ "s",
987
+ "(",
988
+ "2",
989
+ "5",
990
+ "6",
991
+ ")",
992
+ " ",
993
+ "f",
994
+ "o",
995
+ "r",
996
+ " ",
997
+ "_",
998
+ " ",
999
+ "i",
1000
+ "n",
1001
+ " ",
1002
+ "r",
1003
+ "a",
1004
+ "n",
1005
+ "g",
1006
+ "e",
1007
+ "(",
1008
+ "s",
1009
+ "e",
1010
+ "l",
1011
+ "f",
1012
+ ".",
1013
+ "m",
1014
+ "a",
1015
+ "x",
1016
+ "_",
1017
+ "l",
1018
+ "e",
1019
+ "n",
1020
+ " ",
1021
+ "-",
1022
+ " ",
1023
+ "l",
1024
+ "e",
1025
+ "n",
1026
+ "(",
1027
+ "v",
1028
+ "e",
1029
+ "c",
1030
+ "s",
1031
+ ")",
1032
+ ")",
1033
+ "]",
1034
+ "\n",
1035
+ " ",
1036
+ " ",
1037
+ " ",
1038
+ " ",
1039
+ " ",
1040
+ " ",
1041
+ " ",
1042
+ " ",
1043
+ " ",
1044
+ " ",
1045
+ " ",
1046
+ " ",
1047
+ "v",
1048
+ "e",
1049
+ "c",
1050
+ "s",
1051
+ ".",
1052
+ "e",
1053
+ "x",
1054
+ "t",
1055
+ "e",
1056
+ "n",
1057
+ "d",
1058
+ "(",
1059
+ "p",
1060
+ "a",
1061
+ "d",
1062
+ "d",
1063
+ "i",
1064
+ "n",
1065
+ "g",
1066
+ ")",
1067
+ "\n",
1068
+ "\n",
1069
+ " ",
1070
+ " ",
1071
+ " ",
1072
+ " ",
1073
+ " ",
1074
+ " ",
1075
+ " ",
1076
+ " ",
1077
+ "r",
1078
+ "e",
1079
+ "t",
1080
+ "u",
1081
+ "r",
1082
+ "n",
1083
+ " ",
1084
+ "t",
1085
+ "o",
1086
+ "r",
1087
+ "c",
1088
+ "h",
1089
+ ".",
1090
+ "t",
1091
+ "e",
1092
+ "n",
1093
+ "s",
1094
+ "o",
1095
+ "r",
1096
+ "(",
1097
+ "n",
1098
+ "p",
1099
+ ".",
1100
+ "a",
1101
+ "r",
1102
+ "r",
1103
+ "a",
1104
+ "y",
1105
+ "(",
1106
+ "v",
1107
+ "e",
1108
+ "c",
1109
+ "s",
1110
+ ")",
1111
+ ",",
1112
+ " ",
1113
+ "d",
1114
+ "t",
1115
+ "y",
1116
+ "p",
1117
+ "e",
1118
+ "=",
1119
+ "t",
1120
+ "o",
1121
+ "r",
1122
+ "c",
1123
+ "h",
1124
+ ".",
1125
+ "f",
1126
+ "l",
1127
+ "o",
1128
+ "a",
1129
+ "t",
1130
+ "3",
1131
+ "2",
1132
+ ")",
1133
+ ",",
1134
+ " ",
1135
+ "t",
1136
+ "o",
1137
+ "r",
1138
+ "c",
1139
+ "h",
1140
+ ".",
1141
+ "t",
1142
+ "e",
1143
+ "n",
1144
+ "s",
1145
+ "o",
1146
+ "r",
1147
+ "(",
1148
+ "[",
1149
+ "l",
1150
+ "a",
1151
+ "b",
1152
+ "e",
1153
+ "l",
1154
+ "]",
1155
+ ",",
1156
+ " ",
1157
+ "d",
1158
+ "t",
1159
+ "y",
1160
+ "p",
1161
+ "e",
1162
+ "=",
1163
+ "t",
1164
+ "o",
1165
+ "r",
1166
+ "c",
1167
+ "h",
1168
+ ".",
1169
+ "f",
1170
+ "l",
1171
+ "o",
1172
+ "a",
1173
+ "t",
1174
+ "3",
1175
+ "2",
1176
+ ")",
1177
+ "\n",
1178
+ "\n",
1179
+ "#",
1180
+ " ",
1181
+ "I",
1182
+ "m",
1183
+ "p",
1184
+ "l",
1185
+ "e",
1186
+ "m",
1187
+ "e",
1188
+ "n",
1189
+ "t",
1190
+ "a",
1191
+ "t",
1192
+ "i",
1193
+ "o",
1194
+ "n",
1195
+ "\n",
1196
+ "X",
1197
+ "_",
1198
+ "t",
1199
+ "e",
1200
+ "x",
1201
+ "t",
1202
+ " ",
1203
+ "=",
1204
+ " ",
1205
+ "c",
1206
+ "o",
1207
+ "r",
1208
+ "p",
1209
+ "u",
1210
+ "s",
1211
+ "_",
1212
+ "f",
1213
+ "u",
1214
+ "l",
1215
+ "l",
1216
+ "[",
1217
+ "'",
1218
+ "t",
1219
+ "e",
1220
+ "x",
1221
+ "t",
1222
+ "'",
1223
+ "]",
1224
+ ".",
1225
+ "v",
1226
+ "a",
1227
+ "l",
1228
+ "u",
1229
+ "e",
1230
+ "s",
1231
+ "\n",
1232
+ "y",
1233
+ " ",
1234
+ "=",
1235
+ " ",
1236
+ "(",
1237
+ "c",
1238
+ "o",
1239
+ "r",
1240
+ "p",
1241
+ "u",
1242
+ "s",
1243
+ "_",
1244
+ "f",
1245
+ "u",
1246
+ "l",
1247
+ "l",
1248
+ "[",
1249
+ "'",
1250
+ "l",
1251
+ "a",
1252
+ "b",
1253
+ "e",
1254
+ "l",
1255
+ "'",
1256
+ "]",
1257
+ ".",
1258
+ "v",
1259
+ "a",
1260
+ "l",
1261
+ "u",
1262
+ "e",
1263
+ "s",
1264
+ ".",
1265
+ "a",
1266
+ "s",
1267
+ "t",
1268
+ "y",
1269
+ "p",
1270
+ "e",
1271
+ "(",
1272
+ "f",
1273
+ "l",
1274
+ "o",
1275
+ "a",
1276
+ "t",
1277
+ ")",
1278
+ " ",
1279
+ "+",
1280
+ " ",
1281
+ "1",
1282
+ ")",
1283
+ " ",
1284
+ "/",
1285
+ " ",
1286
+ "2",
1287
+ "\n",
1288
+ "\n",
1289
+ "f",
1290
+ "r",
1291
+ "o",
1292
+ "m",
1293
+ " ",
1294
+ "s",
1295
+ "k",
1296
+ "l",
1297
+ "e",
1298
+ "a",
1299
+ "r",
1300
+ "n",
1301
+ ".",
1302
+ "m",
1303
+ "o",
1304
+ "d",
1305
+ "e",
1306
+ "l",
1307
+ "_",
1308
+ "s",
1309
+ "e",
1310
+ "l",
1311
+ "e",
1312
+ "c",
1313
+ "t",
1314
+ "i",
1315
+ "o",
1316
+ "n",
1317
+ " ",
1318
+ "i",
1319
+ "m",
1320
+ "p",
1321
+ "o",
1322
+ "r",
1323
+ "t",
1324
+ " ",
1325
+ "t",
1326
+ "r",
1327
+ "a",
1328
+ "i",
1329
+ "n",
1330
+ "_",
1331
+ "t",
1332
+ "e",
1333
+ "s",
1334
+ "t",
1335
+ "_",
1336
+ "s",
1337
+ "p",
1338
+ "l",
1339
+ "i",
1340
+ "t",
1341
+ "\n",
1342
+ "t",
1343
+ "r",
1344
+ "a",
1345
+ "i",
1346
+ "n",
1347
+ "_",
1348
+ "t",
1349
+ "x",
1350
+ "t",
1351
+ ",",
1352
+ " ",
1353
+ "v",
1354
+ "a",
1355
+ "l",
1356
+ "_",
1357
+ "t",
1358
+ "x",
1359
+ "t",
1360
+ ",",
1361
+ " ",
1362
+ "t",
1363
+ "r",
1364
+ "a",
1365
+ "i",
1366
+ "n",
1367
+ "_",
1368
+ "y",
1369
+ ",",
1370
+ " ",
1371
+ "v",
1372
+ "a",
1373
+ "l",
1374
+ "_",
1375
+ "y",
1376
+ " ",
1377
+ "=",
1378
+ " ",
1379
+ "t",
1380
+ "r",
1381
+ "a",
1382
+ "i",
1383
+ "n",
1384
+ "_",
1385
+ "t",
1386
+ "e",
1387
+ "s",
1388
+ "t",
1389
+ "_",
1390
+ "s",
1391
+ "p",
1392
+ "l",
1393
+ "i",
1394
+ "t",
1395
+ "(",
1396
+ "X",
1397
+ "_",
1398
+ "t",
1399
+ "e",
1400
+ "x",
1401
+ "t",
1402
+ ",",
1403
+ " ",
1404
+ "y",
1405
+ ",",
1406
+ " ",
1407
+ "t",
1408
+ "e",
1409
+ "s",
1410
+ "t",
1411
+ "_",
1412
+ "s",
1413
+ "i",
1414
+ "z",
1415
+ "e",
1416
+ "=",
1417
+ "0",
1418
+ ".",
1419
+ "2",
1420
+ ",",
1421
+ " ",
1422
+ "r",
1423
+ "a",
1424
+ "n",
1425
+ "d",
1426
+ "o",
1427
+ "m",
1428
+ "_",
1429
+ "s",
1430
+ "t",
1431
+ "a",
1432
+ "t",
1433
+ "e",
1434
+ "=",
1435
+ "4",
1436
+ "2",
1437
+ ")",
1438
+ "\n",
1439
+ "\n",
1440
+ "#",
1441
+ " ",
1442
+ "C",
1443
+ "r",
1444
+ "e",
1445
+ "a",
1446
+ "t",
1447
+ "e",
1448
+ " ",
1449
+ "t",
1450
+ "h",
1451
+ "e",
1452
+ " ",
1453
+ "l",
1454
+ "a",
1455
+ "z",
1456
+ "y",
1457
+ " ",
1458
+ "l",
1459
+ "o",
1460
+ "a",
1461
+ "d",
1462
+ "e",
1463
+ "r",
1464
+ "s",
1465
+ "\n",
1466
+ "t",
1467
+ "r",
1468
+ "a",
1469
+ "i",
1470
+ "n",
1471
+ "_",
1472
+ "d",
1473
+ "s",
1474
+ " ",
1475
+ "=",
1476
+ " ",
1477
+ "T",
1478
+ "u",
1479
+ "n",
1480
+ "i",
1481
+ "s",
1482
+ "i",
1483
+ "a",
1484
+ "n",
1485
+ "D",
1486
+ "a",
1487
+ "t",
1488
+ "a",
1489
+ "s",
1490
+ "e",
1491
+ "t",
1492
+ "(",
1493
+ "t",
1494
+ "r",
1495
+ "a",
1496
+ "i",
1497
+ "n",
1498
+ "_",
1499
+ "t",
1500
+ "x",
1501
+ "t",
1502
+ ",",
1503
+ " ",
1504
+ "t",
1505
+ "r",
1506
+ "a",
1507
+ "i",
1508
+ "n",
1509
+ "_",
1510
+ "y",
1511
+ ",",
1512
+ " ",
1513
+ "m",
1514
+ "o",
1515
+ "d",
1516
+ "e",
1517
+ "l",
1518
+ ")",
1519
+ " ",
1520
+ "#",
1521
+ " ",
1522
+ "'",
1523
+ "m",
1524
+ "o",
1525
+ "d",
1526
+ "e",
1527
+ "l",
1528
+ "'",
1529
+ " ",
1530
+ "i",
1531
+ "s",
1532
+ " ",
1533
+ "y",
1534
+ "o",
1535
+ "u",
1536
+ "r",
1537
+ " ",
1538
+ "F",
1539
+ "a",
1540
+ "s",
1541
+ "t",
1542
+ "T",
1543
+ "e",
1544
+ "x",
1545
+ "t",
1546
+ " ",
1547
+ "o",
1548
+ "b",
1549
+ "j",
1550
+ "e",
1551
+ "c",
1552
+ "t",
1553
+ "\n",
1554
+ "v",
1555
+ "a",
1556
+ "l",
1557
+ "_",
1558
+ "d",
1559
+ "s",
1560
+ " ",
1561
+ "=",
1562
+ " ",
1563
+ "T",
1564
+ "u",
1565
+ "n",
1566
+ "i",
1567
+ "s",
1568
+ "i",
1569
+ "a",
1570
+ "n",
1571
+ "D",
1572
+ "a",
1573
+ "t",
1574
+ "a",
1575
+ "s",
1576
+ "e",
1577
+ "t",
1578
+ "(",
1579
+ "v",
1580
+ "a",
1581
+ "l",
1582
+ "_",
1583
+ "t",
1584
+ "x",
1585
+ "t",
1586
+ ",",
1587
+ " ",
1588
+ "v",
1589
+ "a",
1590
+ "l",
1591
+ "_",
1592
+ "y",
1593
+ ",",
1594
+ " ",
1595
+ "m",
1596
+ "o",
1597
+ "d",
1598
+ "e",
1599
+ "l",
1600
+ ")",
1601
+ "\n",
1602
+ "\n",
1603
+ "t",
1604
+ "r",
1605
+ "a",
1606
+ "i",
1607
+ "n",
1608
+ "_",
1609
+ "l",
1610
+ "o",
1611
+ "a",
1612
+ "d",
1613
+ "e",
1614
+ "r",
1615
+ " ",
1616
+ "=",
1617
+ " ",
1618
+ "D",
1619
+ "a",
1620
+ "t",
1621
+ "a",
1622
+ "L",
1623
+ "o",
1624
+ "a",
1625
+ "d",
1626
+ "e",
1627
+ "r",
1628
+ "(",
1629
+ "t",
1630
+ "r",
1631
+ "a",
1632
+ "i",
1633
+ "n",
1634
+ "_",
1635
+ "d",
1636
+ "s",
1637
+ ",",
1638
+ " ",
1639
+ "b",
1640
+ "a",
1641
+ "t",
1642
+ "c",
1643
+ "h",
1644
+ "_",
1645
+ "s",
1646
+ "i",
1647
+ "z",
1648
+ "e",
1649
+ "=",
1650
+ "3",
1651
+ "2",
1652
+ ",",
1653
+ " ",
1654
+ "s",
1655
+ "h",
1656
+ "u",
1657
+ "f",
1658
+ "f",
1659
+ "l",
1660
+ "e",
1661
+ "=",
1662
+ "T",
1663
+ "r",
1664
+ "u",
1665
+ "e",
1666
+ ",",
1667
+ " ",
1668
+ "n",
1669
+ "u",
1670
+ "m",
1671
+ "_",
1672
+ "w",
1673
+ "o",
1674
+ "r",
1675
+ "k",
1676
+ "e",
1677
+ "r",
1678
+ "s",
1679
+ "=",
1680
+ "2",
1681
+ ")",
1682
+ "\n",
1683
+ "v",
1684
+ "a",
1685
+ "l",
1686
+ "_",
1687
+ "l",
1688
+ "o",
1689
+ "a",
1690
+ "d",
1691
+ "e",
1692
+ "r",
1693
+ " ",
1694
+ "=",
1695
+ " ",
1696
+ "D",
1697
+ "a",
1698
+ "t",
1699
+ "a",
1700
+ "L",
1701
+ "o",
1702
+ "a",
1703
+ "d",
1704
+ "e",
1705
+ "r",
1706
+ "(",
1707
+ "v",
1708
+ "a",
1709
+ "l",
1710
+ "_",
1711
+ "d",
1712
+ "s",
1713
+ ",",
1714
+ " ",
1715
+ "b",
1716
+ "a",
1717
+ "t",
1718
+ "c",
1719
+ "h",
1720
+ "_",
1721
+ "s",
1722
+ "i",
1723
+ "z",
1724
+ "e",
1725
+ "=",
1726
+ "3",
1727
+ "2",
1728
+ ",",
1729
+ " ",
1730
+ "s",
1731
+ "h",
1732
+ "u",
1733
+ "f",
1734
+ "f",
1735
+ "l",
1736
+ "e",
1737
+ "=",
1738
+ "F",
1739
+ "a",
1740
+ "l",
1741
+ "s",
1742
+ "e",
1743
+ ",",
1744
+ " ",
1745
+ "n",
1746
+ "u",
1747
+ "m",
1748
+ "_",
1749
+ "w",
1750
+ "o",
1751
+ "r",
1752
+ "k",
1753
+ "e",
1754
+ "r",
1755
+ "s",
1756
+ "=",
1757
+ "0",
1758
+ ")"
1759
+ ],
1760
+ "execution_count": 12,
1761
+ "outputs": []
1762
+ },
1763
+ {
1764
+ "cell_type": "code",
1765
+ "source": [
1766
+ "class EarlyStopping:\n",
1767
+ " def __init__(self, patience=5, delta=0):\n",
1768
+ " self.patience = patience\n",
1769
+ " self.delta = delta\n",
1770
+ " self.best_score = None\n",
1771
+ " self.early_stop = False\n",
1772
+ " self.counter = 0\n",
1773
+ " self.best_model_state = None\n",
1774
+ "\n",
1775
+ " def __call__(self, val_loss, model):\n",
1776
+ " score = -val_loss\n",
1777
+ "\n",
1778
+ " if self.best_score is None:\n",
1779
+ " self.best_score = score\n",
1780
+ " self.best_model_state = model.state_dict()\n",
1781
+ " elif score < self.best_score + self.delta:\n",
1782
+ " self.counter += 1\n",
1783
+ " if self.counter >= self.patience:\n",
1784
+ " self.early_stop = True\n",
1785
+ " else:\n",
1786
+ " self.best_score = score\n",
1787
+ " self.best_model_state = model.state_dict()\n",
1788
+ " self.counter = 0\n",
1789
+ "\n",
1790
+ " def load_best_model(self, model):\n",
1791
+ " model.load_state_dict(self.best_model_state)"
1792
+ ],
1793
+ "metadata": {
1794
+ "id": "ChlmWz3nEOTZ"
1795
+ },
1796
+ "execution_count": 7,
1797
+ "outputs": []
1798
+ },
1799
+ {
1800
+ "cell_type": "code",
1801
+ "source": [
1802
+ "def label_smoothed_bce(outputs, targets, smoothing=0.15):\n",
1803
+ " \"\"\"Softens 0\u21920.075 and 1\u21920.925 \u2014 stronger regularisation than 0.1.\"\"\"\n",
1804
+ " with torch.no_grad():\n",
1805
+ " targets = targets * (1 - smoothing) + 0.5 * smoothing\n",
1806
+ " return nn.BCEWithLogitsLoss()(outputs, targets)\n"
1807
+ ],
1808
+ "metadata": {
1809
+ "id": "pOjRjrd1Fbqc"
1810
+ },
1811
+ "execution_count": 8,
1812
+ "outputs": []
1813
+ },
1814
+ {
1815
+ "cell_type": "code",
1816
+ "source": [
1817
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1818
+ "\n",
1819
+ "# v3: hidden_size=48, dropout=0.4\n",
1820
+ "net = Asfour(input_dim=256, hidden_size=48, num_classes=1, dropout=0.4).to(device)\n",
1821
+ "\n",
1822
+ "# Stronger weight decay (0.01 \u2192 0.05) to penalise large weights harder\n",
1823
+ "optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=0.05)\n",
1824
+ "\n",
1825
+ "NUM_EPOCHS = 30\n",
1826
+ "scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
1827
+ " optimizer,\n",
1828
+ " max_lr=1e-3,\n",
1829
+ " steps_per_epoch=len(train_loader),\n",
1830
+ " epochs=NUM_EPOCHS,\n",
1831
+ " pct_start=0.15, # slightly longer warmup\n",
1832
+ " anneal_strategy=\"cos\",\n",
1833
+ ")\n",
1834
+ "\n",
1835
+ "# Patience bumped to 7 \u2014 gives the model more room under strong regularisation\n",
1836
+ "es = EarlyStopping(patience=7, delta=0.002)\n",
1837
+ "\n",
1838
+ "for epoch in range(NUM_EPOCHS):\n",
1839
+ " net.train()\n",
1840
+ " train_loss = 0\n",
1841
+ " for batch_x, batch_y in train_loader:\n",
1842
+ " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
1843
+ " optimizer.zero_grad()\n",
1844
+ " outputs = net(batch_x)\n",
1845
+ " loss = label_smoothed_bce(outputs, batch_y)\n",
1846
+ " loss.backward()\n",
1847
+ " torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)\n",
1848
+ " optimizer.step()\n",
1849
+ " scheduler.step() # OneCycleLR steps per batch\n",
1850
+ " train_loss += loss.item()\n",
1851
+ "\n",
1852
+ " net.eval()\n",
1853
+ " val_loss = 0\n",
1854
+ " with torch.no_grad():\n",
1855
+ " for batch_x, batch_y in val_loader:\n",
1856
+ " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
1857
+ " outputs = net(batch_x)\n",
1858
+ " v_loss = nn.BCEWithLogitsLoss()(outputs, batch_y)\n",
1859
+ " val_loss += v_loss.item()\n",
1860
+ "\n",
1861
+ " avg_train = train_loss / len(train_loader)\n",
1862
+ " avg_val = val_loss / len(val_loader)\n",
1863
+ "\n",
1864
+ " print(f\"Epoch {epoch+1:02d}/{NUM_EPOCHS}: \"\n",
1865
+ " f\"Train {avg_train:.4f} Val {avg_val:.4f} \"\n",
1866
+ " f\"Gap {avg_val - avg_train:+.4f} \"\n",
1867
+ " f\"LR {optimizer.param_groups[0]['lr']:.2e}\")\n",
1868
+ "\n",
1869
+ " es(avg_val, net)\n",
1870
+ " if es.early_stop:\n",
1871
+ " print(\"Early stopping \u2014 loading best checkpoint\")\n",
1872
+ " es.load_best_model(net)\n",
1873
+ " break\n"
1874
+ ],
1875
+ "metadata": {
1876
+ "colab": {
1877
+ "base_uri": "https://localhost:8080/"
1878
+ },
1879
+ "id": "-QykeL9J__rL",
1880
+ "outputId": "19c2f7b1-c256-4aa4-ca12-9e5b96ec52b3"
1881
+ },
1882
+ "execution_count": 9,
1883
+ "outputs": [
1884
+ {
1885
+ "output_type": "stream",
1886
+ "name": "stdout",
1887
+ "text": [
1888
+ "Epoch 01/20: Train 0.5372 Val 0.3916 Gap -0.1456 LR 5.20e-04\n",
1889
+ "Epoch 02/20: Train 0.4560 Val 0.3694 Gap -0.0866 LR 1.00e-03\n",
1890
+ "Epoch 03/20: Train 0.4334 Val 0.3507 Gap -0.0827 LR 9.92e-04\n",
1891
+ "Epoch 04/20: Train 0.4195 Val 0.3435 Gap -0.0760 LR 9.70e-04\n",
1892
+ "Epoch 05/20: Train 0.4082 Val 0.3458 Gap -0.0624 LR 9.33e-04\n",
1893
+ "Epoch 06/20: Train 0.3990 Val 0.3487 Gap -0.0503 LR 8.83e-04\n",
1894
+ "Epoch 07/20: Train 0.3891 Val 0.3431 Gap -0.0460 LR 8.21e-04\n",
1895
+ "Epoch 08/20: Train 0.3800 Val 0.3416 Gap -0.0384 LR 7.50e-04\n",
1896
+ "Epoch 09/20: Train 0.3715 Val 0.3430 Gap -0.0285 LR 6.71e-04\n",
1897
+ "Early stopping \u2014 loading best checkpoint\n"
1898
+ ]
1899
+ }
1900
+ ]
1901
+ },
1902
+ {
1903
+ "cell_type": "code",
1904
+ "source": [
1905
+ "from sklearn.metrics import classification_report, roc_auc_score\n",
1906
+ "\n",
1907
+ "net.eval()\n",
1908
+ "all_preds, all_labels = [], []\n",
1909
+ "\n",
1910
+ "with torch.no_grad():\n",
1911
+ " for batch_x, batch_y in val_loader:\n",
1912
+ " batch_x = batch_x.to(device)\n",
1913
+ " outputs = torch.sigmoid(net(batch_x))\n",
1914
+ " preds = (outputs.cpu() > 0.5).float()\n",
1915
+ " all_preds.extend(preds.view(-1).tolist())\n",
1916
+ " all_labels.extend(batch_y.view(-1).tolist())\n",
1917
+ "\n",
1918
+ "print(classification_report(all_labels, all_preds, target_names=[\"clean\", \"profane\"]))\n",
1919
+ "print(\"AUC:\", roc_auc_score(all_labels, all_preds))"
1920
+ ],
1921
+ "metadata": {
1922
+ "id": "cB08BYR6UA2Q",
1923
+ "colab": {
1924
+ "base_uri": "https://localhost:8080/"
1925
+ },
1926
+ "outputId": "bf4aa0af-0a5b-4840-d53c-192810ff4caa"
1927
+ },
1928
+ "execution_count": 10,
1929
+ "outputs": [
1930
+ {
1931
+ "output_type": "stream",
1932
+ "name": "stdout",
1933
+ "text": [
1934
+ " precision recall f1-score support\n",
1935
+ "\n",
1936
+ " clean 0.82 0.85 0.83 4281\n",
1937
+ " profane 0.88 0.85 0.87 5461\n",
1938
+ "\n",
1939
+ " accuracy 0.85 9742\n",
1940
+ " macro avg 0.85 0.85 0.85 9742\n",
1941
+ "weighted avg 0.85 0.85 0.85 9742\n",
1942
+ "\n",
1943
+ "AUC: 0.8510448534833718\n"
1944
+ ]
1945
+ }
1946
+ ]
1947
+ },
1948
+ {
1949
+ "cell_type": "code",
1950
+ "metadata": {
1951
+ "colab": {
1952
+ "base_uri": "https://localhost:8080/",
1953
+ "height": 228
1954
+ },
1955
+ "id": "31b32c69",
1956
+ "outputId": "5a467d40-7cf9-40ca-8418-2768aa16e430"
1957
+ },
1958
+ "source": [
1959
+ "from sklearn.metrics import accuracy_score\n",
1960
+ "\n",
1961
+ "# Final Evaluation\n",
1962
+ "net.eval()\n",
1963
+ "probs = []\n",
1964
+ "with torch.no_grad():\n",
1965
+ " for batch_x, _ in val_loader:\n",
1966
+ " probs.extend(torch.sigmoid(net(batch_x.to(device))).cpu().numpy())\n",
1967
+ "\n",
1968
+ "# Try to find the best threshold to maximize accuracy\n",
1969
+ "thresholds = np.linspace(0.1, 0.9, 81)\n",
1970
+ "accuracies = [accuracy_score(all_val_labels, (np.array(probs) > t).astype(int)) for t in thresholds]\n",
1971
+ "best_t = thresholds[np.argmax(accuracies)]\n",
1972
+ "\n",
1973
+ "print(f\"Best Threshold: {best_t:.2f}\")\n",
1974
+ "print(f\"Final Accuracy at Best Threshold: {max(accuracies):.4f}\")\n",
1975
+ "\n",
1976
+ "# Final Report\n",
1977
+ "final_preds = (np.array(probs) > best_t).astype(int)\n",
1978
+ "print(classification_report(all_val_labels, final_preds, target_names=[\"clean\", \"profane\"]))"
1979
+ ],
1980
+ "execution_count": 11,
1981
+ "outputs": [
1982
+ {
1983
+ "output_type": "error",
1984
+ "ename": "NameError",
1985
+ "evalue": "name 'all_val_labels' is not defined",
1986
+ "traceback": [
1987
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1988
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
1989
+ "\u001b[0;32m/tmp/ipykernel_1236/2364656557.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# Try to find the best threshold to maximize accuracy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mthresholds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m81\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0maccuracies\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0maccuracy_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_val_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthresholds\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mbest_t\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mthresholds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracies\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
1990
+ "\u001b[0;31mNameError\u001b[0m: name 'all_val_labels' is not defined"
1991
+ ]
1992
+ }
1993
+ ]
1994
+ },
1995
+ {
1996
+ "cell_type": "code",
1997
+ "source": [
1998
+ "def count_params(model):\n",
1999
+ " total = sum(p.numel() for p in model.parameters())\n",
2000
+ " trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
2001
+ " print(f\"Total params: {total:,}\")\n",
2002
+ " print(f\"Trainable params: {trainable:,}\")"
2003
+ ],
2004
+ "metadata": {
2005
+ "id": "V6X-oV-qXKvu"
2006
+ },
2007
+ "execution_count": null,
2008
+ "outputs": []
2009
+ },
2010
+ {
2011
+ "cell_type": "code",
2012
+ "source": [
2013
+ "count_params(net)"
2014
+ ],
2015
+ "metadata": {
2016
+ "colab": {
2017
+ "base_uri": "https://localhost:8080/"
2018
+ },
2019
+ "id": "12esUm_uXK-W",
2020
+ "outputId": "e845b9f5-0086-4118-d4fb-73fc5dc030a1"
2021
+ },
2022
+ "execution_count": null,
2023
+ "outputs": [
2024
+ {
2025
+ "output_type": "stream",
2026
+ "name": "stdout",
2027
+ "text": [
2028
+ "Total params: 2,433,281\n",
2029
+ "Trainable params: 2,433,281\n"
2030
+ ]
2031
+ }
2032
+ ]
2033
+ },
2034
+ {
2035
+ "cell_type": "code",
2036
+ "source": [
2037
+ "# \u2500\u2500 Inference helper (fixed: uses word vectors, not sentence vector) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
2038
+ "def predict(text, ft_model, net, device, max_len=30, threshold=0.5):\n",
2039
+ " \"\"\"Correctly feeds per-word fastText vectors into Asfour.\"\"\"\n",
2040
+ " words = text.split()[:max_len]\n",
2041
+ " vecs = [ft_model.get_word_vector(w) for w in words]\n",
2042
+ " if len(vecs) < max_len:\n",
2043
+ " import numpy as np\n",
2044
+ " vecs += [np.zeros(256)] * (max_len - len(vecs))\n",
2045
+ " import numpy as np\n",
2046
+ " x = torch.tensor(np.array(vecs), dtype=torch.float32) # [T, 256]\n",
2047
+ " x = x.unsqueeze(0).to(device) # [1, T, 256]\n",
2048
+ " net.eval()\n",
2049
+ " with torch.no_grad():\n",
2050
+ " prob = torch.sigmoid(net(x)).item()\n",
2051
+ " label = 'profane' if prob > threshold else 'clean'\n",
2052
+ " print(f\"Text : '{text}'\")\n",
2053
+ " print(f\"Prob : {prob:.3f} \u2192 {label}\")\n",
2054
+ " return prob\n",
2055
+ "\n",
2056
+ "\n",
2057
+ "sample_text = \"3asba lik ye bechirMNST\"\n",
2058
+ "predict(sample_text, model, net, device)\n"
2059
+ ],
2060
+ "metadata": {
2061
+ "colab": {
2062
+ "base_uri": "https://localhost:8080/"
2063
+ },
2064
+ "id": "GfY49fuEXTWp",
2065
+ "outputId": "c6617fc4-9377-4b6f-dd86-e748274415d1"
2066
+ },
2067
+ "execution_count": null,
2068
+ "outputs": [
2069
+ {
2070
+ "output_type": "stream",
2071
+ "name": "stdout",
2072
+ "text": [
2073
+ "\n",
2074
+ "Text: '3asba lik ye bechirMNST'\n",
2075
+ "Profanity probability: 0.010\n",
2076
+ "Prediction: profane\n"
2077
+ ]
2078
+ }
2079
+ ]
2080
+ },
2081
+ {
2082
+ "cell_type": "code",
2083
+ "source": [],
2084
+ "metadata": {
2085
+ "id": "7Zvj40YIXWfb"
2086
+ },
2087
+ "execution_count": null,
2088
+ "outputs": []
2089
+ },
2090
+ {
2091
+ "cell_type": "code",
2092
+ "metadata": {
2093
+ "id": "53114288"
2094
+ },
2095
+ "source": [
2096
+ "save_path = '/content/drive/MyDrive/Asfour_V3.pt'\n",
2097
+ "torch.save({\n",
2098
+ " 'model_state_dict': net.state_dict(),\n",
2099
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
2100
+ " 'best_threshold': best_t,\n",
2101
+ " 'hyperparameters': {\n",
2102
+ " 'input_dim': 256,\n",
2103
+ " 'hidden_size': 48,\n",
2104
+ " 'num_classes': 1,\n",
2105
+ " 'dropout': 0.4,\n",
2106
+ " 'version': 'v3',\n",
2107
+ " }\n",
2108
+ "}, save_path)\n",
2109
+ "\n",
2110
+ "print(f\"Model saved to: {save_path}\")\n"
2111
+ ],
2112
+ "execution_count": null,
2113
+ "outputs": []
2114
+ }
2115
+ ]
2116
+ }