varshith1110 commited on
Commit
e4d634a
·
verified ·
1 Parent(s): a53b362

Upload llmprop_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llmprop_utils.py +151 -0
llmprop_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import glob
4
+ import torch
5
+ import tarfile
6
+ import datetime
7
+
8
+ # for metrics
9
+ from torchmetrics.classification import BinaryAUROC
10
+ from sklearn.metrics import roc_auc_score
11
+
12
+ def writeToJSON(data, where_to_save):
13
+ """
14
+ data: a dictionary that contains data to save
15
+ where_to_save: the name of the file to write on
16
+ """
17
+ with open(where_to_save, "w", encoding="utf8") as outfile:
18
+ json.dump(data, outfile)
19
+
20
+ def readJSON(input_file):
21
+ """
22
+ 1. arguments
23
+ input_file: a json file to read
24
+ 2. output
25
+ a json objet in a form of a dictionary
26
+ """
27
+ with open(input_file, "r", encoding="utf-8", errors='ignore') as infile:
28
+ json_object = json.load(infile, strict=False)
29
+ return json_object
30
+
31
+ def writeTEXT(data, where_to_save):
32
+ with open(where_to_save, "w", encoding="utf-8") as outfile:
33
+ for d in data:
34
+ outfile.write(str(d))
35
+ outfile.write("\n")
36
+
37
+ def readTEXT_to_LIST(input_file):
38
+ with open(input_file, "r", encoding="utf-8") as infile:
39
+ data = []
40
+ for line in infile:
41
+ data.append(line)
42
+ return data
43
+
44
+ def saveCSV(df, where_to_save):
45
+ df.to_csv(where_to_save, index=False)
46
+
47
+ def time_format(total_time):
48
+ """
49
+ Change the from seconds to hh:mm:ss
50
+ """
51
+ total_time_rounded = int(round((total_time)))
52
+ total_time_final = str(datetime.timedelta(seconds=total_time_rounded))
53
+ return total_time_final
54
+
55
+ def z_normalizer(labels):
56
+ """ Implement a z-score normalization technique"""
57
+ labels_mean = torch.mean(labels)
58
+ labels_std = torch.std(labels)
59
+ # Guard against division by zero when all labels are identical (std == 0)
60
+ labels_std = labels_std.clamp(min=1e-8)
61
+
62
+ scaled_labels = (labels - labels_mean) / labels_std
63
+
64
+ return scaled_labels
65
+
66
+ def z_denormalize(scaled_labels, labels_mean, labels_std):
67
+ labels = (scaled_labels * labels_std) + labels_mean
68
+ return labels
69
+
70
+ def min_max_scaling(labels):
71
+ """ Implement a min-max normalization technique"""
72
+ min_val = torch.min(labels)
73
+ max_val = torch.max(labels)
74
+ diff = max_val - min_val
75
+ # Guard against division by zero when all labels are identical
76
+ diff = diff.clamp(min=1e-8)
77
+ scaled_labels = (labels - min_val) / diff
78
+ return scaled_labels
79
+
80
+ def mm_denormalize(scaled_labels, min_val, max_val):
81
+ diff = max_val - min_val
82
+ denorm_labels = (scaled_labels * diff) + min_val
83
+ return denorm_labels
84
+
85
+ def log_scaling(labels):
86
+ """ Implement log-scaling normalization technique"""
87
+ scaled_labels = torch.log1p(labels)
88
+ return scaled_labels
89
+
90
+ def ls_denormalize(scaled_labels):
91
+ denorm_labels = torch.expm1(scaled_labels)
92
+ return denorm_labels
93
+
94
+ def compressCheckpointsWithTar(filename):
95
+ filename_for_tar = filename[0:-3]
96
+ tar = tarfile.open(f"{filename_for_tar}.tar.gz", "w:gz")
97
+ tar.add(filename)
98
+ tar.close()
99
+
100
+ def decompressTarCheckpoints(tar_filename):
101
+ tar = tarfile.open(tar_filename)
102
+ tar.extractall()
103
+ tar.close()
104
+
105
+ def replace_bond_lengths_with_num(sentence):
106
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*Å", "[NUM]", sentence) # Regex pattern to match bond lengths and units
107
+ return sentence.strip()
108
+
109
+ def replace_bond_angles_with_ang(sentence):
110
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*°", "[ANG]", sentence) # Regex pattern to match angles and units
111
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*degrees", "[ANG]", sentence) # Regex pattern to match angles and units
112
+ return sentence.strip()
113
+
114
+ def replace_bond_lengths_and_angles_with_num_and_ang(sentence):
115
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*Å", "[NUM]", sentence) # Regex pattern to match bond lengths and units
116
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*°", "[ANG]", sentence) # Regex pattern to match angles and units
117
+ sentence = re.sub(r"\d+(\.\d+)?(?:–\d+(\.\d+)?)?\s*degrees", "[ANG]", sentence) # Regex pattern to match angles and units
118
+ return sentence.strip()
119
+
120
+ def get_cleaned_stopwords():
121
+ # from https://github.com/igorbrigadir/stopwords
122
+ stopword_files = glob.glob("stopwords/en/*.txt")
123
+ num_str = {'one','two','three','four','five','six','seven','eight','nine'}
124
+
125
+ all_stopwords_list = set()
126
+
127
+ for file_path in stopword_files:
128
+ all_stopwords_list |= set(readTEXT_to_LIST(file_path))
129
+
130
+ cleaned_list_for_mat = {wrd.replace("\n", "").strip() for wrd in all_stopwords_list} - {wrd for wrd in all_stopwords_list if wrd.isdigit()} - num_str
131
+
132
+ return cleaned_list_for_mat
133
+
134
+ def remove_mat_stopwords(sentence):
135
+ stopwords_list = get_cleaned_stopwords()
136
+ words = sentence.split()
137
+ words_lower = sentence.lower().split()
138
+ sentence = ' '.join([words[i] for i in range(len(words)) if words_lower[i] not in stopwords_list])
139
+ return sentence
140
+
141
+ def get_sequence_len_stats(df, tokenizer, max_len):
142
+ training_on = sum(1 for sent in df['description'].apply(tokenizer.tokenize) if len(sent) <= max_len)
143
+ return (training_on/len(df))*100
144
+
145
+ def get_roc_score(predictions, targets):
146
+ roc_fn = BinaryAUROC(threshold=None)
147
+ x = torch.tensor(targets)
148
+ y = torch.tensor(predictions)
149
+ y = torch.round(torch.sigmoid(y))
150
+ roc_score = roc_fn(y, x)
151
+ return roc_score