Spaces:
Sleeping
Sleeping
File size: 748 Bytes
78f28d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA, load_test_data
def load_model(model_path="model.pt", device=None):
model = StriMap_pHLA(
device=device,
model_save_path=model_path,
cache_save=False,
)
model.load_model(model_path)
return model, device
def predict_from_df(df, model):
df = load_test_data(
df_test=df,
hla_dict_path='HLA_dict.npy',
)
model.prepare_embeddings(
df,
force_recompute=False,
)
df['label'] = 1
torch.cuda.empty_cache()
predictions, _ = model.predict(df, batch_size=128, return_probs=True, use_kfold=False)
df["Prediction"] = predictions
# remove label
df = df.drop(columns=['label'])
return df |