simple-geno-model / hla_class.py
Domen Jemec
updates
a92d46c
raw
history blame
2.56 kB
import json
import re
import numpy as np
import joblib
asset_home = './assets/hla/'
CONFIDENCE_THRESHOLD = 0.1
def tokenize_sequence(sequence):
# configs
n = 3
stride = 1
ngrams = []
## clean up string
clean_sequence = re.sub(r'[^a-zA-Z]', '', sequence)
clean_sequence = clean_sequence.upper()
## tokenize
for i in range(0, len(clean_sequence) - n + 1, stride):
# Create an n-gram
ngram = clean_sequence[i:i + n]
# Add the n-gram to the list
ngrams.append(ngram)
tokens = ' '.join(ngrams)
return tokens
def bow_embedding(sequence):
embed_path = asset_home + 'encoder/trimer_bow_hla.json'
with open(embed_path, 'r') as json_file:
value_to_index = json.load(json_file)
uniq_len = len(value_to_index)
unknown_token = 'UNK'
token_seq = tokenize_sequence(sequence)
## embed
bow_matrix = np.zeros((1,uniq_len), dtype=int)
tokens = token_seq.split(' ')
for value in tokens:
if value in value_to_index.keys():
col_idx = value_to_index[value]
else:
col_idx = value_to_index[unknown_token]
bow_matrix[0, col_idx] += 1
return bow_matrix
def predict_class(encoding, conf_thresh=CONFIDENCE_THRESHOLD):
model_path = asset_home + 'model/rfm_hla.pkl'
prediction = []
link_base = 'https://www.genecards.org/cgi-bin/carddisp.pl?gene='
model = joblib.load(model_path)
# Use predict_proba to get class probabilities
class_probabilities = model.predict_proba(encoding)
# Iterate over the class probabilities for each instance
for i, probs in enumerate(class_probabilities):
# Get indices that would sort the array in descending order
sorted_indices = np.argsort(probs)[::-1]
# Track if any class probability is above the threshold
any_class_above_threshold = False
for class_index in sorted_indices:
prob = probs[class_index]
if prob > conf_thresh: # Check if the probability is above the threshold
# Get the class name from the model's classes_
class_name = model.classes_[class_index]
prediction = prediction + [f'[HLA-{class_name}]({link_base}HLA-{class_name}) with confidence {prob:.2f}'] # Print class name and probability
any_class_above_threshold = True
if not any_class_above_threshold:
prediction = ['No class predicted']
return prediction