Domen Jemec commited on
Commit
0df3e88
·
1 Parent(s): e9f6911

hla model v0.1

Browse files
assets/hla/encoder/trimer_bow_hla.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"AAA": 0, "AAC": 1, "AAG": 2, "AAT": 3, "ACA": 4, "ACC": 5, "ACG": 6, "ACT": 7, "AGA": 8, "AGC": 9, "AGG": 10, "AGT": 11, "ATA": 12, "ATC": 13, "ATG": 14, "ATT": 15, "CAA": 16, "CAC": 17, "CAG": 18, "CAT": 19, "CCA": 20, "CCC": 21, "CCG": 22, "CCT": 23, "CGA": 24, "CGC": 25, "CGG": 26, "CGT": 27, "CTA": 28, "CTC": 29, "CTG": 30, "CTT": 31, "GAA": 32, "GAC": 33, "GAG": 34, "GAT": 35, "GCA": 36, "GCC": 37, "GCG": 38, "GCT": 39, "GGA": 40, "GGC": 41, "GGG": 42, "GGT": 43, "GTA": 44, "GTC": 45, "GTG": 46, "GTT": 47, "TAA": 48, "TAC": 49, "TAG": 50, "TAT": 51, "TCA": 52, "TCC": 53, "TCG": 54, "TCT": 55, "TGA": 56, "TGC": 57, "TGG": 58, "TGT": 59, "TTA": 60, "TTC": 61, "TTG": 62, "TTT": 63, "UNK": 64}
assets/hla/model/rfm_hla.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b91bfdc19c2dc37096dc1be9c5703766db0598ef9a2d2f9c0a941a9cd75d1db
3
+ size 7008153
hla_class.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import numpy as np
4
+ import joblib
5
+
6
+ asset_home = './assets/hla/'
7
+
8
+ def tokenize_sequence(sequence):
9
+ # configs
10
+ n = 3
11
+ stride = 1
12
+
13
+ ngrams = []
14
+
15
+ ## clean up string
16
+ clean_sequence = re.sub(r'[^a-zA-Z]', '', sequence)
17
+ clean_sequence = clean_sequence.upper()
18
+
19
+ ## tokenize
20
+ for i in range(0, len(clean_sequence) - n + 1, stride):
21
+ # Create an n-gram
22
+ ngram = clean_sequence[i:i + n]
23
+ # Add the n-gram to the list
24
+ ngrams.append(ngram)
25
+
26
+ tokens = ' '.join(ngrams)
27
+ return tokens
28
+
29
+ def bow_embedding(sequence):
30
+ embed_path = asset_home + 'encoder/trimer_bow_hla.json'
31
+ with open(embed_path, 'r') as json_file:
32
+ value_to_index = json.load(json_file)
33
+ uniq_len = len(value_to_index)
34
+ unknown_token = 'UNK'
35
+
36
+ token_seq = tokenize_sequence(sequence)
37
+
38
+ ## embed
39
+ bow_matrix = np.zeros((1,uniq_len), dtype=int)
40
+ tokens = token_seq.split(' ')
41
+ for value in tokens:
42
+ if value in value_to_index.keys():
43
+ col_idx = value_to_index[value]
44
+ else:
45
+ col_idx = value_to_index[unknown_token]
46
+ bow_matrix[0, col_idx] += 1
47
+
48
+ return bow_matrix
49
+
50
+ def predict_class(encoding, conf_thresh=0.1):
51
+ model_path = asset_home + 'model/rfm_hla.pkl'
52
+ prediction = []
53
+
54
+ model = joblib.load(model_path)
55
+
56
+ # Use predict_proba to get class probabilities
57
+ class_probabilities = model.predict_proba(encoding)
58
+
59
+ # Iterate over the class probabilities for each instance
60
+ for i, probs in enumerate(class_probabilities):
61
+
62
+ # Get indices that would sort the array in descending order
63
+ sorted_indices = np.argsort(probs)[::-1]
64
+
65
+ # Track if any class probability is above the threshold
66
+ any_class_above_threshold = False
67
+
68
+ for class_index in sorted_indices:
69
+ prob = probs[class_index]
70
+ if prob > conf_thresh: # Check if the probability is above the threshold
71
+ # Get the class name from the model's classes_
72
+ class_name = model.classes_[class_index]
73
+ prediction = prediction + [f'HLA-{class_name}: {prob:.4f}'] # Print class name and probability
74
+ any_class_above_threshold = True
75
+
76
+ if not any_class_above_threshold:
77
+ prediction = ['No class predicted']
78
+
79
+
80
+
81
+ return prediction
pages/HLA_Type_Prediction.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import numpy
 
3
 
4
  min_seq_length = 100
5
  max_seq_length = 10000
@@ -18,10 +19,12 @@ input_sequence = st.text_area('Enter your sequence to analyze',
18
  help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
19
  placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
20
 
21
- if st.button('Predict', disabled=len(input_sequence)>= min_seq_length):
22
- status_text.text('Prediction Complete')
23
- gene = 'HLA A'
24
- st.markdown(f'The predicted squence is {gene}')
 
 
25
 
26
 
27
  footer='''<style>
 
1
  import streamlit as st
2
  import numpy
3
+ from hla_class import bow_embedding, predict_class
4
 
5
  min_seq_length = 100
6
  max_seq_length = 10000
 
19
  help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
20
  placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
21
 
22
+ if st.button('Predict', disabled=len(input_sequence)< min_seq_length):
23
+ enc_seq = bow_embedding(input_sequence)
24
+ prediction = predict_class(enc_seq)
25
+ combined_string = '\n'.join(prediction)
26
+ st.markdown('## HLA Model Prediction')
27
+ st.markdown(combined_string)
28
 
29
 
30
  footer='''<style>
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  numpy
2
- pandas
 
 
 
 
1
  numpy
2
+ pandas
3
+ json
4
+ re
5
+ joblib