A protein Subcellular localisation prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning. Model deployment references Synthira's [fastESM] (https://huggingface.co/Synthyra) series.
The dataset comes from the [DeepLoc project] (https://services.healthtech.dtu.dk/services/DeepLoc-2.1/).
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "leexiaohua/subloc_small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
"leexiaohua/subloc_small",
trust_remote_code=True
)
model.eval()
def predict_sublocation(sequence, model, tokenizer, device):
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits if hasattr(outputs, "logits") else outputs
probs = torch.sigmoid(logits).cpu().numpy()[0]
id2label = model.config.id2label
results = {}
for i, prob in enumerate(probs):
if prob > 0.5:
label = id2label.get(i) or id2label.get(str(i))
if label:
results[label] = float(prob)
else:
results[f"Unknown_{i}"] = float(prob)
if not results:
max_idx = int(probs.argmax())
label = id2label.get(max_idx) or id2label.get(str(max_idx))
results[label or f"Unknown_{max_idx}"] = float(probs[max_idx])
return results
An example:
test_seq = "MSRLEAKKPSLCKSEPLTTERVRTTLSVLKRIVTSCYGPSGRLKQLHNGFGGYVCTTSQSSALLSHLLVTHPILKILTASIQNHVSSFSDCGLFTAILCCNLIENVQRLGLTPTTVIRLNKHLLSLCISYLKSETCGCRIPVDFSSTQILLCLVRSILTSKPACMLTRKETEHVSALILRAFLLTIPENAEGHIILGKSLIVPLKGQRVIDSTVLPGILIEMSEVQLMRLLPIKKSTALKVALFCTTLSGDTSDTGEGTVVVSYGVSLENAVLDQLLNLGRQLISDHVDLVLCQKVIHPSLKQFLNMHRIIAIDRIGVTLMEPLTKMTGTQPIGSLGSICPNSYGSVKDVCTAKFGSKHFFHLIPNEATICSLLLCNRNDTAWDELKLTCQTALHVLQLTLKEPWALLGGGCTETHLAAYIRHKTHNDPESILKDDECTQTELQLIAEAFCSALESVVGSLEHDGGEILTDMKYGHLWSVQADSPCVANWPDLLSQCGCGLYNSQEELNWSFLRSTRRPFVPQSCLPHEAVGSASNLTLDCLTAKLSGLQVAVETANLILDLSYVIEDKN"
predictions = predict_sublocation(test_seq, model, tokenizer, device)
print(f"Result: {predictions}")
The output will be similar to:
Result: {'Cytoplasm': 0.9772326350212097, 'Soluble': 0.998727023601532}
- Downloads last month
- 75
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
