File size: 2,474 Bytes
89be9f9
 
 
 
 
59874d6
 
 
89be9f9
457a981
89be9f9
 
34274e5
 
 
 
 
 
59874d6
34274e5
457a981
 
 
 
 
34274e5
 
 
 
 
 
457a981
 
 
 
 
 
 
 
 
 
34274e5
457a981
89be9f9
 
814d067
34274e5
59874d6
 
 
 
 
 
 
457a981
 
814d067
 
457a981
 
814d067
 
 
 
 
 
 
34274e5
814d067
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import tensorflow as tf
import pandas as pd

GUIDE_LEN = 23
CONTEXT_5P = 3
CONTEXT_3P = 0
TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
NUCLEOTIDE_TOKENS = dict(zip(['A', 'C', 'G', 'T'], [0, 1, 2, 3]))
NUCLEOTIDE_COMPLEMENT = dict(zip(['A', 'C', 'G', 'T'], ['T', 'G', 'C', 'A']))


def process_data(transcript_seq: str):

    # convert to upper case
    transcript_seq = transcript_seq.upper()

    # get all target sites
    target_seq = [transcript_seq[i: i + TARGET_LEN] for i in range(len(transcript_seq) - TARGET_LEN)]

    # prepare guide sequences
    guide_seq = [seq[CONTEXT_5P:len(seq) - CONTEXT_3P] for seq in target_seq]
    guide_seq = [''.join([NUCLEOTIDE_COMPLEMENT[nt] for nt in list(seq)]) for seq in guide_seq]

    # tokenize sequence
    nucleotide_table = tf.lookup.StaticVocabularyTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(NUCLEOTIDE_TOKENS.keys()), dtype=tf.string),
            values=tf.constant(list(NUCLEOTIDE_TOKENS.values()), dtype=tf.int64)),
        num_oov_buckets=1)
    target_tokens = nucleotide_table.lookup(tf.stack([list(t) for t in target_seq], axis=0))
    guide_tokens = nucleotide_table.lookup(tf.stack([list(g) for g in guide_seq], axis=0))
    pad_5p = 255 * tf.ones([guide_tokens.shape[0], CONTEXT_5P], dtype=guide_tokens.dtype)
    pad_3p = 255 * tf.ones([guide_tokens.shape[0], CONTEXT_3P], dtype=guide_tokens.dtype)
    guide_tokens = tf.concat([pad_5p, guide_tokens, pad_3p], axis=1)

    # model inputs
    model_inputs = tf.concat([
        tf.reshape(tf.one_hot(target_tokens, depth=4), [len(target_seq), -1]),
        tf.reshape(tf.one_hot(guide_tokens, depth=4), [len(guide_tokens), -1]),
        ], axis=-1)

    return target_seq, guide_seq, model_inputs


def tiger_predict(transcript_seq: str):

    # load model
    if os.path.exists('model'):
        tiger = tf.keras.models.load_model('model')
    else:
        print('no saved model!')
        exit()

    # parse transcript sequence
    target_seq, guide_seq, model_inputs = process_data(transcript_seq)

    # get predictions
    normalized_lfc = tiger.predict_step(model_inputs)
    predictions = pd.DataFrame({'Guide': guide_seq, 'Normalized LFC': tf.squeeze(normalized_lfc).numpy()})

    return predictions


if __name__ == '__main__':

    # simple test case
    transcript_sequence = 'ACGTACGTACGTACGTACGTACGTACGTACGT'.lower()
    df = tiger_predict(transcript_sequence)
    print(df)