vshulev commited on
Commit
19d17e3
·
verified ·
1 Parent(s): 15a72cb

Add example use instructions

Browse files
Files changed (1) hide show
  1. README.md +78 -3
README.md CHANGED
@@ -7,6 +7,81 @@ tags:
7
  This model adds a classification head on top of [LofiAmazon/BarcodeBERT-Entire-BOLD](https://huggingface.co/LofiAmazon/BarcodeBERT-Entire-BOLD), The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with
8
  [BOLD-Embeddings-Ecolayers-Amazon](https://huggingface.co/datasets/LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon) to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location.
9
 
10
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
11
- - Library: [More Information Needed]
12
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  This model adds a classification head on top of [LofiAmazon/BarcodeBERT-Entire-BOLD](https://huggingface.co/LofiAmazon/BarcodeBERT-Entire-BOLD), The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with
8
  [BOLD-Embeddings-Ecolayers-Amazon](https://huggingface.co/datasets/LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon) to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location.
9
 
10
+ ## Example Usage
11
+
12
+ First you need to download the `StandardScaler` used to normalise environmental layer values during training. You can find it [here](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/scaler.pkl).
13
+
14
+ You will also need to download all ecological layers from [here](https://huggingface.co/datasets/LofiAmazon/Global-Ecolayers/tree/main).
15
+
16
+ The model will output a probability distribution over genera that were present in the training dataset. You can find the mapping between the index in the vector and the genus name in [this file](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/genus_labels.json).
17
+
18
+ ```
19
+ import pickle
20
+
21
+ from transformers import PreTrainedTokenizerFast
22
+ import rasterio
23
+ from rasterio.sample import sample_gen
24
+
25
+
26
+ class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
27
+ def __init__(self, bert_model, env_dim, num_classes):
28
+ super(DNASeqClassifier, self).__init__()
29
+ self.bert = bert_model
30
+ self.env_dim = env_dim
31
+ self.num_classes = num_classes
32
+ self.fc = nn.Linear(768 + env_dim, num_classes)
33
+
34
+ def forward(self, bert_inputs, env_data):
35
+ outputs = self.bert(**bert_inputs)
36
+ dna_embeddings = outputs.hidden_states[-1].mean(1)
37
+ combined = torch.cat((dna_embeddings, env_data), dim=1)
38
+ logits = self.fc(combined)
39
+
40
+ return logits
41
+
42
+
43
+ ecolayers = [
44
+ "median_elevation_1km.tiff",
45
+ "human_footprint.tiff",
46
+ "population_density_1km.tif",
47
+ "annual_precipitation.tif",
48
+ "precipitation_seasonality.tif",
49
+ "annual_mean_air_temp.tif",
50
+ "temp_seasonality.tif",
51
+ ]
52
+
53
+ with open("scaler.pkl", "rb") as f:
54
+ scaler = pickle.load(f)
55
+
56
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("LofiAmazon/BarcodeBERT-Entire-BOLD")
57
+
58
+ # The DNA sequence you want to predict.
59
+ # There should be a space after every 4 characters.
60
+ # The sequence may also have unknown characters which are not A,C,T,G.
61
+ # The maximum DNA sequence length (not counting spaces) should be 660 characters
62
+ dna_sequence = "AACA ATGT ATTT A-T- TTCG CCCT TGTG AATT TATT ..."
63
+ # Location where DNA was sampled
64
+ coords = (-3.009083, -58.68281)
65
+
66
+ # Tokenize the DNA sequence
67
+ dna_tokenized = tokenizer(dna_seq_preprocessed, return_tensors="pt")
68
+
69
+ # Obtain the environmental data from the coordinates
70
+ env_data = []
71
+ for layer in ecolayers:
72
+ with rasterio.open(layer) as dataset:
73
+ # Get the corresponding ecological values for the samples
74
+ results = sample_gen(dataset, [coords])
75
+ results = [r for r in results]
76
+ layer_data = np.mean(results[0])
77
+ env_data.append(layer_data)
78
+ env_data = scaler.transform([env_data])
79
+ env_data = torch.from_numpy(env_data).to(torch.float32)
80
+
81
+ # Obtain genus prediction
82
+ logits = classification_model(bert_inputs, env_data)
83
+ temperature = 0.2
84
+
85
+ # Obtain the final genus probabilities
86
+ probs = torch.softmax(logits / temperature, dim=1).squeeze()
87
+ ```