Spaces:
Build error
Build error
| """ | |
| tf_hub.py | |
| Find text embeddings using pre-trained TensorFlow Hub models | |
| """ | |
| import os | |
| import pickle | |
| import numpy as np | |
| from arxiv_public_data.config import DIR_OUTPUT, LOGGER | |
| from arxiv_public_data.embeddings.util import batch_fulltext | |
| logger = LOGGER.getChild('embds') | |
| try: | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| import sentencepiece as spm | |
| except ImportError as e: | |
| logger.warn("This module requires 'tensorflow', 'tensorflow-hub', and" | |
| "'sentencepiece'\n" | |
| 'Please install these modules to use tf_hub.py') | |
| UNIV_SENTENCE_ENCODER_URL = ('https://tfhub.dev/google/' | |
| 'universal-sentence-encoder/2') | |
| ELMO_URL = "https://tfhub.dev/google/elmo/2" | |
| ELMO_KWARGS = dict(signature='default', as_dict=True) | |
| ELMO_MODULE_KWARGS = dict(trainable=True) | |
| ELMO_DICTKEY = 'default' | |
| DIR_EMBEDDING = os.path.join(DIR_OUTPUT, 'embeddings') | |
| if not os.path.exists(DIR_EMBEDDING): | |
| os.mkdir(DIR_EMBEDDING) | |
| def elmo_strings(batches, filename, batchsize=32): | |
| """ | |
| Compute and save vector embeddings of lists of strings in batches | |
| Parameters | |
| ---------- | |
| batches : iterable of strings to be embedded | |
| filename : str | |
| filename to store embeddings | |
| (optional) | |
| batchsize : int | |
| size of batches | |
| """ | |
| g = tf.Graph() | |
| with g.as_default(): | |
| module = hub.Module(ELMO_URL, **ELMO_MODULE_KWARGS) | |
| text_input = tf.placeholder(dtype=tf.string, shape=[None]) | |
| embeddings = module(text_input, **ELMO_KWARGS) | |
| init_op = tf.group([tf.global_variables_initializer(), | |
| tf.tables_initializer()]) | |
| g.finalize() | |
| with tf.Session(graph=g) as sess: | |
| sess.run(init_op) | |
| for i, batch in enumerate(batches): | |
| # grab mean-pooling of contextualized word reps | |
| logger.info("Computing/saving batch {}".format(i)) | |
| with open(filename, 'ab') as fout: | |
| pickle.dump(sess.run( | |
| embeddings, feed_dict={text_input: batch} | |
| )[ELMO_DICTKEY], fout) | |
| UNIV_SENTENCE_LITE = "https://tfhub.dev/google/universal-sentence-encoder-lite/2" | |
| def get_sentence_piece_model(): | |
| with tf.Session() as sess: | |
| module = hub.Module(UNIV_SENTENCE_LITE) | |
| return sess.run(module(signature="spm_path")) | |
| def process_to_IDs_in_sparse_format(sp, sentences): | |
| """ | |
| An utility method that processes sentences with the sentence piece | |
| processor | |
| 'sp' and returns the results in tf.SparseTensor-similar format: | |
| (values, indices, dense_shape) | |
| """ | |
| ids = [sp.EncodeAsIds(x) for x in sentences] | |
| max_len = max(len(x) for x in ids) | |
| dense_shape=(len(ids), max_len) | |
| values=[item for sublist in ids for item in sublist] | |
| indices=[[row,col] for row in range(len(ids)) for col in range(len(ids[row]))] | |
| return (values, indices, dense_shape) | |
| def universal_sentence_encoder_lite(batches, filename, spm_path, batchsize=32): | |
| """ | |
| Compute and save vector embeddings of lists of strings in batches | |
| Parameters | |
| ---------- | |
| batches : iterable of strings to be embedded | |
| filename : str | |
| filename to store embeddings | |
| spm_path : str | |
| path to sentencepiece model from `get_sentence_piece_model` | |
| (optional) | |
| batchsize : int | |
| size of batches | |
| """ | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(spm_path) | |
| g = tf.Graph() | |
| with g.as_default(): | |
| module = hub.Module(UNIV_SENTENCE_LITE) | |
| input_placeholder = tf.sparse_placeholder( | |
| tf.int64, shape=(None, None) | |
| ) | |
| embeddings = module( | |
| inputs=dict( | |
| values=input_placeholder.values, indices=input_placeholder.indices, | |
| dense_shape=input_placeholder.dense_shape | |
| ) | |
| ) | |
| init_op = tf.group([tf.global_variables_initializer(), | |
| tf.tables_initializer()]) | |
| g.finalize() | |
| with tf.Session(graph=g) as sess: | |
| sess.run(init_op) | |
| for i, batch in enumerate(batches): | |
| values, indices, dense_shape = process_to_IDs_in_sparse_format(sp, batch) | |
| logger.info("Computing/saving batch {}".format(i)) | |
| emb = sess.run( | |
| embeddings, | |
| feed_dict={ | |
| input_placeholder.values: values, | |
| input_placeholder.indices: indices, | |
| input_placeholder.dense_shape: dense_shape | |
| } | |
| ) | |
| with open(filename, 'ab') as fout: | |
| pickle.dump(emb, fout) | |
| def create_save_embeddings(batches, filename, encoder, headers=[], encoder_args=(), | |
| encoder_kwargs={}, savedir=DIR_EMBEDDING): | |
| """ | |
| Create vector embeddings of strings and save them to filename | |
| Parameters | |
| ---------- | |
| batches : iterator of strings | |
| filename: str | |
| embeddings will be saved in DIR_EMBEDDING/embeddings/filename | |
| encoder : function(batches, savename, *args, **kwargs) | |
| encodes strings in batches into vectors and saves them | |
| (optional) | |
| headers : list of things to save in embeddings file first | |
| Examples | |
| -------- | |
| # For list of strings, create batched numpy array of objects | |
| batches = np.array_split( | |
| np.array(strings, dtype='object'), len(strings)//batchsize | |
| ) | |
| headers = [] | |
| # For the fulltext which cannot fit in memory, use `util.batch_fulltext` | |
| md_index, all_ids, batch_gen = batch_fulltext() | |
| headers = [md_index, all_ids] | |
| # Universal Sentence Encoder Lite: | |
| spm_path = get_sentence_piece_model() | |
| create_save_embeddings(batches, filename, universal_sentence_encoder_lite, | |
| headers=headers, encoder_args=(spm_path,)) | |
| # ELMO: | |
| create_save_embeddings(strings, filename, elmo_strings, headers=headers) | |
| """ | |
| if not os.path.exists(savedir): | |
| os.makedirs(savedir) | |
| savename = os.path.join(savedir, filename) | |
| with open(savename, 'ab') as fout: | |
| for h in headers: | |
| pickle.dump(h, fout) | |
| logger.info("Saving embeddings to {}".format(savename)) | |
| encoder(batches, savename, *encoder_args, | |
| **encoder_kwargs) | |