jrmd commited on
Commit
60ef55c
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ wandb
2
+ custom_bert_model.torch
3
+ __pycache__
4
+ .gradio
DL2_BERT_Model_Based_Classification.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
DL2_BERT_Model_Based_Classification.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from sklearn.metrics import (
11
+ accuracy_score,
12
+ confusion_matrix,
13
+ precision_score,
14
+ recall_score,
15
+ )
16
+ from torch.utils.data import DataLoader, Dataset, Subset
17
+ from transformers import AutoTokenizer, BertModel
18
+
19
+ import wandb
20
+
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ NUM_EPOCHS = 5
23
+ BATCH_SIZE = 16
24
+ SAVED_MODEL_PATH = "custom_bert_model.torch"
25
+ SAVED_TARGET_CAT_PATH = "bbc-news-categories.torch"
26
+ DS_PATH = "bbc-news-data.csv"
27
+
28
+
29
+ from typing import DefaultDict
30
+
31
+
32
+ class CustomBertDataset(Dataset):
33
+ def __init__(
34
+ self,
35
+ file_path,
36
+ model_path="google-bert/bert-base-uncased",
37
+ saved_target_cats_path=SAVED_TARGET_CAT_PATH,
38
+ ):
39
+ self.model_path = model_path
40
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
41
+ self.lines = open(file_path).readlines()
42
+ self.lines = np.array(
43
+ [
44
+ [
45
+ re.split(r"\t+", line.replace("\n", ""))[3],
46
+ re.split(r"\t+", line.replace("\n", ""))[0],
47
+ ]
48
+ for i, line in enumerate(self.lines)
49
+ if line != "\n" and i != 0
50
+ ]
51
+ )
52
+ self.corpus = np.array(self.lines[:, 0])
53
+ self.elem_cats = self.lines[:, 1]
54
+ self.unique_cats = sorted(list(set(self.elem_cats)))
55
+ self.num_class = len(self.unique_cats)
56
+ self.cats_dict = {cat: i for i, cat in enumerate(self.unique_cats)}
57
+ self.targets = np.array([self.cats_dict[cat] for cat in self.elem_cats])
58
+
59
+ torch.save(self.unique_cats, saved_target_cats_path)
60
+
61
+ entry_dict = DefaultDict(list)
62
+ for i in range(len(self.corpus)):
63
+ entry_dict[self.targets[i]].append(self.corpus[i])
64
+
65
+ self.final_corpus = []
66
+ self.final_targets = []
67
+ n = 0
68
+ while n < len(self.corpus):
69
+ for key in entry_dict.keys():
70
+ if len(entry_dict[key]) > 0:
71
+ self.final_corpus.append(entry_dict[key].pop(0))
72
+ self.final_targets.append(key)
73
+ n += 1
74
+
75
+ self.corpus = np.array(self.final_corpus)
76
+ self.targets = np.array(self.final_targets)
77
+
78
+ self.max_len = 0
79
+ for sent in self.corpus:
80
+ input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
81
+ self.max_len = max(self.max_len, len(input_ids))
82
+
83
+ self.max_len = min(self.max_len, 512)
84
+ print(f"Max length : {self.max_len}")
85
+
86
+ def __len__(self):
87
+ return len(self.corpus)
88
+
89
+ def __getitem__(self, idx):
90
+ text = self.corpus[idx]
91
+ target = self.targets[idx]
92
+ encoded_input = self.tokenizer.encode_plus(
93
+ text,
94
+ max_length=self.max_len,
95
+ padding="max_length",
96
+ truncation=True,
97
+ return_tensors="pt",
98
+ )
99
+ return (
100
+ encoded_input["input_ids"].squeeze(0),
101
+ encoded_input["attention_mask"].squeeze(0),
102
+ torch.tensor(target, dtype=torch.long),
103
+ )
104
+ # return np.array(encoded_input), torch.tensor(target, dtype=torch.long)
105
+
106
+
107
+ class CustomBertModel(nn.Module):
108
+ def __init__(self, num_class, model_path="google-bert/bert-base-uncased"):
109
+ super(CustomBertModel, self).__init__()
110
+ self.model_path = model_path
111
+ self.num_class = num_class
112
+
113
+ self.bert = BertModel.from_pretrained(self.model_path)
114
+ # Freeze of the parameters of this layer for the training process
115
+ for param in self.bert.parameters():
116
+ param.requires_grad = False
117
+ self.proj_lin = nn.Linear(self.bert.config.hidden_size, self.num_class)
118
+
119
+ def forward(self, input_ids, attention_mask):
120
+ x = self.bert(input_ids=input_ids, attention_mask=attention_mask)
121
+
122
+ x = x.last_hidden_state[:, 0, :]
123
+ x = self.proj_lin(x)
124
+ return x
125
+
126
+
127
+ def train_step(model, train_dataloader, loss_fn, optimizer):
128
+
129
+ num_iterations = len(train_dataloader)
130
+
131
+ for i in range(NUM_EPOCHS):
132
+ print(f"Training Epoch n° {i}")
133
+ model.train()
134
+
135
+ for j, batch in enumerate(train_dataloader):
136
+
137
+ input = batch[:][0]
138
+ attention = batch[:][1]
139
+ target = batch[:][2]
140
+
141
+ output = model(input.to(device), attention.to(device))
142
+
143
+ loss = loss_fn(output, target.to(device))
144
+
145
+ optimizer.zero_grad()
146
+ loss.backward()
147
+ optimizer.step()
148
+
149
+ run.log({"Training loss": loss})
150
+
151
+ print(f"Epoch {i+1} | step {j+1} / {num_iterations} | loss : {loss}")
152
+
153
+ # Save model
154
+ torch.save(model.state_dict(), SAVED_MODEL_PATH)
155
+ print(f"Model saved at {SAVED_MODEL_PATH}")
156
+
157
+
158
+ def eval_step(
159
+ test_dataloader,
160
+ loss_fn,
161
+ num_class,
162
+ saved_model_path=SAVED_MODEL_PATH,
163
+ saved_target_cats_path=SAVED_TARGET_CAT_PATH,
164
+ ):
165
+
166
+ y_pred = []
167
+ y_true = []
168
+
169
+ num_iterations = len(test_dataloader)
170
+ # Load the saved model
171
+ saved_model = CustomBertModel(num_class)
172
+ saved_model.load_state_dict(
173
+ torch.load(saved_model_path, weights_only=False)
174
+ ) # Explicitly set weights_only to False
175
+ saved_model = saved_model.to(device)
176
+ saved_model.eval() # Set the model to evaluation mode
177
+ print(f"Model loaded from path :{saved_model_path}")
178
+
179
+ with torch.no_grad():
180
+ for j, batch in enumerate(test_dataloader):
181
+
182
+ input = batch[:][0]
183
+ attention = batch[:][1]
184
+ target = batch[:][2]
185
+
186
+ output = saved_model(input.to(device), attention.to(device))
187
+
188
+ loss = loss_fn(output, target.to(device))
189
+
190
+ run.log({"Eval loss": loss})
191
+ print(f"Eval loss : {loss}")
192
+ y_pred.extend(output.cpu().numpy().argmax(axis=1))
193
+ y_true.extend(target.cpu().numpy())
194
+
195
+ class_labels = torch.load(saved_target_cats_path, weights_only=False)
196
+
197
+ true_labels = [class_labels[i] for i in y_true]
198
+ pred_labels = [class_labels[i] for i in y_pred]
199
+
200
+ print(f"Accuracy : {accuracy_score(true_labels, pred_labels)}")
201
+
202
+ cm = confusion_matrix(true_labels, pred_labels, labels=class_labels)
203
+ df_cm = pd.DataFrame(cm, index=class_labels, columns=class_labels)
204
+ sns.heatmap(df_cm, annot=True, fmt="d")
205
+ plt.title("Confusion Matrix for BBC News Dataset")
206
+ plt.xlabel("Predicted Label")
207
+ plt.ylabel("True Label")
208
+ plt.show()
209
+
210
+
211
+ if __name__ == "__main__":
212
+
213
+ wandb.login()
214
+ run = wandb.init(project="DIT-Bert-bbc-news-project")
215
+ our_bert_dataset = CustomBertDataset(DS_PATH)
216
+ print(f"Size of bert dataset : {len(our_bert_dataset)}")
217
+ train_dataset = Subset(our_bert_dataset, range(int(len(our_bert_dataset) * 0.8)))
218
+ test_dataset = Subset(
219
+ our_bert_dataset, range(int(len(our_bert_dataset) * 0.8), len(our_bert_dataset))
220
+ )
221
+
222
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
223
+ test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
224
+
225
+ our_bert_model = CustomBertModel(our_bert_dataset.num_class)
226
+ our_bert_model = our_bert_model.to(device)
227
+
228
+ loss_fn = nn.CrossEntropyLoss()
229
+ optimizer = optim.SGD(
230
+ filter(lambda p: p.requires_grad, our_bert_model.parameters()), lr=0.01
231
+ )
232
+
233
+ train_step(our_bert_model, train_dataloader, loss_fn, optimizer)
234
+
235
+ eval_step(test_dataloader, loss_fn, our_bert_dataset.num_class)
README.md ADDED
File without changes
W&B Chart 20_05_2025 00_42_07.png ADDED

