| | |
| | |
| |
|
| | from transformers import EsmForSequenceClassification, AutoTokenizer |
| | import torch |
| |
|
| |
|
| | def model_fn(model_dir): |
| | model = EsmForSequenceClassification.from_pretrained(model_dir, device_map="auto") |
| | tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| |
|
| | return model, tokenizer |
| |
|
| |
|
| | def predict_fn(data, model_and_tokenizer): |
| | model, tokenizer = model_and_tokenizer |
| | model.eval() |
| | inputs = data.pop("inputs", data) |
| | encoding = tokenizer(inputs, return_tensors="pt") |
| | encoding = {k: v.to(model.device) for k, v in encoding.items()} |
| | results = model(**encoding) |
| | sigmoid = torch.nn.Sigmoid() |
| | probs = sigmoid(results.logits) |
| | probs = probs.cpu() |
| | return {"membrane_probability": probs[0][1].item()} |
| |
|