leexiaohua commited on
Commit
8970aae
·
verified ·
1 Parent(s): 798ea70

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ A protein Subcellular localisation prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning. Model deployment references Synthira's [fastESM] (https://huggingface.co/Synthyra) series.
6
+
7
+ The dataset comes from the [DeepLoc project] (https://services.healthtech.dtu.dk/services/DeepLoc-2.1/).
8
+
9
+ ![evaluation_metrics](./ESM2_Subloc_Metrics.png)
10
+
11
+ ```
12
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
13
+ import torch
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model_id = "leexiaohua/subloc_small"
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+
21
+ model = AutoModelForSequenceClassification.from_pretrained(
22
+ "leexiaohua/subloc_small",
23
+ trust_remote_code=True
24
+ )
25
+
26
+ model.eval()
27
+
28
+ ```
29
+
30
+ ```
31
+ def predict_sublocation(sequence, model, tokenizer, device):
32
+
33
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
34
+ inputs = {k: v.to(device) for k, v in inputs.items()}
35
+
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+
39
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
40
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
41
+
42
+ id2label = model.config.id2label
43
+ results = {}
44
+
45
+ for i, prob in enumerate(probs):
46
+ if prob > 0.5:
47
+
48
+ label = id2label.get(i) or id2label.get(str(i))
49
+ if label:
50
+ results[label] = float(prob)
51
+ else:
52
+ results[f"Unknown_{i}"] = float(prob)
53
+
54
+ if not results:
55
+ max_idx = int(probs.argmax())
56
+ label = id2label.get(max_idx) or id2label.get(str(max_idx))
57
+ results[label or f"Unknown_{max_idx}"] = float(probs[max_idx])
58
+
59
+ return results
60
+ ```
61
+
62
+ An example:
63
+
64
+ ```
65
+ test_seq = "MSRLEAKKPSLCKSEPLTTERVRTTLSVLKRIVTSCYGPSGRLKQLHNGFGGYVCTTSQSSALLSHLLVTHPILKILTASIQNHVSSFSDCGLFTAILCCNLIENVQRLGLTPTTVIRLNKHLLSLCISYLKSETCGCRIPVDFSSTQILLCLVRSILTSKPACMLTRKETEHVSALILRAFLLTIPENAEGHIILGKSLIVPLKGQRVIDSTVLPGILIEMSEVQLMRLLPIKKSTALKVALFCTTLSGDTSDTGEGTVVVSYGVSLENAVLDQLLNLGRQLISDHVDLVLCQKVIHPSLKQFLNMHRIIAIDRIGVTLMEPLTKMTGTQPIGSLGSICPNSYGSVKDVCTAKFGSKHFFHLIPNEATICSLLLCNRNDTAWDELKLTCQTALHVLQLTLKEPWALLGGGCTETHLAAYIRHKTHNDPESILKDDECTQTELQLIAEAFCSALESVVGSLEHDGGEILTDMKYGHLWSVQADSPCVANWPDLLSQCGCGLYNSQEELNWSFLRSTRRPFVPQSCLPHEAVGSASNLTLDCLTAKLSGLQVAVETANLILDLSYVIEDKN"
66
+ predictions = predict_sublocation(test_seq, model, tokenizer, device)
67
+ print(f"Result: {predictions}")
68
+ ```
69
+
70
+ The output will be similar to:
71
+
72
+ ```text
73
+ Result: {'Cytoplasm': 0.9772326350212097, 'Soluble': 0.998727023601532}
74
+ ```