Spaces:
Running
Running
Domen Jemec commited on
Commit ·
0df3e88
1
Parent(s): e9f6911
hla model v0.1
Browse files- assets/hla/encoder/trimer_bow_hla.json +1 -0
- assets/hla/model/rfm_hla.pkl +3 -0
- hla_class.py +81 -0
- pages/HLA_Type_Prediction.py +7 -4
- requirements.txt +4 -1
assets/hla/encoder/trimer_bow_hla.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"AAA": 0, "AAC": 1, "AAG": 2, "AAT": 3, "ACA": 4, "ACC": 5, "ACG": 6, "ACT": 7, "AGA": 8, "AGC": 9, "AGG": 10, "AGT": 11, "ATA": 12, "ATC": 13, "ATG": 14, "ATT": 15, "CAA": 16, "CAC": 17, "CAG": 18, "CAT": 19, "CCA": 20, "CCC": 21, "CCG": 22, "CCT": 23, "CGA": 24, "CGC": 25, "CGG": 26, "CGT": 27, "CTA": 28, "CTC": 29, "CTG": 30, "CTT": 31, "GAA": 32, "GAC": 33, "GAG": 34, "GAT": 35, "GCA": 36, "GCC": 37, "GCG": 38, "GCT": 39, "GGA": 40, "GGC": 41, "GGG": 42, "GGT": 43, "GTA": 44, "GTC": 45, "GTG": 46, "GTT": 47, "TAA": 48, "TAC": 49, "TAG": 50, "TAT": 51, "TCA": 52, "TCC": 53, "TCG": 54, "TCT": 55, "TGA": 56, "TGC": 57, "TGG": 58, "TGT": 59, "TTA": 60, "TTC": 61, "TTG": 62, "TTT": 63, "UNK": 64}
|
assets/hla/model/rfm_hla.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b91bfdc19c2dc37096dc1be9c5703766db0598ef9a2d2f9c0a941a9cd75d1db
|
| 3 |
+
size 7008153
|
hla_class.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
import joblib
|
| 5 |
+
|
| 6 |
+
asset_home = './assets/hla/'
|
| 7 |
+
|
| 8 |
+
def tokenize_sequence(sequence):
|
| 9 |
+
# configs
|
| 10 |
+
n = 3
|
| 11 |
+
stride = 1
|
| 12 |
+
|
| 13 |
+
ngrams = []
|
| 14 |
+
|
| 15 |
+
## clean up string
|
| 16 |
+
clean_sequence = re.sub(r'[^a-zA-Z]', '', sequence)
|
| 17 |
+
clean_sequence = clean_sequence.upper()
|
| 18 |
+
|
| 19 |
+
## tokenize
|
| 20 |
+
for i in range(0, len(clean_sequence) - n + 1, stride):
|
| 21 |
+
# Create an n-gram
|
| 22 |
+
ngram = clean_sequence[i:i + n]
|
| 23 |
+
# Add the n-gram to the list
|
| 24 |
+
ngrams.append(ngram)
|
| 25 |
+
|
| 26 |
+
tokens = ' '.join(ngrams)
|
| 27 |
+
return tokens
|
| 28 |
+
|
| 29 |
+
def bow_embedding(sequence):
|
| 30 |
+
embed_path = asset_home + 'encoder/trimer_bow_hla.json'
|
| 31 |
+
with open(embed_path, 'r') as json_file:
|
| 32 |
+
value_to_index = json.load(json_file)
|
| 33 |
+
uniq_len = len(value_to_index)
|
| 34 |
+
unknown_token = 'UNK'
|
| 35 |
+
|
| 36 |
+
token_seq = tokenize_sequence(sequence)
|
| 37 |
+
|
| 38 |
+
## embed
|
| 39 |
+
bow_matrix = np.zeros((1,uniq_len), dtype=int)
|
| 40 |
+
tokens = token_seq.split(' ')
|
| 41 |
+
for value in tokens:
|
| 42 |
+
if value in value_to_index.keys():
|
| 43 |
+
col_idx = value_to_index[value]
|
| 44 |
+
else:
|
| 45 |
+
col_idx = value_to_index[unknown_token]
|
| 46 |
+
bow_matrix[0, col_idx] += 1
|
| 47 |
+
|
| 48 |
+
return bow_matrix
|
| 49 |
+
|
| 50 |
+
def predict_class(encoding, conf_thresh=0.1):
|
| 51 |
+
model_path = asset_home + 'model/rfm_hla.pkl'
|
| 52 |
+
prediction = []
|
| 53 |
+
|
| 54 |
+
model = joblib.load(model_path)
|
| 55 |
+
|
| 56 |
+
# Use predict_proba to get class probabilities
|
| 57 |
+
class_probabilities = model.predict_proba(encoding)
|
| 58 |
+
|
| 59 |
+
# Iterate over the class probabilities for each instance
|
| 60 |
+
for i, probs in enumerate(class_probabilities):
|
| 61 |
+
|
| 62 |
+
# Get indices that would sort the array in descending order
|
| 63 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 64 |
+
|
| 65 |
+
# Track if any class probability is above the threshold
|
| 66 |
+
any_class_above_threshold = False
|
| 67 |
+
|
| 68 |
+
for class_index in sorted_indices:
|
| 69 |
+
prob = probs[class_index]
|
| 70 |
+
if prob > conf_thresh: # Check if the probability is above the threshold
|
| 71 |
+
# Get the class name from the model's classes_
|
| 72 |
+
class_name = model.classes_[class_index]
|
| 73 |
+
prediction = prediction + [f'HLA-{class_name}: {prob:.4f}'] # Print class name and probability
|
| 74 |
+
any_class_above_threshold = True
|
| 75 |
+
|
| 76 |
+
if not any_class_above_threshold:
|
| 77 |
+
prediction = ['No class predicted']
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
return prediction
|
pages/HLA_Type_Prediction.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import numpy
|
|
|
|
| 3 |
|
| 4 |
min_seq_length = 100
|
| 5 |
max_seq_length = 10000
|
|
@@ -18,10 +19,12 @@ input_sequence = st.text_area('Enter your sequence to analyze',
|
|
| 18 |
help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
|
| 19 |
placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
|
| 20 |
|
| 21 |
-
if st.button('Predict', disabled=len(input_sequence)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
footer='''<style>
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import numpy
|
| 3 |
+
from hla_class import bow_embedding, predict_class
|
| 4 |
|
| 5 |
min_seq_length = 100
|
| 6 |
max_seq_length = 10000
|
|
|
|
| 19 |
help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
|
| 20 |
placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
|
| 21 |
|
| 22 |
+
if st.button('Predict', disabled=len(input_sequence)< min_seq_length):
|
| 23 |
+
enc_seq = bow_embedding(input_sequence)
|
| 24 |
+
prediction = predict_class(enc_seq)
|
| 25 |
+
combined_string = '\n'.join(prediction)
|
| 26 |
+
st.markdown('## HLA Model Prediction')
|
| 27 |
+
st.markdown(combined_string)
|
| 28 |
|
| 29 |
|
| 30 |
footer='''<style>
|
requirements.txt
CHANGED
|
@@ -1,2 +1,5 @@
|
|
| 1 |
numpy
|
| 2 |
-
pandas
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
numpy
|
| 2 |
+
pandas
|
| 3 |
+
json
|
| 4 |
+
re
|
| 5 |
+
joblib
|