divyanshu94 commited on
Commit
b437a5e
·
1 Parent(s): 79be5ba

commit files to HF hub

Browse files
agri_custom_pipeline.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, BertForSequenceClassification, Pipeline
2
+ from nltk.tokenize import word_tokenize
3
+ from nltk.stem import WordNetLemmatizer
4
+ from nltk.corpus import stopwords
5
+ from nltk.corpus import wordnet
6
+ import numpy as np
7
+ import warnings
8
+ import string
9
+ import torch
10
+ import nltk
11
+ import re
12
+
13
+ # Download necessary NLTK packages
14
+ nltk.download('averaged_perceptron_tagger')
15
+ nltk.download("stopwords")
16
+ nltk.download('wordnet')
17
+ nltk.download('punkt')
18
+
19
+ # Supress warning
20
+ warnings.filterwarnings('ignore')
21
+
22
+ # pre-processing modules
23
+ class RemovePunctuation:
24
+ """
25
+ class to remove the corresponding punctuation from the list of punctuations
26
+ """
27
+
28
+ def __init__(self):
29
+ """
30
+ :param empty: None
31
+ """
32
+ self.punctuation = string.punctuation
33
+
34
+ def __call__(self, punctuations):
35
+ """
36
+ Apply the transformations above.
37
+ :param punctuation: take the single punctuation(in my case '?')
38
+ :return: transformed punctuation list, excluding the '?'
39
+ """
40
+ if type(punctuations) == str:
41
+ punctuations = list(punctuations)
42
+ for punctuation in punctuations:
43
+ self.punctuation = self.punctuation.translate(str.maketrans('', '', punctuation))
44
+ return self.punctuation
45
+
46
+
47
+ # Accessing the remove_punctuation object
48
+ remove_punctuation = RemovePunctuation()
49
+
50
+
51
+ def get_wordnet_pos(tag):
52
+ if tag.startswith('J'):
53
+ return wordnet.ADJ
54
+ elif tag.startswith('V'):
55
+ return wordnet.VERB
56
+ elif tag.startswith('N'):
57
+ return wordnet.NOUN
58
+ elif tag.startswith('R'):
59
+ return wordnet.ADV
60
+ else:
61
+ return wordnet.NOUN # Default to Noun if the part of speech is not recognized
62
+
63
+
64
+ class ProcessText(object):
65
+
66
+ @staticmethod
67
+ def remove_punctuation_text(text):
68
+ """custom function to remove the punctuation"""
69
+ res = (re.findall(r'\w+|[^\s\w]+', text))
70
+ name = []
71
+ for word in res:
72
+ clean_word = word.translate(str.maketrans('', '', remove_punctuation("")))
73
+ if clean_word != "":
74
+ name.append(clean_word)
75
+
76
+ return " ".join(name)
77
+
78
+ @staticmethod
79
+ def remove_stopwords(text):
80
+ stop_words = set(stopwords.words('english'))
81
+ words = word_tokenize(text)
82
+ filtered_words = [word for word in words if word.lower() not in stop_words]
83
+ return ' '.join(filtered_words)
84
+
85
+ @staticmethod
86
+ def lower_casing(text):
87
+ text_lower = text.lower()
88
+
89
+ return text_lower
90
+
91
+
92
+ @staticmethod
93
+ def lemmatize_text(text):
94
+ lemmatizer = WordNetLemmatizer()
95
+ words = word_tokenize(text)
96
+ tagged_words = nltk.pos_tag(words)
97
+ lemmatized_words = [lemmatizer.lemmatize(word, pos=get_wordnet_pos(tag)) for word, tag in tagged_words]
98
+ return ' '.join(lemmatized_words)
99
+
100
+ @staticmethod
101
+ def remove_duplicates_and_sort(text):
102
+ # Split the text into individual words
103
+ words = text.split()
104
+
105
+ # Create a set to store unique words (which automatically removes duplicates)
106
+ unique_words = set(words)
107
+
108
+ # Sort the unique words based on their original order in the text
109
+ sorted_unique_words = sorted(unique_words, key=lambda x: words.index(x))
110
+
111
+ # Join the sorted unique words back into a string with space as separator
112
+ sorted_text = ' '.join(sorted_unique_words)
113
+
114
+ return sorted_text
115
+
116
+ @staticmethod
117
+ def remove_numbers(text):
118
+ # Use regex to replace all numbers with an empty string
119
+ cleaned_text = re.sub(r'\d+', '', text)
120
+ return cleaned_text
121
+
122
+ @staticmethod
123
+ def include_words_with_len_greater_than_2(text):
124
+ # Split the text into words
125
+ words = text.split()
126
+
127
+ # Filter out words with length greater than 2
128
+ filtered_words = [word for word in words if len(word) > 2]
129
+
130
+ # Join the filtered words back into a text
131
+ cleaned_text = ' '.join(filtered_words)
132
+
133
+ return cleaned_text
134
+
135
+ def __call__(self, text):
136
+ # remove any punctuation
137
+ text = self.remove_punctuation_text(text)
138
+
139
+ # Covert text into lower case
140
+ text = self.lower_casing(text)
141
+
142
+ # Stopwords such as "is", "the", etc that coney no meaning are removed
143
+ text = self.remove_stopwords(text)
144
+
145
+ # Lemmatization is done for converting words to their base or root form, considering their context and part of speech.
146
+ text = self.lemmatize_text(text)
147
+
148
+ # Since words are independent to one another in our problem scenario we can sort the text by word and remove any kind of duplicacy
149
+ text = self.remove_duplicates_and_sort(text)
150
+
151
+ cleaned_text = self.include_words_with_len_greater_than_2(self.remove_numbers(text))
152
+
153
+ return cleaned_text
154
+
155
+
156
+ def write_csv(file_path, rows):
157
+ with open(file_path, "w", newline="", encoding="utf-8") as data_file:
158
+ # create the csv writer object
159
+ csv_writer = csv.writer(data_file, lineterminator="\n")
160
+
161
+ # write to the same file
162
+ csv_writer.writerows(rows)
163
+
164
+
165
+ # custom inference pipeline
166
+ class AgriClfPipeline(Pipeline):
167
+ def _sanitize_parameters(self, **kwargs):
168
+ preprocess_kwargs = {}
169
+ if "text" in kwargs:
170
+ preprocess_kwargs["text"] = kwargs["text"]
171
+ return preprocess_kwargs, {}, {}
172
+
173
+ def preprocess(self, text, **kwargs):
174
+ textPre_processing = ProcessText()
175
+ processed_description = textPre_processing(text)
176
+ try:
177
+ if type(processed_description) == str:
178
+ tokenizer = AutoTokenizer.from_pretrained("divyanshu94/agriBERT_clfModel")
179
+ processed_description = str(processed_description)
180
+ predToken = tokenizer.encode(processed_description, add_special_tokens=True)
181
+
182
+ max_len = 155
183
+ padded_predToken = np.array([predToken + [0]*(max_len-len(predToken))])
184
+ predAttention_mask = np.where(padded_predToken != 0, 1, 0)
185
+
186
+ input_idsPred = torch.tensor(padded_predToken)
187
+ attention_maskPred = torch.tensor(predAttention_mask)
188
+
189
+ return {"input_idsPred": input_idsPred, "attention_maskPred": attention_maskPred}
190
+ except Exception as error:
191
+ print("{}".format(str(error)))
192
+ return -1
193
+
194
+ def _forward(self, model_inputs):
195
+ input_idsPred = model_inputs["input_idsPred"]
196
+ attention_maskPred = model_inputs["attention_maskPred"]
197
+ self.model = self.model.to("cuda") # Ensure model is on CUDA if available
198
+
199
+ with torch.no_grad():
200
+ output = self.model(input_idsPred.to("cuda"), token_type_ids=None, attention_mask=attention_maskPred.to("cuda"))
201
+ prediction = 1 if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 0
202
+
203
+ return {"logits": "agri" if prediction == 1 else "non-agri"}
204
+
205
+ def postprocess(self, model_outputs, **kwargs):
206
+ return model_outputs["logits"]
207
+
config.json CHANGED
@@ -1,10 +1,19 @@
1
  {
2
- "_name_or_path": "bert-base-uncased",
3
  "architectures": [
4
  "BertForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
 
 
 
 
 
 
 
 
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
 
1
  {
2
+ "_name_or_path": "divyanshu94/agriBERT_clfModel",
3
  "architectures": [
4
  "BertForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
8
+ "custom_pipelines": {
9
+ "agri-classification": {
10
+ "impl": "agri_custom_pipeline.AgriClfPipeline",
11
+ "pt": [
12
+ "BertForSequenceClassification"
13
+ ],
14
+ "tf": []
15
+ }
16
+ },
17
  "gradient_checkpointing": false,
18
  "hidden_act": "gelu",
19
  "hidden_dropout_prob": 0.1,
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:45309f28b751ad0638bfd0282d791aea0a1c9cebca6c91fcbe32fed149104c6b
3
- size 438003505
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a8d8c2c691a6ff22dc0c8e73f309e492cdeeaf4e0c956731b9ee9123e3f5846
3
+ size 438000689
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -6,7 +6,7 @@
6
  "normalized": false,
7
  "rstrip": false,
8
  "single_word": false,
9
- "special": false
10
  },
11
  "100": {
12
  "content": "[UNK]",
@@ -14,7 +14,7 @@
14
  "normalized": false,
15
  "rstrip": false,
16
  "single_word": false,
17
- "special": false
18
  },
19
  "101": {
20
  "content": "[CLS]",
@@ -22,7 +22,7 @@
22
  "normalized": false,
23
  "rstrip": false,
24
  "single_word": false,
25
- "special": false
26
  },
27
  "102": {
28
  "content": "[SEP]",
@@ -30,7 +30,7 @@
30
  "normalized": false,
31
  "rstrip": false,
32
  "single_word": false,
33
- "special": false
34
  },
35
  "103": {
36
  "content": "[MASK]",
@@ -38,7 +38,7 @@
38
  "normalized": false,
39
  "rstrip": false,
40
  "single_word": false,
41
- "special": false
42
  }
43
  },
44
  "additional_special_tokens": [],
@@ -54,6 +54,5 @@
54
  "strip_accents": null,
55
  "tokenize_chinese_chars": true,
56
  "tokenizer_class": "DistilBertTokenizer",
57
- "tokenizer_file": "/root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/6cdc0aad91f5ae2e6712e91bc7b65d1cf5c05411/tokenizer.json",
58
  "unk_token": "[UNK]"
59
  }
 
6
  "normalized": false,
7
  "rstrip": false,
8
  "single_word": false,
9
+ "special": true
10
  },
11
  "100": {
12
  "content": "[UNK]",
 
14
  "normalized": false,
15
  "rstrip": false,
16
  "single_word": false,
17
+ "special": true
18
  },
19
  "101": {
20
  "content": "[CLS]",
 
22
  "normalized": false,
23
  "rstrip": false,
24
  "single_word": false,
25
+ "special": true
26
  },
27
  "102": {
28
  "content": "[SEP]",
 
30
  "normalized": false,
31
  "rstrip": false,
32
  "single_word": false,
33
+ "special": true
34
  },
35
  "103": {
36
  "content": "[MASK]",
 
38
  "normalized": false,
39
  "rstrip": false,
40
  "single_word": false,
41
+ "special": true
42
  }
43
  },
44
  "additional_special_tokens": [],
 
54
  "strip_accents": null,
55
  "tokenize_chinese_chars": true,
56
  "tokenizer_class": "DistilBertTokenizer",
 
57
  "unk_token": "[UNK]"
58
  }