File size: 748 Bytes
78f28d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA, load_test_data

def load_model(model_path="model.pt", device=None):
    model = StriMap_pHLA(
        device=device,
        model_save_path=model_path,
        cache_save=False,
    )
    model.load_model(model_path)
    return model, device

def predict_from_df(df, model):
    df = load_test_data(
        df_test=df,
        hla_dict_path='HLA_dict.npy',
    )
    model.prepare_embeddings(
        df,
        force_recompute=False,
    )
    df['label'] = 1
    torch.cuda.empty_cache()
    predictions, _ = model.predict(df, batch_size=128, return_probs=True, use_kfold=False)
    df["Prediction"] = predictions
    # remove label
    df = df.drop(columns=['label'])
    return df