Sunaina792 commited on
Commit
a6bdbee
·
verified ·
1 Parent(s): 025feca

Upload 2 files

Browse files
Files changed (2) hide show
  1. normal_to_formal.ipynb +331 -0
  2. normal_to_genz.ipynb +0 -0
normal_to_formal.ipynb ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {
7
+ "id": "4KDV129CjSUr"
8
+ },
9
+ "outputs": [
10
+ {
11
+ "data": {
12
+ "application/vnd.jupyter.widget-view+json": {
13
+ "model_id": "cfcbb81f755540bbbee503cce0b039eb",
14
+ "version_major": 2,
15
+ "version_minor": 0
16
+ },
17
+ "text/plain": [
18
+ "tokenizer_config.json: 0.00B [00:00, ?B/s]"
19
+ ]
20
+ },
21
+ "metadata": {},
22
+ "output_type": "display_data"
23
+ },
24
+ {
25
+ "data": {
26
+ "application/vnd.jupyter.widget-view+json": {
27
+ "model_id": "e2b5e0ddfc1741108fd7d92163bbea02",
28
+ "version_major": 2,
29
+ "version_minor": 0
30
+ },
31
+ "text/plain": [
32
+ "config.json: 0.00B [00:00, ?B/s]"
33
+ ]
34
+ },
35
+ "metadata": {},
36
+ "output_type": "display_data"
37
+ },
38
+ {
39
+ "data": {
40
+ "application/vnd.jupyter.widget-view+json": {
41
+ "model_id": "be20a20120464aa3964460614bc46c6b",
42
+ "version_major": 2,
43
+ "version_minor": 0
44
+ },
45
+ "text/plain": [
46
+ "spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]"
47
+ ]
48
+ },
49
+ "metadata": {},
50
+ "output_type": "display_data"
51
+ },
52
+ {
53
+ "data": {
54
+ "application/vnd.jupyter.widget-view+json": {
55
+ "model_id": "d0a80cf87a4a4035a09409c816e67b8f",
56
+ "version_major": 2,
57
+ "version_minor": 0
58
+ },
59
+ "text/plain": [
60
+ "tokenizer.json: 0.00B [00:00, ?B/s]"
61
+ ]
62
+ },
63
+ "metadata": {},
64
+ "output_type": "display_data"
65
+ },
66
+ {
67
+ "data": {
68
+ "application/vnd.jupyter.widget-view+json": {
69
+ "model_id": "4cc5466f071849e6a4b9338ddddfc7fc",
70
+ "version_major": 2,
71
+ "version_minor": 0
72
+ },
73
+ "text/plain": [
74
+ "special_tokens_map.json: 0.00B [00:00, ?B/s]"
75
+ ]
76
+ },
77
+ "metadata": {},
78
+ "output_type": "display_data"
79
+ },
80
+ {
81
+ "data": {
82
+ "application/vnd.jupyter.widget-view+json": {
83
+ "model_id": "c5db085240c84185a4fdeef9570873ff",
84
+ "version_major": 2,
85
+ "version_minor": 0
86
+ },
87
+ "text/plain": [
88
+ "pytorch_model.bin: 0%| | 0.00/892M [00:00<?, ?B/s]"
89
+ ]
90
+ },
91
+ "metadata": {},
92
+ "output_type": "display_data"
93
+ },
94
+ {
95
+ "data": {
96
+ "application/vnd.jupyter.widget-view+json": {
97
+ "model_id": "5ab33f6bec6e46e6a8159a42ef725590",
98
+ "version_major": 2,
99
+ "version_minor": 0
100
+ },
101
+ "text/plain": [
102
+ "model.safetensors: 0%| | 0.00/892M [00:00<?, ?B/s]"
103
+ ]
104
+ },
105
+ "metadata": {},
106
+ "output_type": "display_data"
107
+ },
108
+ {
109
+ "name": "stdout",
110
+ "output_type": "stream",
111
+ "text": [
112
+ "I am going to get that report now.\n",
113
+ "I love going to the movies.\n"
114
+ ]
115
+ }
116
+ ],
117
+ "source": [
118
+ "!pip install -q transformers torch\n",
119
+ "\n",
120
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
121
+ "import torch\n",
122
+ "\n",
123
+ "model_id = \"rajistics/informal_formal_style_transfer\"\n",
124
+ "\n",
125
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
126
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_id)\n",
127
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
128
+ "model.to(device)\n",
129
+ "\n",
130
+ "def informal_to_formal(text, max_new_tokens=64, num_beams=4):\n",
131
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
132
+ " with torch.no_grad():\n",
133
+ " outputs = model.generate(\n",
134
+ " **inputs,\n",
135
+ " max_new_tokens=max_new_tokens,\n",
136
+ " num_beams=num_beams,\n",
137
+ " early_stopping=True,\n",
138
+ " no_repeat_ngram_size=2,\n",
139
+ " )\n",
140
+ " return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()\n",
141
+ "\n",
142
+ "# test\n",
143
+ "print(informal_to_formal(\"gimme that report now\"))\n",
144
+ "print(informal_to_formal(\"i loooooooooooooooooooooooove going to the movies.\"))\n"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 3,
150
+ "metadata": {
151
+ "id": "hgYDbUJ3jleL"
152
+ },
153
+ "outputs": [],
154
+ "source": [
155
+ "def informal_to_formal_prefixed(text, **gen_kwargs):\n",
156
+ " prefixed = \"transfer Casual to Formal: \" + text\n",
157
+ " return informal_to_formal(prefixed, **gen_kwargs)\n"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 4,
163
+ "metadata": {
164
+ "id": "K6KK6Rr2jwKt"
165
+ },
166
+ "outputs": [
167
+ {
168
+ "name": "stdout",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
172
+ "* Running on public URL: https://591bd78c0ee0426622.gradio.live\n",
173
+ "\n",
174
+ "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
175
+ ]
176
+ },
177
+ {
178
+ "data": {
179
+ "text/html": [
180
+ "<div><iframe src=\"https://591bd78c0ee0426622.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
181
+ ],
182
+ "text/plain": [
183
+ "<IPython.core.display.HTML object>"
184
+ ]
185
+ },
186
+ "metadata": {},
187
+ "output_type": "display_data"
188
+ },
189
+ {
190
+ "data": {
191
+ "text/plain": []
192
+ },
193
+ "execution_count": 4,
194
+ "metadata": {},
195
+ "output_type": "execute_result"
196
+ }
197
+ ],
198
+ "source": [
199
+ "import gradio as gr\n",
200
+ "\n",
201
+ "def formal_interface(text, max_len, beams):\n",
202
+ " return informal_to_formal(text, max_new_tokens=int(max_len), num_beams=int(beams))\n",
203
+ "\n",
204
+ "demo = gr.Interface(\n",
205
+ " fn=formal_interface,\n",
206
+ " inputs=[\n",
207
+ " gr.Textbox(lines=3, label=\"Informal text\"),\n",
208
+ " gr.Slider(16, 128, value=64, step=4, label=\"Max new tokens\"),\n",
209
+ " gr.Slider(1, 8, value=4, step=1, label=\"Beams\"),\n",
210
+ " ],\n",
211
+ " outputs=gr.Textbox(label=\"Formal text\"),\n",
212
+ " title=\"Informal ➜ Formal \",\n",
213
+ ")\n",
214
+ "\n",
215
+ "demo.launch(share=True)\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 5,
221
+ "metadata": {
222
+ "id": "OGUU73oqj1gn"
223
+ },
224
+ "outputs": [
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "Model saved to: my_formal_t5_model\n"
230
+ ]
231
+ }
232
+ ],
233
+ "source": [
234
+ "model.save_pretrained(\"my_formal_t5_model\")\n",
235
+ "print(\"Model saved to: my_formal_t5_model\")"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": []
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {
249
+ "id": "tASct-9QlqNk"
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "# Alternative 1: Save only model weights (state_dict)\n",
254
+ "import torch\n",
255
+ "torch.save(model.state_dict(), \"formal_model_weights.pth\")\n",
256
+ "print(\"Model weights saved to: formal_model_weights.pth\")\n",
257
+ "\n",
258
+ "# To load later:\n",
259
+ "# model = AutoModelForSeq2SeqLM.from_pretrained(model_id)\n",
260
+ "# model.load_state_dict(torch.load(\"formal_model_weights.pth\"))\n",
261
+ "# model.to(device)"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "# Alternative 2: Save in SafeTensors format (more secure and faster loading)\n",
271
+ "try:\n",
272
+ " from safetensors.torch import save_file\n",
273
+ " save_file(model.state_dict(), \"formal_model_weights.safetensors\")\n",
274
+ " print(\"Model saved in SafeTensors format: formal_model_weights.safetensors\")\n",
275
+ "except ImportError:\n",
276
+ " print(\"SafeTensors not installed. Install with: pip install safetensors\")\n",
277
+ "\n",
278
+ "# To load SafeTensors:\n",
279
+ "# from safetensors.torch import load_file\n",
280
+ "# state_dict = load_file(\"formal_model_weights.safetensors\")\n",
281
+ "# model.load_state_dict(state_dict)"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "# Alternative 3: Save model and tokenizer to a custom directory\n",
291
+ "model.save_pretrained(\"./my_custom_formal_model\")\n",
292
+ "tokenizer.save_pretrained(\"./my_custom_formal_model\")\n",
293
+ "print(\"Model and tokenizer saved to: ./my_custom_formal_model/\")\n",
294
+ "\n",
295
+ "# Alternative 4: Push to Hugging Face Hub (requires huggingface_hub)\n",
296
+ "# from huggingface_hub import login\n",
297
+ "# login() # You'll need to authenticate\n",
298
+ "# model.push_to_hub(\"your-username/formal-style-transfer-model\")\n",
299
+ "# tokenizer.push_to_hub(\"your-username/formal-style-transfer-model\")\n",
300
+ "# print(\"Model pushed to Hugging Face Hub\")"
301
+ ]
302
+ }
303
+ ],
304
+ "metadata": {
305
+ "accelerator": "GPU",
306
+ "colab": {
307
+ "gpuType": "T4",
308
+ "private_outputs": true,
309
+ "provenance": []
310
+ },
311
+ "kernelspec": {
312
+ "display_name": "Python 3 (ipykernel)",
313
+ "language": "python",
314
+ "name": "python3"
315
+ },
316
+ "language_info": {
317
+ "codemirror_mode": {
318
+ "name": "ipython",
319
+ "version": 3
320
+ },
321
+ "file_extension": ".py",
322
+ "mimetype": "text/x-python",
323
+ "name": "python",
324
+ "nbconvert_exporter": "python",
325
+ "pygments_lexer": "ipython3",
326
+ "version": "3.12.12"
327
+ }
328
+ },
329
+ "nbformat": 4,
330
+ "nbformat_minor": 0
331
+ }
normal_to_genz.ipynb ADDED
The diff for this file is too large to render. See raw diff