O047's picture
Update init.py
c507eca verified
# ============================================================================
# Full initialization, loader, HLA mapping helpers, and prediction functions.
# Paste this file into your project as init.py. Importing this module will
# initialize both MHC-I and MHC-II engines (ENGINE_MHC1, ENGINE_MHC2).
# ============================================================================
import re
import os
import json
from collections import defaultdict
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from huggingface_hub import hf_hub_download, login
from datasets import load_dataset
# OPTIM:
#from sklearn.metrics import precision_recall_curve
import pandas as pd
import numpy as np
from contextlib import nullcontext
# --- CONFIGURATION ---
TOKEN = os.getenv("HF_TOKEN")
if TOKEN is None:
raise ValueError("HF_TOKEN environment variable is not set.")
login(TOKEN)
DEFAULT_EL_THRESHOLD = 0.60
DEFAULT_BA_THRESHOLD = 0.60
MHC1_CONFIG = {
"model": "O047/esm2_MHC-I_Reforged_Single",
"mapping": "O047/MHC-I_HLA_Mapping",
"ba_db": "O047/MHC-I_BA_Data",
"el_db": "O047/MHC-I_EL_Data",
"eval_db": "O047/MHC-I_EVAL",
"reg_head": "regHead_MHC-I.pt",
"clf_head": "clfHead_MHC-I.pt"
}
MHC2_CONFIG = {
"model": "O047/esm2_MHC-II_Reforged_Single",
"mapping": "O047/MHC-II_HLA_Mapping",
"ba_db": "O047/MHC-II_BA_Data",
"el_db": "O047/MHC-II_EL_Data",
"eval_db": "O047/MHC-II_EVAL",
"reg_head": "regHead_MHC-II.pt",
"clf_head": "clfHead_MHC-II.pt"
}
# --- Advanced HLA parsing and mapping helpers (preserve all variants) ---
_ALLELE_RE = re.compile(r"([A-Za-z0-9]+)[\*\-_:]?(\d{2,3})[:_]?(\d{2})?$")
def _format_single_allele_token(token: str) -> str:
"""
Convert a single token like 'DRB1_0401', 'DRB10401', 'DRB1*04:01', 'DRB1-04-01'
into canonical 'HLA-DRB1*04:01'.
"""
if token is None:
return None
s = str(token).strip()
s = s.replace(" ", "").replace("/", "-")
s = re.sub(r"[-/]+", "-", s)
if s.upper().startswith("HLA-") and "*" in s and ":" in s:
return s if s.startswith("HLA-") else "HLA-" + s.split("HLA-")[-1]
m = _ALLELE_RE.match(s.replace("HLA-", "").replace("hla-", ""))
if m:
gene = m.group(1).upper()
part1 = m.group(2)
part2 = m.group(3) or ""
if part2:
formatted = f"HLA-{gene}*{part1}:{part2}"
else:
formatted = f"HLA-{gene}*{part1}"
return formatted
if "*" in s:
left, right = s.split("*", 1)
left = left.upper()
right = right.replace("_", ":").replace("-", ":")
if ":" not in right and len(right) >= 4:
right = right[:2] + ":" + right[2:]
return f"HLA-{left}*{right}"
return s
def _normalize_allele(a: str) -> str:
"""
Normalize allele or allele-pair strings into canonical lookup keys.
Handles:
- single alleles: 'DRB1_0401' -> 'HLA-DRB1*04:01'
- chain pairs: 'DPA1_04_01_DPB1_85_01' -> 'HLA-DPA1*04:01-DPB1*85:01'
- separators: '/', '_', '-', ' ' are tolerated
"""
if a is None:
return None
s = str(a).strip()
if s == "":
return s
# Split on explicit chain separators (dash or slash) but keep order
chain_tokens = re.split(r"[\/\-]+", s)
formatted_tokens = []
for tok in chain_tokens:
# split concatenated tokens heuristically
subtoks = re.split(r"(?=[A-Za-z]+[0-9])", tok)
subtoks = [st for st in subtoks if st]
if len(subtoks) == 1:
formatted_tokens.append(_format_single_allele_token(subtoks[0]))
else:
for st in subtoks:
formatted_tokens.append(_format_single_allele_token(st))
canonical = "-".join(formatted_tokens)
return canonical
def build_hla_map_preserve_variants(map_ds):
"""
Build a mapping that preserves original allele strings and normalized canonical keys.
Each original allele string and its normalized canonical form(s) map to the same pseudosequence.
"""
rows = list(map_ds)
raw_count = len(rows)
hla_map = {}
groups = defaultdict(list)
for r in rows:
orig = r.get("allele")
pseudo = r.get("pseudosequence")
if orig is None:
continue
orig_str = str(orig).strip()
norm = _normalize_allele(orig_str)
groups[norm].append((orig_str, pseudo))
duplicate_groups = 0
for norm_key, entries in groups.items():
chosen_pseudo = None
for orig, pseudo in entries:
if pseudo and str(pseudo).strip():
chosen_pseudo = pseudo
break
if chosen_pseudo is None:
chosen_pseudo = entries[0][1] if entries else None
for orig, _ in entries:
if orig:
hla_map[orig] = chosen_pseudo
if norm_key:
hla_map[norm_key] = chosen_pseudo
compact = norm_key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm_key else None
if compact:
hla_map[compact] = chosen_pseudo
if len(entries) > 1:
duplicate_groups += 1
stats = {
"raw_rows": raw_count,
"registered_keys": len(hla_map),
"normalized_groups": len(groups),
"duplicate_groups": duplicate_groups
}
return hla_map, stats
def resolve_allele_key(query, hla_map):
"""
Resolve a user-supplied allele string to a key present in hla_map.
Tries:
1) exact original string
2) canonical normalized form
3) compact form (no punctuation)
4) case variants
5) substring match fallback
"""
if query is None:
return None
q = str(query).strip()
if q in hla_map:
return q
norm = _normalize_allele(q)
if norm and norm in hla_map:
return norm
compact = norm.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "") if norm else None
if compact and compact in hla_map:
return compact
if q.upper() in hla_map:
return q.upper()
if q.lower() in hla_map:
return q.lower()
q_comp = q.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "")
for key in hla_map.keys():
key_comp = key.replace("*", "").replace(":", "").replace("-", "").replace("HLA", "")
if q_comp and q_comp in key_comp:
return key
return None
# --- MODEL ARCHITECTURES ---
class ProteinHead(nn.Module):
"""BA Head for Regression"""
def __init__(self, input_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(input_dim, input_dim // 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(input_dim // 2, 1)
)
def forward(self, x):
return self.mlp(x)
class ImprovedProteinHead(nn.Module):
"""EL Head for Classification"""
def __init__(self, input_dim, use_scale=False, scale_factor=1.0, use_bias=True, bias_value=-2.92):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(input_dim, input_dim // 4),
nn.Tanh(),
nn.Linear(input_dim // 4, 1)
)
self.mlp = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.LayerNorm(input_dim),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(input_dim, input_dim),
nn.LayerNorm(input_dim),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(input_dim, 1)
)
self.use_scale = use_scale
self.scale_factor = scale_factor
self.use_bias = use_bias
self.bias_value = bias_value
def forward(self, x):
attn_logits = self.attention(x)
attn_weights = torch.softmax(attn_logits, dim=1, dtype=torch.float32).to(x.dtype)
pooled = (x * attn_weights).sum(dim=1)
out = self.mlp(pooled)
if self.use_scale:
out = out * self.scale_factor
if self.use_bias:
out = out + self.bias_value
return torch.clamp(out, min=-10.0, max=10.0)
# --- THRESHOLD UTILITIES ---
""" OPTIM:
def compute_optimal_threshold(y_true, y_probs, default=0.6):
try:
if len(np.unique(y_true)) < 2:
return default
p, r, t = precision_recall_curve(y_true, y_probs)
f1 = (2 * p * r) / (p + r + 1e-8)
idx = np.argmax(f1)
return t[idx] if idx < len(t) else default
except Exception:
return default
def calculate_engine_thresholds(engine, eval_repo, sample_size=5000):
print(f" [*] Calculating dynamic thresholds from {eval_repo}...")
try:
eval_df = load_dataset(eval_repo, split="test", token=TOKEN).to_pandas()
if len(eval_df) > sample_size:
eval_df = eval_df.sample(sample_size, random_state=42).reset_index(drop=True)
if 'pseudosequence' in eval_df.columns:
eval_df.rename(columns={'pseudosequence': 'allele'}, inplace=True)
seqs = []
for _, row in eval_df.iterrows():
p = row["peptide"]
a = row["allele"]
c = row.get("context", "")
seqs.append(f"{p} [SEP] {a} [SEP] {c}")
ba_preds, el_probs = [], []
BATCH_SIZE = 512
device = engine.get("device", "cpu")
use_cuda_amp = (device != "cpu") and torch.cuda.is_available()
for i in range(0, len(seqs), BATCH_SIZE):
batch = seqs[i:i+BATCH_SIZE]
toks = engine['tokenizer'](batch, return_tensors="pt", padding=True, truncation=True, max_length=128)
try:
toks = toks.to(device)
except Exception:
toks = {k: v.to(device) for k, v in toks.items()}
amp_ctx = torch.cuda.amp.autocast() if use_cuda_amp else nullcontext()
with torch.no_grad(), amp_ctx:
outputs = engine['model'](**toks)
last_hidden = outputs.last_hidden_state
ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32)
ba_batch_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy()
ba_preds.extend(np.asarray(ba_batch_preds).ravel().tolist())
el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy()
el_batch_probs = 1.0 / (1.0 + np.exp(-np.asarray(el_logits).ravel()))
el_probs.extend(el_batch_probs.tolist())
y_true = eval_df["score"].values
el_thresh = compute_optimal_threshold(y_true, el_probs, default=0.6)
ba_thresh = compute_optimal_threshold(y_true, ba_preds, default=0.5)
print(f" [>] Calculated EL Threshold: {el_thresh:.4f}")
print(f" [>] Calculated BA Threshold: {ba_thresh:.4f}")
return el_thresh, ba_thresh
except Exception as e:
print(f" [!] Threshold calculation failed ({e}). Defaulting to 0.6 for both.")
return 0.6, 0.6
"""
# --- ENGINE LOADER (uses preserved-variant HLA map) ---
def load_inference_engine(config, device="cpu"):
"""Loads the model, heads, tokenizer, databases, and calculates thresholds."""
print(f"\n[*] Initializing Engine for {config['model']} on {device}...")
# 1. Load HLA Mapping (preserve variants)
print(" [>] Fetching HLA Registry...")
map_ds = load_dataset(config['mapping'], split="train", token=TOKEN)
# --- NEW: extract a plain copy of the allele column for frontend use ---
allele_list = list(map_ds['allele']) # this is a detached list, not a reference
# store it in the engine under a clear key
frontend_alleles = pd.DataFrame({"Allele": allele_list})
frontend_alleles_dict = {allele: idx for idx, allele in enumerate(allele_list)}
# continue with preprocessing for inference
hla_map, hla_stats = build_hla_map_preserve_variants(map_ds)
print(f" [>] HLA mapping rows: {hla_stats['raw_rows']}; registered lookup keys: {hla_stats['registered_keys']}; normalized groups: {hla_stats['normalized_groups']}; duplicate groups: {hla_stats['duplicate_groups']}")
# 2. Determine Latest Step
try:
state_path = hf_hub_download(repo_id=config['model'], filename="latest_training_state.txt", token=TOKEN)
with open(state_path, 'r') as f:
step = json.load(f)['last_step']
print(f" [>] Auto-detected latest checkpoint: Step {step}")
except Exception as e:
print(f" [!] Could not auto-detect step ({e}). Exiting.")
return None
ckpt_folder = f"checkpoints/step_{step}"
# 3. Load Tokenizer & Base
print(" [>] Loading Base Model & Tokenizer (ESM-2 650M)...")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
if "[SEP]" not in tokenizer.get_vocab():
tokenizer.add_special_tokens({'sep_token': '[SEP]'})
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base = AutoModel.from_pretrained(
"facebook/esm2_t33_650M_UR50D",
torch_dtype=dtype
)
base.resize_token_embeddings(len(tokenizer))
hidden_dim = base.config.hidden_size
# 4. Load LoRA
print(" [>] Loading LoRA Adapters...")
model = PeftModel.from_pretrained(base, config['model'], subfolder=ckpt_folder, token=TOKEN).to(device)
model.eval()
# 5. Load Heads (filenames come from config)
print(" [>] Loading Prediction Heads...")
regHead = ProteinHead(hidden_dim).to(device).float()
clfHead = ImprovedProteinHead(hidden_dim, use_bias=True, bias_value=-2.92).to(device).float()
try:
reg_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['reg_head'], token=TOKEN)
clf_path = hf_hub_download(repo_id=config['model'], subfolder=ckpt_folder, filename=config['clf_head'], token=TOKEN)
regHead_state = torch.load(reg_path, map_location=device)
if 'mlp.6.weight' in regHead_state:
regHead.load_state_dict(regHead_state, strict=False)
else:
if 'mlp.5.weight' in regHead_state:
regHead_state['mlp.6.weight'] = regHead_state.pop('mlp.5.weight')
regHead_state['mlp.6.bias'] = regHead_state.pop('mlp.5.bias')
regHead.load_state_dict(regHead_state, strict=False)
clfHead.load_state_dict(torch.load(clf_path, map_location=device), strict=False)
regHead.eval()
clfHead.eval()
except Exception as e:
print(f" [!] Error loading heads: {e}")
return None
""" OPTIM 1:
# 6. Preload Training DBs for Novelty Flags
print(" [>] Loading Training Databases for novelty checks...")
try:
ba_ds = load_dataset(config['ba_db'], split="train", token=TOKEN)
el_ds = load_dataset(config['el_db'], split="train", token=TOKEN)
ba_set = set(ba_ds['peptide'])
el_set = set(el_ds['peptide'])
except Exception as e:
print(f" [!] Error loading DBs: {e}. Novelty checks disabled.")
ba_set, el_set = set(), set()
"""
# 6. Novelty checks disabled for deployment
print(" [>] Novelty database loading disabled.")
ba_set = set()
el_set = set()
##########################################################
engine = {
"model": model,
"tokenizer": tokenizer,
"regHead": regHead,
"clfHead": clfHead,
"hla_map": hla_map,
"device": device,
"hidden_dim": hidden_dim,
"step": step,
"ba_db": ba_set,
"el_db": el_set,
# --- NEW: add the raw allele list for frontend ---
"frontend_alleles": frontend_alleles,
"frontend_alleles_hash": frontend_alleles_dict
}
""" OPTIM 2:
# thresholds etc...
el_thresh, ba_thresh = calculate_engine_thresholds(engine, config['eval_db'])
engine["el_threshold"] = el_thresh
engine["ba_threshold"] = ba_thresh
"""
# ------------------------------------------------------------------
# THRESHOLD LOADING
# ------------------------------------------------------------------
el_thresh = DEFAULT_EL_THRESHOLD
ba_thresh = DEFAULT_BA_THRESHOLD
try:
threshold_path = hf_hub_download(
repo_id=config['model'],
subfolder=ckpt_folder,
filename="thresholds.json",
token=TOKEN
)
with open(threshold_path, "r") as f:
threshold_data = json.load(f)
el_thresh = float(
threshold_data.get("el_threshold", DEFAULT_EL_THRESHOLD)
)
ba_thresh = float(
threshold_data.get("ba_threshold", DEFAULT_BA_THRESHOLD)
)
print("\n" + "=" * 80)
print("THRESHOLDS LOADED FROM CHECKPOINT")
print(f"MODEL : {config['model']}")
print(f"EL THRESHOLD : {el_thresh:.6f}")
print(f"BA THRESHOLD : {ba_thresh:.6f}")
print("=" * 80 + "\n")
except Exception as e:
print("\n" + "=" * 80)
print("WARNING: CHECKPOINT THRESHOLDS NOT FOUND")
print(f"MODEL : {config['model']}")
print(f"REASON: {e}")
print(f"USING DEFAULT EL THRESHOLD = {DEFAULT_EL_THRESHOLD:.6f}")
print(f"USING DEFAULT BA THRESHOLD = {DEFAULT_BA_THRESHOLD:.6f}")
print("=" * 80 + "\n")
engine["el_threshold"] = el_thresh
engine["ba_threshold"] = ba_thresh
# ============================================================
# THRESHOLD FINALIZATION REPORT (PER ENGINE)
# ============================================================
source_label = "CHECKPOINT"
if el_thresh == DEFAULT_EL_THRESHOLD and ba_thresh == DEFAULT_BA_THRESHOLD:
source_label = "DEFAULT (FALLBACK)"
elif el_thresh == DEFAULT_EL_THRESHOLD or ba_thresh == DEFAULT_BA_THRESHOLD:
source_label = "PARTIAL (MIXED DEFAULT + CHECKPOINT)"
print("\n" + "#" * 90)
print("#" + " " * 88 + "#")
print("# ENGINE THRESHOLD FINALIZATION REPORT #")
print("#" + " " * 88 + "#")
print("#" * 90)
print(f"# MODEL : {config['model']}")
print(f"# DEVICE : {device}")
print(f"# SOURCE : {source_label}")
print("#" + "-" * 88 + "#")
print(f"# EL THRESHOLD : {el_thresh:.8f}")
print(f"# BA THRESHOLD : {ba_thresh:.8f}")
print("#" + "-" * 88 + "#")
print("# STATUS SUMMARY")
if source_label == "CHECKPOINT":
print("# -> USING TRAINED CHECKPOINT THRESHOLDS")
elif source_label == "DEFAULT (FALLBACK)":
print("# -> USING DEFAULT THRESHOLDS (NO CHECKPOINT FOUND)")
else:
print("# -> MIXED CONFIGURATION DETECTED")
print("#" * 90 + "\n")
##########################################################
print(f"[*] {config['model']} Engine Initialization Complete.")
return engine
# --- PREDICTION (keeps all allele variants by default) ---
# Set to True to deduplicate by pseudosequence (faster); False to keep all variants (traceable)
DEDUP_BY_PSEUDO = False
def _base_predict(peptides, alleles, contexts, engine, model_name="Model"):
"""Internal common prediction logic (robust, device-safe, single-forward)."""
if engine is None:
print(f"[!] {model_name} Engine not initialized.")
return pd.DataFrame()
# Normalize inputs
if peptides is None or alleles is None:
print("[!] Error: At least one peptide and one allele required.")
return pd.DataFrame()
if isinstance(peptides, str):
peptides = [peptides]
if isinstance(alleles, str):
alleles = [alleles]
if len(peptides) == 0 or len(alleles) == 0:
print("[!] Error: At least one peptide and one allele required.")
return pd.DataFrame()
if contexts is None:
contexts = [""] * len(peptides)
elif isinstance(contexts, str):
contexts = [contexts] * len(peptides)
elif len(contexts) != len(peptides):
print("[!] Warning: Context list length mismatch. Defaulting to empty contexts.")
contexts = [""] * len(peptides)
# 1. Resolve alleles -> pseudosequences (use resolver and preserve variants per DEDUP_BY_PSEUDO)
valid_alleles = {}
known_pseudos = set()
for hla in alleles:
matched_key = resolve_allele_key(hla, engine['hla_map'])
if matched_key:
pseudo = engine['hla_map'][matched_key]
if DEDUP_BY_PSEUDO:
if pseudo not in known_pseudos:
valid_alleles[matched_key] = pseudo
known_pseudos.add(pseudo)
else:
# preserve every matched variant (traceability)
valid_alleles[matched_key] = pseudo
if not valid_alleles:
print(f"[!] Error: None of the provided alleles were found in the {model_name} registry.")
return pd.DataFrame()
# 2. Build cartesian product
batch_data = []
for i, pep in enumerate(peptides):
ctx = contexts[i]
for allele, pseudo in valid_alleles.items():
batch_data.append({"Peptide": pep, "Allele": allele, "pseudo": pseudo, "Context": ctx})
df = pd.DataFrame(batch_data)
if df.empty:
return pd.DataFrame()
results = []
BATCH_SIZE = 64
device = engine.get("device", "cpu")
use_cuda_amp = (device != "cpu") and torch.cuda.is_available()
print(f"[*] {model_name}: Processing {len(peptides)} peptides × {len(valid_alleles)} unique alleles = {len(df)} predictions")
for start in range(0, len(df), BATCH_SIZE):
batch = df.iloc[start:start + BATCH_SIZE].reset_index(drop=True)
seqs = [f"{p} [SEP] {ps} [SEP] {c}" for p, ps, c in zip(batch["Peptide"], batch["pseudo"], batch["Context"])]
toks = engine['tokenizer'](seqs, return_tensors="pt", padding=True, truncation=True, max_length=128)
try:
toks = toks.to(device)
except Exception:
toks = {k: v.to(device) for k, v in toks.items()}
try:
with torch.no_grad():
if use_cuda_amp:
amp_ctx = torch.cuda.amp.autocast()
else:
amp_ctx = nullcontext()
with amp_ctx:
outputs = engine['model'](**toks)
last_hidden = outputs.last_hidden_state # shape: (B, L, H)
# BA: pooled mean over sequence
ba_emb = last_hidden.mean(dim=1).to(dtype=torch.float32)
ba_preds = engine['regHead'](ba_emb).squeeze(-1).cpu().numpy()
ba_preds = np.asarray(ba_preds).ravel()
# EL: classification head expects sequence input (attention inside head)
el_logits = engine['clfHead'](last_hidden.to(dtype=torch.float32)).squeeze(-1).cpu().numpy()
el_logits = np.asarray(el_logits).ravel()
el_probs = 1.0 / (1.0 + np.exp(-el_logits))
except RuntimeError as e:
print(f"[!] Inference error on batch starting at {start}: {e}")
torch.cuda.empty_cache()
n = len(batch)
ba_preds = np.full(n, np.nan)
el_probs = np.full(n, np.nan)
batch_res = batch.copy()
batch_res['BA_Score'] = ba_preds
batch_res['EL_Prob'] = el_probs
results.append(batch_res)
if not results:
return pd.DataFrame()
final_df = pd.concat(results, ignore_index=True)
# 4. Thresholding & postprocessing
el_t = engine.get('el_threshold', DEFAULT_EL_THRESHOLD)
ba_t = engine.get('ba_threshold', DEFAULT_BA_THRESHOLD)
# handle NaNs safely before casting
final_df['EL_Class'] = (final_df['EL_Prob'].fillna(-1) >= el_t).astype(int)
final_df['BA_Class'] = (final_df['BA_Score'].fillna(-1) >= ba_t).astype(int)
""" OPTIM 3:
final_df['Seen_in_BA'] = final_df['Peptide'].isin(engine.get('ba_db', set()))
final_df['Seen_in_EL'] = final_df['Peptide'].isin(engine.get('el_db', set()))
"""
final_df['Seen_in_BA'] = False
final_df['Seen_in_EL'] = False
#################
# Cleanup
final_df = final_df.drop(columns=['pseudo'])
final_df = final_df.sort_values(by="EL_Prob", ascending=False).reset_index(drop=True)
print(f"[*] {model_name} Inference Complete.")
return final_df
def predict_mhc1(peptides: list, alleles: list, contexts: list = None):
return _base_predict(peptides, alleles, contexts, ENGINE_MHC1, "MHC-I")
def predict_mhc2(peptides: list, alleles: list, contexts: list = None):
return _base_predict(peptides, alleles, contexts, ENGINE_MHC2, "MHC-II")
# --- (EXECUTION) INITIALIZATION HERE ---
print("\n" + "="*80)
print("INITIALIZING DUAL INFERENCE ENGINES (ESM-2 650M)")
print("="*80)
ENGINE_MHC1 = load_inference_engine(MHC1_CONFIG)
ENGINE_MHC2 = load_inference_engine(MHC2_CONFIG)