import os import tensorflow as tf import pandas as pd GUIDE_LEN = 23 NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T'], [0, 1, 2, 3])) # load model if os.path.exists('model'): tiger = tf.keras.models.load_model('model') else: print('no saved model!') exit() def process_data(x): x = [item.upper() for item in x] number_of_input = len(x) - GUIDE_LEN + 1 input_gens = [] for i in range(number_of_input): input_gens.append("".join(x[i:i + GUIDE_LEN])) merged_token = [] token_x = [NUCLEOTIDE_TOKENS[item] for item in x] for i in range(number_of_input): merged_token.extend(token_x[i:i + GUIDE_LEN]) one_hot_x = tf.one_hot(merged_token, depth=4) model_input_x = tf.reshape(one_hot_x, [-1, GUIDE_LEN * 4]) return input_gens, model_input_x def tiger_predict(transcript_seq: str): # parse transcript sequence into 23-nt target sequences and their one-hot encodings target_seq, target_seq_one_hot = process_data(transcript_seq) # get predictions normalized_lfc = tiger.predict_step(target_seq_one_hot) predictions = pd.DataFrame({'Target site': target_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()}) return predictions if __name__ == '__main__': # simple test case transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT' df = tiger_predict(transcript_sequence) print(df)