PuLam commited on
Commit
e8148e6
·
verified ·
1 Parent(s): aacc571

Upload 4 files

Browse files
Files changed (3) hide show
  1. config.pt +3 -0
  2. pythonfile.py +283 -0
  3. vocab.pkl +3 -0
config.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2ed8867fb6ae6249b66e8569ea9aa3ec8a061211d77dcf8a781986fc44e9666
3
+ size 860
pythonfile.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # poems_generation.py
2
+
3
+ import os
4
+ import re
5
+ import time
6
+ import pandas as pd
7
+ import torch
8
+ import numpy as np
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ import torch.nn.functional as F
12
+ from sklearn.model_selection import train_test_split
13
+ from underthesea import word_tokenize
14
+ import pickle
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from dataclasses import dataclass
17
+
18
+ # Download the datasets
19
+ os.system("wget https://huggingface.co/datasets/Libosa2707/vietnamese-poem/resolve/main/poems_dataset.csv")
20
+ os.system("wget https://huggingface.co/datasets/phamson02/vietnamese-poetry-corpus/resolve/main/poems_dataset.csv")
21
+
22
+ # Install necessary libraries
23
+ os.system("pip install underthesea")
24
+
25
+ # Define functions for saving and loading pickle files
26
+ def save_pkl(save_object, save_file):
27
+ with open(save_file, 'wb') as f:
28
+ pickle.dump(save_object, f, protocol=pickle.HIGHEST_PROTOCOL)
29
+
30
+ def load_pkl(load_file):
31
+ with open(load_file, 'rb') as f:
32
+ output = pickle.load(f)
33
+ return output
34
+
35
+ # Load the dataset
36
+ data = pd.read_csv("poems_dataset.csv")
37
+
38
+ # Display the first few rows of the dataset
39
+ print(data.head())
40
+
41
+ # Split the dataset into training and validation sets
42
+ train_df, val_df = train_test_split(data, test_size=0.3, random_state=42)
43
+
44
+ # Prepare the documents
45
+ train_documents = [doc for doc in train_df['content'].tolist()]
46
+ val_documents = [doc for doc in val_df['content'].tolist()]
47
+
48
+ # Vocabulary class definition
49
+ class Vocabulary:
50
+ def __init__(self):
51
+ self.word2id = dict()
52
+ self.pad_id = 0
53
+ self.unk_id = 1
54
+ self.sos_id = 2
55
+ self.eos_id = 3
56
+
57
+ self.word2id['<pad>'] = self.pad_id
58
+ self.word2id['<unk>'] = self.unk_id
59
+ self.word2id['<s>'] = self.sos_id
60
+ self.word2id['</s>'] = self.eos_id
61
+
62
+ self.id2word = {v: k for k, v in self.word2id.items()}
63
+
64
+ def __getitem__(self, word):
65
+ return self.word2id.get(word, self.unk_id)
66
+
67
+ def __contains__(self, word):
68
+ return word in self.word2id
69
+
70
+ def __len__(self):
71
+ return len(self.word2id)
72
+
73
+ def lookup_tokens(self, word_indexes: list):
74
+ return [self.id2word[word_index] for word_index in word_indexes]
75
+
76
+ def add(self, word):
77
+ if word not in self:
78
+ word_index = self.word2id[word] = len(self.word2id)
79
+ self.id2word[word_index] = word
80
+ return word_index
81
+ else:
82
+ return self[word]
83
+
84
+ def corpus_to_tensor(self, corpus, is_tokenized=False):
85
+ if is_tokenized:
86
+ tokenized_corpus = corpus
87
+ else:
88
+ tokenized_corpus = self.tokenize_corpus(corpus)
89
+ indicies_corpus = list()
90
+ for document in tqdm(tokenized_corpus):
91
+ indicies_document = torch.tensor(list(map(lambda word: self[word], document)), dtype=torch.long)
92
+ indicies_corpus.append(indicies_document)
93
+ return indicies_corpus
94
+
95
+ def tensor_to_corpus(self, tensor):
96
+ corpus = list()
97
+ for indicies in tqdm(tensor):
98
+ document = list(map(lambda index: self.id2word[index.item()], indicies))
99
+ corpus.append(document)
100
+ return corpus
101
+
102
+ @staticmethod
103
+ def tokenize_corpus(corpus):
104
+ print("Tokenize the corpus...")
105
+ tokenized_corpus = list()
106
+ for document in tqdm(corpus):
107
+ tokenized_document = ['<s>'] + re.findall(r'(\w+|[^\w\s]|\S+|\n)', document) + ['</s>']
108
+ tokenized_corpus.append(tokenized_document)
109
+ return tokenized_corpus
110
+
111
+ @classmethod
112
+ def from_documents(cls, documents):
113
+ words = set(word for doc in documents for word in re.findall(r'\w+|\S|\n', doc))
114
+ vocab = cls()
115
+ for w in words:
116
+ vocab.add(w)
117
+ return vocab
118
+
119
+ @classmethod
120
+ def from_pretrained(cls, save_dir):
121
+ with open(os.path.join(save_dir, "vocab.pkl"), 'rb') as file:
122
+ pretrained_vocab = pickle.load(file)
123
+ return cls.init_vocab_from_pretrained(pretrained_vocab)
124
+
125
+ @staticmethod
126
+ def init_vocab_from_pretrained(pretrained_vocab):
127
+ vocab = Vocabulary()
128
+ vocab.word2id.update(pretrained_vocab)
129
+ vocab.id2word = {v: k for k, v in vocab.word2id.items()}
130
+ return vocab
131
+
132
+ def save_pretrained(self, save_dir):
133
+ os.makedirs(save_dir, exist_ok=True)
134
+ with open(os.path.join(save_dir, "vocab.pkl"), 'wb') as file:
135
+ pickle.dump(self.word2id, file)
136
+
137
+ # Initialize Vocabulary
138
+ vocab = Vocabulary.from_documents(train_documents)
139
+
140
+ # PoemGenerationDataset class definition
141
+ class PoemGenerationDataset(Dataset):
142
+ def __init__(self, documents, vocab, max_length=None):
143
+ self.vocab = vocab
144
+ self.sos_idx = vocab["<s>"]
145
+ self.eos_idx = vocab["</s>"]
146
+ self.pad_idx = vocab["<pad>"]
147
+ self.documents = documents
148
+ self.max_length = max_length
149
+ self.tokenized_documents = self.vocab.tokenize_corpus(self.documents)
150
+ self.tensor_data = self.vocab.corpus_to_tensor(self.tokenized_documents, is_tokenized=True)
151
+
152
+ def __len__(self):
153
+ return len(self.tensor_data)
154
+
155
+ def __getitem__(self, idx):
156
+ return self.tensor_data[idx]
157
+
158
+ def shift_right(self, input_ids, pad_token=0):
159
+ padding_column = torch.full_like(input_ids[:, :1], pad_token)
160
+ shifted_ids = torch.cat([padding_column, input_ids[:, :-1]], dim=-1)
161
+ return shifted_ids
162
+
163
+ def collate_fn(self, examples):
164
+ examples = sorted(examples, key=lambda e: len(e), reverse=True)
165
+ if self.max_length is not None:
166
+ examples = [torch.cat([e[:self.max_length-1], torch.tensor([self.eos_idx])]) for e in examples]
167
+ docs = [e for e in examples]
168
+ input_ids = torch.nn.utils.rnn.pad_sequence(docs, batch_first=True, padding_value=self.pad_idx)
169
+ labels = input_ids.clone()
170
+ input_ids = self.shift_right(input_ids, pad_token=self.pad_idx)
171
+ return {"inputs": input_ids, "labels": labels}
172
+
173
+ # Initialize datasets and dataloaders
174
+ train_dataset = PoemGenerationDataset(train_documents, vocab, max_length=512)
175
+ val_dataset = PoemGenerationDataset(val_documents, vocab, max_length=512)
176
+ train_dataloader = DataLoader(train_dataset, batch_size=5, collate_fn=train_dataset.collate_fn)
177
+ val_dataloader = DataLoader(val_dataset, batch_size=5, collate_fn=val_dataset.collate_fn)
178
+
179
+ # LSTMForPoemGeneration model definition
180
+ class LSTMForPoemGeneration(nn.Module):
181
+ def __init__(self, config):
182
+ super().__init__()
183
+ self.config = config
184
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.padding_idx)
185
+ self.batch_norm = nn.BatchNorm1d(config.hidden_size)
186
+ self.lstm = nn.LSTM(config.hidden_size,
187
+ config.hidden_size,
188
+ num_layers=config.num_layers,
189
+ dropout=config.dropout,
190
+ batch_first=True)
191
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
192
+ self.dropout = nn.Dropout(config.dropout)
193
+
194
+ def forward(self, input_ids, labels=None):
195
+ embeds = self.dropout(self.embedding(input_ids))
196
+ bn_output = self.batch_norm(embeds.permute(0,2,1)).permute(0,2,1)
197
+ output, (hidden, cell) = self.lstm(bn_output)
198
+ logits = self.lm_head(self.dropout(output))
199
+ loss = None
200
+ if labels is not None:
201
+ loss_fct = nn.CrossEntropyLoss()
202
+ labels = labels.to(logits.device)
203
+ loss = loss_fct(logits.view(-1, logits.size(2)), labels.view(-1))
204
+ return logits, loss
205
+
206
+ @torch.no_grad()
207
+ def generate(self, input_ids, max_length=100, temperature=1.0):
208
+ self.to(input_ids.device)
209
+ self.eval()
210
+ current_length = input_ids.size(1)
211
+ while current_length < max_length:
212
+ logits, _ = self.forward(input_ids)
213
+ logits = logits[:, -1, :] / temperature
214
+ probs = F.softmax(logits, dim=-1)
215
+ next_token_id = torch.multinomial(probs, num_samples=1)
216
+ input_ids = torch.cat([input_ids, next_token_id], dim=1)
217
+ current_length += 1
218
+ if next_token_id.item() == vocab.eos_id:
219
+ break
220
+ return input_ids
221
+
222
+ def save_pretrained(self, save_dir):
223
+ os.makedirs(save_dir, exist_ok=True)
224
+ torch.save(self.config, os.path.join(save_dir, "config.pt"))
225
+ torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.pt"))
226
+
227
+ @classmethod
228
+ def from_pretrained(cls, saved_dir):
229
+ config = torch.load(os.path.join(saved_dir, "config.pt"))
230
+ model = cls(config)
231
+ state_dict = torch.load(os.path.join(saved_dir, "pytorch_model.pt"), map_location=torch.device('cpu'))
232
+ model.load_state_dict(state_dict)
233
+ return model
234
+
235
+ # Configuration for the LSTM model
236
+ @dataclass
237
+ class LSTMConfig:
238
+ vocab_size: int
239
+ padding_idx: int
240
+ hidden_size: int
241
+ num_layers: int
242
+ dropout: float = 0.3
243
+
244
+ # Initialize and train the model
245
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
246
+
247
+ model_config = LSTMConfig(vocab_size=len(vocab), padding_idx=vocab.pad_id, hidden_size=128, num_layers=2)
248
+ model = LSTMForPoemGeneration(model_config).to(device)
249
+
250
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
251
+
252
+ for epoch in range(20):
253
+ model.train()
254
+ running_loss = 0.0
255
+ for i, batch in enumerate(tqdm(train_dataloader), 1):
256
+ optimizer.zero_grad()
257
+ inputs, labels = batch["inputs"].to(device), batch["labels"].to(device)
258
+ _, loss = model(inputs, labels)
259
+ loss.backward()
260
+ optimizer.step()
261
+ running_loss += loss.item()
262
+ print(f"Training Loss after epoch {epoch}: {running_loss/len(train_dataloader)}")
263
+
264
+ # Validation
265
+ model.eval()
266
+ val_loss = 0.0
267
+ with torch.no_grad():
268
+ for i, batch in enumerate(tqdm(val_dataloader), 1):
269
+ inputs, labels = batch["inputs"].to(device), batch["labels"].to(device)
270
+ _, loss = model(inputs, labels)
271
+ val_loss += loss.item()
272
+ print(f"Validation Loss after epoch {epoch}: {val_loss/len(val_dataloader)}")
273
+
274
+ # Save the model and vocabulary
275
+ model.save_pretrained("poem_model")
276
+ vocab.save_pretrained("poem_model")
277
+
278
+ # Generate sample text
279
+ input_sentence = "Mặt trời mọc ở đằng đông"
280
+ input_ids = vocab.corpus_to_tensor([input_sentence])[0].unsqueeze(0).to(device)
281
+ output = model.generate(input_ids, max_length=100, temperature=0.8)
282
+ output_text = " ".join(vocab.lookup_tokens(output.squeeze().tolist())).replace("<s>", "").replace("</s>", "")
283
+ print(output_text)
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dbd72e54b70152ab1102a1c2218229f8a32aa0c3f56347e21ee6400bcb42806
3
+ size 160866