Domen Jemec commited on
Commit
a92d46c
·
1 Parent(s): 0df3e88
Files changed (2) hide show
  1. hla_class.py +4 -2
  2. pages/HLA_Type_Prediction.py +8 -8
hla_class.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import joblib
5
 
6
  asset_home = './assets/hla/'
 
7
 
8
  def tokenize_sequence(sequence):
9
  # configs
@@ -47,9 +48,10 @@ def bow_embedding(sequence):
47
 
48
  return bow_matrix
49
 
50
- def predict_class(encoding, conf_thresh=0.1):
51
  model_path = asset_home + 'model/rfm_hla.pkl'
52
  prediction = []
 
53
 
54
  model = joblib.load(model_path)
55
 
@@ -70,7 +72,7 @@ def predict_class(encoding, conf_thresh=0.1):
70
  if prob > conf_thresh: # Check if the probability is above the threshold
71
  # Get the class name from the model's classes_
72
  class_name = model.classes_[class_index]
73
- prediction = prediction + [f'HLA-{class_name}: {prob:.4f}'] # Print class name and probability
74
  any_class_above_threshold = True
75
 
76
  if not any_class_above_threshold:
 
4
  import joblib
5
 
6
  asset_home = './assets/hla/'
7
+ CONFIDENCE_THRESHOLD = 0.1
8
 
9
  def tokenize_sequence(sequence):
10
  # configs
 
48
 
49
  return bow_matrix
50
 
51
+ def predict_class(encoding, conf_thresh=CONFIDENCE_THRESHOLD):
52
  model_path = asset_home + 'model/rfm_hla.pkl'
53
  prediction = []
54
+ link_base = 'https://www.genecards.org/cgi-bin/carddisp.pl?gene='
55
 
56
  model = joblib.load(model_path)
57
 
 
72
  if prob > conf_thresh: # Check if the probability is above the threshold
73
  # Get the class name from the model's classes_
74
  class_name = model.classes_[class_index]
75
+ prediction = prediction + [f'[HLA-{class_name}]({link_base}HLA-{class_name}) with confidence {prob:.2f}'] # Print class name and probability
76
  any_class_above_threshold = True
77
 
78
  if not any_class_above_threshold:
pages/HLA_Type_Prediction.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import numpy
3
- from hla_class import bow_embedding, predict_class
4
 
5
- min_seq_length = 100
6
  max_seq_length = 10000
7
 
8
  st.set_page_config(page_title='HLA Gene Prediction')
@@ -14,17 +14,17 @@ st.markdown(
14
  )
15
 
16
 
17
- input_sequence = st.text_area('Enter your sequence to analyze',
18
- max_chars=max_seq_length,
19
  help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
20
  placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
 
21
 
22
  if st.button('Predict', disabled=len(input_sequence)< min_seq_length):
23
  enc_seq = bow_embedding(input_sequence)
24
- prediction = predict_class(enc_seq)
25
- combined_string = '\n'.join(prediction)
26
- st.markdown('## HLA Model Prediction')
27
- st.markdown(combined_string)
28
 
29
 
30
  footer='''<style>
 
1
  import streamlit as st
2
  import numpy
3
+ from hla_class import bow_embedding, predict_class, CONFIDENCE_THRESHOLD
4
 
5
+ min_seq_length = 5
6
  max_seq_length = 10000
7
 
8
  st.set_page_config(page_title='HLA Gene Prediction')
 
14
  )
15
 
16
 
17
+ input_sequence = st.text_area('Enter your sequence to analyze',max_chars=max_seq_length,
 
18
  help=f'enter between {min_seq_length} and {max_seq_length} characters without line breaks',
19
  placeholder='aactaaaagactgacaaaatttttagtctctcgAATCGGGG...')
20
+ conf_filt = st.slider('Confidence Cutoff', min_value=0.00, max_value=1.00, value=CONFIDENCE_THRESHOLD)
21
 
22
  if st.button('Predict', disabled=len(input_sequence)< min_seq_length):
23
  enc_seq = bow_embedding(input_sequence)
24
+ prediction = predict_class(enc_seq, conf_filt)
25
+ st.markdown('### HLA Model Prediction')
26
+ for pred in prediction:
27
+ st.markdown(pred)
28
 
29
 
30
  footer='''<style>