khalid99ml commited on
Commit
1623255
·
verified ·
1 Parent(s): 4f950e5
Files changed (1) hide show
  1. challange1.ipynb +237 -0
challange1.ipynb ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from datasets import load_dataset, DatasetDict\n",
10
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
11
+ "from sklearn.model_selection import train_test_split\n",
12
+ "from huggingface_hub import notebook_login\n",
13
+ "import torch\n"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "data": {
23
+ "application/vnd.jupyter.widget-view+json": {
24
+ "model_id": "619aead440a04dbd9eafa156e3713251",
25
+ "version_major": 2,
26
+ "version_minor": 0
27
+ },
28
+ "text/plain": [
29
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
30
+ ]
31
+ },
32
+ "metadata": {},
33
+ "output_type": "display_data"
34
+ }
35
+ ],
36
+ "source": [
37
+ "notebook_login()\n"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "dataset = load_dataset(\"SKNahin/bengali-transliteration-data\")\n",
47
+ "\n",
48
+ "dataset = dataset[\"train\"].train_test_split(test_size=0.2, seed=42)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 6,
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "name": "stdout",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "DatasetDict({\n",
61
+ " train: Dataset({\n",
62
+ " features: ['bn', 'rm'],\n",
63
+ " num_rows: 4004\n",
64
+ " })\n",
65
+ " validation: Dataset({\n",
66
+ " features: ['bn', 'rm'],\n",
67
+ " num_rows: 1002\n",
68
+ " })\n",
69
+ "})\n"
70
+ ]
71
+ }
72
+ ],
73
+ "source": [
74
+ "dataset = DatasetDict({\n",
75
+ " \"train\": dataset[\"train\"],\n",
76
+ " \"validation\": dataset[\"test\"]\n",
77
+ "})\n",
78
+ "print(dataset)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 11,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "from transformers import MBartForConditionalGeneration, MBart50TokenizerFast\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [
95
+ {
96
+ "ename": "NameError",
97
+ "evalue": "name 'MBartForConditionalGeneration' is not defined",
98
+ "output_type": "error",
99
+ "traceback": [
100
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
101
+ "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
102
+ "Cell \u001b[1;32mIn[1], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfacebook/mbart-large-50-many-to-many-mmt\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m----> 2\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mMBartForConditionalGeneration\u001b[49m\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacebook/mbart-large-50-many-to-many-mmt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 3\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m MBart50TokenizerFast\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacebook/mbart-large-50-many-to-many-mmt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m#this process took huge time as it is very big compared to the internet bandwith, so the next steps could not be run until its finishing\u001b[39;00m\n",
103
+ "\u001b[1;31mNameError\u001b[0m: name 'MBartForConditionalGeneration' is not defined"
104
+ ]
105
+ }
106
+ ],
107
+ "source": [
108
+ "model_name='facebook/mbart-large-50-many-to-many-mmt'\n",
109
+ "model = MBartForConditionalGeneration.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
110
+ "tokenizer = MBart50TokenizerFast.from_pretrained(\"facebook/mbart-large-50-many-to-many-mmt\")\n",
111
+ "#this process took huge time as it is very big compared to the internet bandwith, and it is still running so the next steps could not be run until its finishing\n"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 15,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "def preprocess_function(examples):\n",
121
+ " inputs = examples[\"banglish\"]\n",
122
+ " targets = examples[\"bangla\"]\n",
123
+ " model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding=\"max_length\")\n",
124
+ " with tokenizer.as_target_tokenizer():\n",
125
+ " labels = tokenizer(targets, max_length=128, truncation=True, padding=\"max_length\")\n",
126
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
127
+ " return model_inputs"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset[\"train\"].column_names)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "training_args = Seq2SeqTrainingArguments(\n",
164
+ " output_dir=\"./results\",\n",
165
+ " evaluation_strategy=\"epoch\",\n",
166
+ " learning_rate=2e-5,\n",
167
+ " per_device_train_batch_size=16,\n",
168
+ " per_device_eval_batch_size=16,\n",
169
+ " weight_decay=0.01,\n",
170
+ " save_total_limit=2,\n",
171
+ " num_train_epochs=5,\n",
172
+ " predict_with_generate=True,\n",
173
+ " logging_dir=\"./logs\",\n",
174
+ " logging_strategy=\"epoch\",\n",
175
+ " save_strategy=\"epoch\"\n",
176
+ ")"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "trainer = Seq2SeqTrainer(\n",
186
+ " model=model,\n",
187
+ " args=training_args,\n",
188
+ " train_dataset=tokenized_datasets[\"train\"],\n",
189
+ " eval_dataset=tokenized_datasets[\"validation\"],\n",
190
+ " tokenizer=tokenizer,\n",
191
+ " data_collator=data_collator\n",
192
+ ")"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "#training start using the provided dataset for fine tuning\n",
202
+ "trainer.train()"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "model.save_pretrained(\"banglish-to-bangla-model\")\n",
212
+ "tokenizer.save_pretrained(\"banglish-to-bangla-model\")"
213
+ ]
214
+ }
215
+ ],
216
+ "metadata": {
217
+ "kernelspec": {
218
+ "display_name": ".venv",
219
+ "language": "python",
220
+ "name": "python3"
221
+ },
222
+ "language_info": {
223
+ "codemirror_mode": {
224
+ "name": "ipython",
225
+ "version": 3
226
+ },
227
+ "file_extension": ".py",
228
+ "mimetype": "text/x-python",
229
+ "name": "python",
230
+ "nbconvert_exporter": "python",
231
+ "pygments_lexer": "ipython3",
232
+ "version": "3.12.3"
233
+ }
234
+ },
235
+ "nbformat": 4,
236
+ "nbformat_minor": 2
237
+ }