Spaces:
Runtime error
Runtime error
Commit
·
8c44dfd
1
Parent(s):
e6827a9
Create tridentmodel.py
Browse files- tridentmodel.py +241 -0
tridentmodel.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""TridentModel.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colaboratory.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1u07dSU0DoKnNzGzySXMTisXnaloqpUEO
|
| 8 |
+
|
| 9 |
+
TRIDENT MODEL IMPLEMENTATION
|
| 10 |
+
|
| 11 |
+
Date: 14 January 2023
|
| 12 |
+
Authors: Egheosa Ogbomo & Amran Mohammed (The Polymer Guys)
|
| 13 |
+
Description: This script combines three ML-based models to identify whether an input text is related to green plastics or not.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
pip install transformers
|
| 17 |
+
|
| 18 |
+
########## IMPORTING REQUIRED PYTHON PACKAGES ##########
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
import numpy as np
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
from transformers import AutoTokenizer, AutoModel
|
| 24 |
+
import torch
|
| 25 |
+
import math
|
| 26 |
+
import time
|
| 27 |
+
import csv
|
| 28 |
+
import pandas as pd
|
| 29 |
+
import nltk
|
| 30 |
+
from nltk.tokenize import word_tokenize
|
| 31 |
+
from nltk.corpus import stopwords
|
| 32 |
+
nltk.download('stopwords')
|
| 33 |
+
nltk.download('punkt')
|
| 34 |
+
import string
|
| 35 |
+
|
| 36 |
+
########## DEFINING FUNCTIONS FOR MODEL IMPLEMENTATIONS ##########
|
| 37 |
+
|
| 38 |
+
### Input data cleaner
|
| 39 |
+
all_stopwords = stopwords.words('english') # Making sure to only use English stopwords
|
| 40 |
+
extra_stopwords = ['ii', 'iii'] # Can add extra stopwords to be removed from dataset/input abstracts
|
| 41 |
+
all_stopwords.extend(extra_stopwords)
|
| 42 |
+
def clean_data(input, type='Dataframe'):
|
| 43 |
+
"""
|
| 44 |
+
As preparation for use with the text similarity model, this function removes superfluous data from either a dataframe full of
|
| 45 |
+
classifications, or an input string, in order for embeddings to be calculated for them. Removes:
|
| 46 |
+
• Entries with missing abstracts/descriptions/classifications/typos
|
| 47 |
+
• Duplicate entries
|
| 48 |
+
• Unnecessary punctuation
|
| 49 |
+
• Stop words (e.g., by, a , an, he, she, it)
|
| 50 |
+
• URLs
|
| 51 |
+
• All entries are in the same language
|
| 52 |
+
|
| 53 |
+
:param input: Either a dataframe or an individual string
|
| 54 |
+
:param type: Tells fucntion whether input is a dataframe or an individual string
|
| 55 |
+
:return: (if dataframe), returns a dataframe containing CPC classfication codes and their associated 'cleaned' description
|
| 56 |
+
:return: (if string), returns a 'cleaned' version of the input string
|
| 57 |
+
"""
|
| 58 |
+
if type == 'Dataframe':
|
| 59 |
+
cleaneddf = pd.DataFrame(columns=['Class', 'Description'])
|
| 60 |
+
for i in range(0, len(input)):
|
| 61 |
+
row_list = input.loc[i, :].values.flatten().tolist()
|
| 62 |
+
noNaN_row = [x for x in row_list if str(x) != 'nan']
|
| 63 |
+
listrow = []
|
| 64 |
+
if len(noNaN_row) > 0:
|
| 65 |
+
row = noNaN_row[:-1]
|
| 66 |
+
row = [x.strip() for x in row]
|
| 67 |
+
row = (" ").join(row)
|
| 68 |
+
text_tokens = word_tokenize(row) # splits abstracts into individual tokens to allow removal of stopwords by list comprehension
|
| 69 |
+
Stopword_Filtered_List = [word for word in text_tokens if not word in all_stopwords] # removes stopwords
|
| 70 |
+
row = (" ").join(Stopword_Filtered_List) # returns abstract to string form
|
| 71 |
+
removechars = ['[', ']', '{', '}', ';', '(', ')', ',', '.', ':', '/', '-', '#', '?', '@', '£', '$']
|
| 72 |
+
for char in removechars:
|
| 73 |
+
row = list(map(lambda x: x.replace(char, ''), row))
|
| 74 |
+
|
| 75 |
+
row = ''.join(row)
|
| 76 |
+
wnum = row.split(' ')
|
| 77 |
+
wnum = [x.lower() for x in wnum]
|
| 78 |
+
#remove duplicate words
|
| 79 |
+
wnum = list(dict.fromkeys(wnum))
|
| 80 |
+
#removing numbers
|
| 81 |
+
wonum = []
|
| 82 |
+
for x in wnum:
|
| 83 |
+
xv = list(x)
|
| 84 |
+
xv = [i.isnumeric() for i in xv]
|
| 85 |
+
if True in xv:
|
| 86 |
+
continue
|
| 87 |
+
else:
|
| 88 |
+
wonum.append(x)
|
| 89 |
+
row = ' '.join(wonum)
|
| 90 |
+
l = [noNaN_row[-1], row]
|
| 91 |
+
cleaneddf.loc[len(cleaneddf)] = l
|
| 92 |
+
cleaneddf = cleaneddf.drop_duplicates(subset=['Description'])
|
| 93 |
+
cleaneddf.to_csv('E:/Users/eeo21/Startup/CPC_Classifications_List/additionalcleanedclasses.csv', index=False)
|
| 94 |
+
return cleaneddf
|
| 95 |
+
|
| 96 |
+
elif type == 'String':
|
| 97 |
+
text_tokens = word_tokenize(input) # splits abstracts into individual tokens to allow removal of stopwords by list comprehension
|
| 98 |
+
Stopword_Filtered_List = [word for word in text_tokens if not word in all_stopwords] # removes stopwords
|
| 99 |
+
row = (" ").join(Stopword_Filtered_List) # returns abstract to string form
|
| 100 |
+
removechars = ['[', ']', '{', '}', ';', '(', ')', ',', '.', ':', '/', '-', '#', '?', '@', '£', '$']
|
| 101 |
+
for char in removechars:
|
| 102 |
+
row = list(map(lambda x: x.replace(char, ''), row))
|
| 103 |
+
row = ''.join(row)
|
| 104 |
+
wnum = row.split(' ')
|
| 105 |
+
wnum = [x.lower() for x in wnum]
|
| 106 |
+
# remove duplicate words
|
| 107 |
+
wnum = list(dict.fromkeys(wnum))
|
| 108 |
+
# removing numbers
|
| 109 |
+
wonum = []
|
| 110 |
+
for x in wnum:
|
| 111 |
+
xv = list(x)
|
| 112 |
+
xv = [i.isnumeric() for i in xv]
|
| 113 |
+
if True in xv:
|
| 114 |
+
continue
|
| 115 |
+
else:
|
| 116 |
+
wonum.append(x)
|
| 117 |
+
row = ' '.join(wonum)
|
| 118 |
+
return row
|
| 119 |
+
|
| 120 |
+
### Mean Pooler
|
| 121 |
+
"""
|
| 122 |
+
Performs a mean pooling to reduce dimension of embedding
|
| 123 |
+
"""
|
| 124 |
+
def mean_pooling(model_output, attention_mask):
|
| 125 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
| 126 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 127 |
+
return tf.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.clip_by_value(input_mask_expanded.sum(1), clip_value_min=1e-9, clip_value_max=math.inf)
|
| 128 |
+
|
| 129 |
+
### Sentence Embedder
|
| 130 |
+
def sentence_embedder(sentences, model_path):
|
| 131 |
+
"""
|
| 132 |
+
Calling the sentence similarity model to generate embeddings on input text.
|
| 133 |
+
:param sentences: takes input text in the form of a string
|
| 134 |
+
:param model_path: path to the text similarity model
|
| 135 |
+
:return returns a (1, 384) embedding of the input text
|
| 136 |
+
"""
|
| 137 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path) #instantiating the sentence embedder using HuggingFace library
|
| 138 |
+
model = AutoModel.from_pretrained(model_path, from_tf=True) #making a model instance
|
| 139 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
| 140 |
+
# Compute token embeddings
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
model_output = model(**encoded_input)
|
| 143 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) #outputs a (1, 384) tensor representation of input text
|
| 144 |
+
return sentence_embeddings
|
| 145 |
+
|
| 146 |
+
### Sentence Embedding Preparation Function
|
| 147 |
+
def convert_saved_embeddings(embedding_string):
|
| 148 |
+
"""
|
| 149 |
+
Preparing pre-computed embeddings for use for comparison with new abstract embeddings .
|
| 150 |
+
Pre-computed embeddings are saved as tensors in string format so need to be converted back to numpy arrays in order to calculate cosine similarity.
|
| 151 |
+
:param embedding_string:
|
| 152 |
+
:return: Should be a single tensor with dims (,384) in string formate
|
| 153 |
+
"""
|
| 154 |
+
embedding = embedding_string.replace('(', '')
|
| 155 |
+
embedding = embedding.replace(')', '')
|
| 156 |
+
embedding = embedding.replace('[', '')
|
| 157 |
+
embedding = embedding.replace(']', '')
|
| 158 |
+
embedding = embedding.replace('tensor', '')
|
| 159 |
+
embedding = embedding.replace(' ', '')
|
| 160 |
+
embedding = embedding.split(',')
|
| 161 |
+
embedding = [float(x) for x in embedding]
|
| 162 |
+
embedding = np.array(embedding)
|
| 163 |
+
embedding = np.expand_dims(embedding, axis=0)
|
| 164 |
+
embedding = torch.from_numpy(embedding)
|
| 165 |
+
return embedding
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
### Generating Class Embeddings
|
| 169 |
+
|
| 170 |
+
Model_Path = 'Model_bert' ### Insert Path to MODEL DIRECTORY here
|
| 171 |
+
def class_embbedding_generator(classes):
|
| 172 |
+
"""
|
| 173 |
+
This function is to be used to generate and save class embeddings
|
| 174 |
+
Takes an input of 'cleaned' classes, generated by clean_data function, and computes vector representations of these classes (the embeddings) and saves them to csv
|
| 175 |
+
:classes: Classes should be a dataframe including all of broad scope classes that are intended to be used to make comparisons with
|
| 176 |
+
"""
|
| 177 |
+
class_embeddings = pd.DataFrame(columns=['Class', 'Description', 'Embedding'])
|
| 178 |
+
for i in range(len(classes)):
|
| 179 |
+
class_name = classes.iloc[i, 0]
|
| 180 |
+
print(class_name)
|
| 181 |
+
class_description = classes.iloc[i, 1]
|
| 182 |
+
class_description_embedding = sentence_embedder(class_description, Model_Path)
|
| 183 |
+
class_description_embedding = class_description_embedding.numpy()
|
| 184 |
+
class_description_embedding = torch.from_numpy(class_description_embedding)
|
| 185 |
+
embedding_entry = [class_name, class_description, class_description_embedding]
|
| 186 |
+
class_embeddings.loc[len(class_embeddings)] = embedding_entry
|
| 187 |
+
|
| 188 |
+
### Broad Scope Classifier
|
| 189 |
+
Model_Path = 'Model_bert' ### Insert Path to MODEL DIRECTORY here
|
| 190 |
+
def broad_scope_class_predictor(class_embeddings, abstract_embedding, N=5, Sensitivity='Medium'):
|
| 191 |
+
"""
|
| 192 |
+
Takes in pre-computed class embeddings and abstract texts, converts abstract text into
|
| 193 |
+
:param class_embeddings: dataframe of class embeddings
|
| 194 |
+
:param abstract: a single abstract embedding
|
| 195 |
+
:param N: N highest matching classes to return, from highest to lowest, default is 5
|
| 196 |
+
:return: predictions: a full dataframe of all the predictions on the 9500+ classes, HighestSimilarity: Dataframe of the N most similar classes
|
| 197 |
+
"""
|
| 198 |
+
predictions = pd.DataFrame(columns=['Class Name', 'Score'])
|
| 199 |
+
for i in range(len(class_embeddings)):
|
| 200 |
+
class_name = class_embeddings.iloc[i, 0]
|
| 201 |
+
embedding = class_embeddings.iloc[i, 2]
|
| 202 |
+
embedding = convert_saved_embeddings(embedding)
|
| 203 |
+
abstract_embedding = abstract_embedding.numpy()
|
| 204 |
+
abstract_embedding = torch.from_numpy(abstract_embedding)
|
| 205 |
+
cos = torch.nn.CosineSimilarity(dim=1)
|
| 206 |
+
score = cos(abstract_embedding, embedding).numpy().tolist()
|
| 207 |
+
result = [class_name, score[0]]
|
| 208 |
+
predictions.loc[len(predictions)] = result
|
| 209 |
+
greenpredictions = predictions.tail(52)
|
| 210 |
+
if Sensitivity == 'High':
|
| 211 |
+
Threshold = 0.5
|
| 212 |
+
elif Sensitivity == 'Medium':
|
| 213 |
+
Threshold = 0.40
|
| 214 |
+
elif Sensitivity == 'Low':
|
| 215 |
+
Threshold = 0.35
|
| 216 |
+
GreenLikelihood = 'False'
|
| 217 |
+
for i in range(len(greenpredictions)):
|
| 218 |
+
score = greenpredictions.iloc[i, 1]
|
| 219 |
+
if float(score) >= Threshold:
|
| 220 |
+
GreenLikelihood = 'True'
|
| 221 |
+
break
|
| 222 |
+
else:
|
| 223 |
+
continue
|
| 224 |
+
HighestSimilarity = predictions.nlargest(N, ['Score'])
|
| 225 |
+
print(HighestSimilarity)
|
| 226 |
+
print(GreenLikelihood)
|
| 227 |
+
return predictions, HighestSimilarity, GreenLikelihood
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
########## LOADING PRE-COMPUTED EMBEDDINGS ##########
|
| 231 |
+
class_embeddings = pd.read_csv('ClassEmbedd/MainClassEmbeddings.csv')
|
| 232 |
+
|
| 233 |
+
abstract = """
|
| 234 |
+
Described herein are strength characteristics and biodegradation of articles produced using one or more “green” sustainable polymers and one or more carbohydrate-based polymers. A compatibilizer can optionally be included in the article. In some cases, the article can include a film, a bag, a bottle, a cap or lid therefore, a sheet, a box or other container, a plate, a cup, utensils, or the like.
|
| 235 |
+
"""
|
| 236 |
+
abstract= clean_data(abstract, type='String')
|
| 237 |
+
abstract_embedding = sentence_embedder(abstract, Model_Path)
|
| 238 |
+
Number = 10
|
| 239 |
+
broad_scope_predictions = broad_scope_class_predictor(class_embeddings, abstract_embedding, Number, Sensitivity='High')
|
| 240 |
+
|
| 241 |
+
print(broad_scope_class_predictor)
|