Git LFS Details

  • SHA256: 5fe03054854608a6bc58209d79d4aadf130a28a01cf51ab0436dbae739a0ad10
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
W&B Chart 20_05_2025 00_42_20.png ADDED

Git LFS Details

  • SHA256: 68f926f7a172e41a8e52521f5bf5b109bd33c7690611d06be300d66d4f2f1310
  • Pointer size: 131 Bytes
  • Size of remote file: 307 kB
bbc-news-categories.torch ADDED
Binary file (1.42 kB). View file
 
bbc-news-data.csv ADDED
The diff for this file is too large to render. See raw diff
 
demo.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import seaborn as sns
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from sklearn.metrics import (
9
+ accuracy_score,
10
+ confusion_matrix,
11
+ precision_score,
12
+ recall_score,
13
+ )
14
+ from torch.utils.data import DataLoader, Dataset, Subset
15
+ from transformers import AutoTokenizer, BertModel
16
+
17
+ from DL2_BERT_Model_Based_Classification import CustomBertModel
18
+
19
+ SAVED_TARGET_CAT_PATH = "bbc-news-categories.torch"
20
+ # The actual is too large to be stored in github.
21
+ # So It is avaailable at the following URL : https://drive.google.com/file/d/1o-TDzHJwQfgw_y9R5PWo4TigkpYuSKmd/view?usp=sharing
22
+ SAVED_MODEL_PATH = "custom_bert_model.torch"
23
+
24
+
25
+ def find_category(
26
+ input,
27
+ saved_model_path=SAVED_MODEL_PATH,
28
+ model_path="google-bert/bert-base-uncased",
29
+ saved_target_cats_path=SAVED_TARGET_CAT_PATH,
30
+ ):
31
+ class_labels = torch.load(
32
+ saved_target_cats_path, weights_only=False, map_location=torch.device("cpu")
33
+ )
34
+ saved_model = CustomBertModel(len(class_labels))
35
+ saved_model.load_state_dict(
36
+ torch.load(
37
+ saved_model_path, weights_only=False, map_location=torch.device("cpu")
38
+ )
39
+ ) # Explicitly set weights_only to False
40
+
41
+ saved_model.eval() # Set the model to evaluation mode
42
+ print(f"Model loaded from path :{saved_model_path}")
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
45
+ encoded_input = tokenizer.encode_plus(
46
+ input,
47
+ max_length=512,
48
+ padding="max_length",
49
+ truncation=True,
50
+ return_tensors="pt",
51
+ )
52
+
53
+ y_pred = ""
54
+ with torch.no_grad():
55
+
56
+ output = saved_model(
57
+ encoded_input["input_ids"],
58
+ encoded_input["attention_mask"],
59
+ )
60
+
61
+ y_pred = class_labels[output.squeeze(0).numpy().argmax()]
62
+
63
+ return y_pred
64
+
65
+
66
+ demo = gr.Interface(
67
+ fn=find_category,
68
+ inputs=["text"],
69
+ outputs=["text"],
70
+ )
71
+
72
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ transformers
3
+ numpy
4
+ pandas
5
+ openpyxl
6
+ torch
7
+ torchvision
8
+ torchaudio
9
+ gradio