import torch from main import StriMap_pHLA, StriMap_TCRpHLA, load_test_data def load_model(model_path="model.pt", device=None): model = StriMap_pHLA( device=device, model_save_path=model_path, cache_save=False, ) model.load_model(model_path) return model, device def predict_from_df(df, model): df = load_test_data( df_test=df, hla_dict_path='HLA_dict.npy', ) model.prepare_embeddings( df, force_recompute=False, ) df['label'] = 1 torch.cuda.empty_cache() predictions, _ = model.predict(df, batch_size=128, return_probs=True, use_kfold=False) df["Prediction"] = predictions # remove label df = df.drop(columns=['label']) return df