A protein Subcellular localisation prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning. Model deployment references Synthira's [fastESM] (https://huggingface.co/Synthyra) series.

The dataset comes from the [DeepLoc project] (https://services.healthtech.dtu.dk/services/DeepLoc-2.1/).

evaluation_metrics

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_id = "leexiaohua/subloc_small"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForSequenceClassification.from_pretrained(
    "leexiaohua/subloc_small", 
    trust_remote_code=True
)

model.eval()
def predict_sublocation(sequence, model, tokenizer, device):

    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)

        logits = outputs.logits if hasattr(outputs, "logits") else outputs
        probs = torch.sigmoid(logits).cpu().numpy()[0]
    
    id2label = model.config.id2label
    results = {}
  
    for i, prob in enumerate(probs):
        if prob > 0.5:

            label = id2label.get(i) or id2label.get(str(i))
            if label:
                results[label] = float(prob)
            else:
                results[f"Unknown_{i}"] = float(prob) 
  
    if not results:
        max_idx = int(probs.argmax())
        label = id2label.get(max_idx) or id2label.get(str(max_idx))
        results[label or f"Unknown_{max_idx}"] = float(probs[max_idx])
        
    return results

An example:

test_seq = "MSRLEAKKPSLCKSEPLTTERVRTTLSVLKRIVTSCYGPSGRLKQLHNGFGGYVCTTSQSSALLSHLLVTHPILKILTASIQNHVSSFSDCGLFTAILCCNLIENVQRLGLTPTTVIRLNKHLLSLCISYLKSETCGCRIPVDFSSTQILLCLVRSILTSKPACMLTRKETEHVSALILRAFLLTIPENAEGHIILGKSLIVPLKGQRVIDSTVLPGILIEMSEVQLMRLLPIKKSTALKVALFCTTLSGDTSDTGEGTVVVSYGVSLENAVLDQLLNLGRQLISDHVDLVLCQKVIHPSLKQFLNMHRIIAIDRIGVTLMEPLTKMTGTQPIGSLGSICPNSYGSVKDVCTAKFGSKHFFHLIPNEATICSLLLCNRNDTAWDELKLTCQTALHVLQLTLKEPWALLGGGCTETHLAAYIRHKTHNDPESILKDDECTQTELQLIAEAFCSALESVVGSLEHDGGEILTDMKYGHLWSVQADSPCVANWPDLLSQCGCGLYNSQEELNWSFLRSTRRPFVPQSCLPHEAVGSASNLTLDCLTAKLSGLQVAVETANLILDLSYVIEDKN"
predictions = predict_sublocation(test_seq, model, tokenizer, device)
print(f"Result: {predictions}")

The output will be similar to:

Result: {'Cytoplasm': 0.9772326350212097, 'Soluble': 0.998727023601532}
Downloads last month
75
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support