andrewdalpino commited on
Commit
550efd9
·
verified ·
1 Parent(s): cd256e6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -9
README.md CHANGED
@@ -48,8 +48,6 @@ Then, we'll load the model weights from HuggingFace Hub and the GO graph using `
48
  ```python
49
  import torch
50
 
51
- import obonet
52
-
53
  from esm.tokenization import EsmSequenceTokenizer
54
 
55
  from esmc_function_classifier.model import EsmcGoTermClassifier
@@ -57,28 +55,53 @@ from esmc_function_classifier.model import EsmcGoTermClassifier
57
 
58
  model_name = "andrewdalpino/ESMC-300M-Protein-Function"
59
 
60
- # Visit https://geneontology.org/docs/download-ontology/ to download.
61
- go_db_path = "./dataset/go-basic.obo"
62
-
63
  sequence = "MPPKGHKKTADGDFRPVNSAGNTIQAKQKYSIDDLLYPKSTIKNLAKETLPDDAIISKDALTAIQRAATLFVSYMASHGNASAEAGGRKKIT"
64
 
65
  top_p = 0.5
66
 
67
- graph = obonet.read_obo(go_db_path)
68
-
69
  tokenizer = EsmSequenceTokenizer()
70
 
71
  model = EsmcGoTermClassifier.from_pretrained(model_name)
72
 
73
- model.load_gene_ontology(graph)
74
-
75
  out = tokenizer(sequence, max_length=2048, truncation=True)
76
 
77
  input_ids = torch.tensor(out["input_ids"], dtype=torch.int64)
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  subgraph, go_term_probabilities = model.predict_subgraph(
80
  input_ids, top_p=top_p
81
  )
 
 
 
 
 
 
 
 
 
 
 
 
82
  ```
83
 
84
  ## Code Repository
 
48
  ```python
49
  import torch
50
 
 
 
51
  from esm.tokenization import EsmSequenceTokenizer
52
 
53
  from esmc_function_classifier.model import EsmcGoTermClassifier
 
55
 
56
  model_name = "andrewdalpino/ESMC-300M-Protein-Function"
57
 
 
 
 
58
  sequence = "MPPKGHKKTADGDFRPVNSAGNTIQAKQKYSIDDLLYPKSTIKNLAKETLPDDAIISKDALTAIQRAATLFVSYMASHGNASAEAGGRKKIT"
59
 
60
  top_p = 0.5
61
 
 
 
62
  tokenizer = EsmSequenceTokenizer()
63
 
64
  model = EsmcGoTermClassifier.from_pretrained(model_name)
65
 
 
 
66
  out = tokenizer(sequence, max_length=2048, truncation=True)
67
 
68
  input_ids = torch.tensor(out["input_ids"], dtype=torch.int64)
69
 
70
+ go_term_probabilities = model.predict_terms(
71
+ input_ids, top_p=top_p
72
+ )
73
+ ```
74
+
75
+ You can also output the gene-ontology (GO) `networkx` subgraph for a given sequence like in the example below. You'll need an up-to-date gene ontology database that you can import using the `obonet` package.
76
+
77
+ ```python
78
+ import networkx as nx
79
+
80
+ import obonet
81
+
82
+
83
+ # Visit https://geneontology.org/docs/download-ontology/ to download.
84
+ go_db_path = "./dataset/go-basic.obo"
85
+
86
+ graph = obonet.read_obo(go_db_path)
87
+
88
+ model.load_gene_ontology(graph)
89
+
90
  subgraph, go_term_probabilities = model.predict_subgraph(
91
  input_ids, top_p=top_p
92
  )
93
+
94
+ json = nx.node_link_data(subgraph)
95
+
96
+ print(json)
97
+ ```
98
+
99
+ ### Quantized Model
100
+
101
+ To quantize the model weights using int8 call the `quantize_weights()` method. Any model can be quantized, but we recommend one that has been quantization-aware trained (QAT) for the best performance. The `group_size` argument controls the granularity at which quantization scales are computed.
102
+
103
+ ```python
104
+ model.quantize_weights(group_size=64)
105
  ```
106
 
107
  ## Code Repository