File size: 3,738 Bytes
8ad14e3
 
 
 
 
 
6d0bcec
 
9d3badd
19d17e3
 
 
 
 
 
 
 
67ab59d
19d17e3
 
b1a25c6
19d17e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1a25c6
 
 
 
 
 
19d17e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58f740e
19d17e3
 
 
 
 
 
 
 
 
 
 
 
 
 
857c824
19d17e3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
---
tags:
- pytorch_model_hub_mixin
- model_hub_mixin
---

This model adds a classification head on top of [LofiAmazon/BarcodeBERT-Entire-BOLD](https://huggingface.co/LofiAmazon/BarcodeBERT-Entire-BOLD), The classification head is a linear layer which concatenates the DNA embeddings with environmental layer data. This model has been trained with 
[BOLD-Embeddings-Ecolayers-Amazon](https://huggingface.co/datasets/LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon) to predict taxonomic genuses. BOLD-Embeddings-Ecolayers-Amazon includes DNA embeddings and ecological raster layer data regarding sample location.

## Example Usage

First you need to download the `StandardScaler` used to normalise environmental layer values during training. You can find it [here](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/scaler.pkl).

You will also need to download all ecological layers from [here](https://huggingface.co/datasets/LofiAmazon/Global-Ecolayers/tree/main).

The model will output a probability distribution over genera that were present in the training dataset. You can find the mapping between the index in the vector and the genus name in [this file](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/genus_labels.json).

```py
import pickle

from transformers import PreTrainedTokenizerFast, BertForMaskedLM, BertConfig
import rasterio
from rasterio.sample import sample_gen


class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
    def __init__(self, bert_model, env_dim, num_classes):
        super(DNASeqClassifier, self).__init__()
        self.bert = bert_model
        self.env_dim = env_dim
        self.num_classes = num_classes
        self.fc = nn.Linear(768 + env_dim, num_classes)

    def forward(self, bert_inputs, env_data):
        outputs = self.bert(**bert_inputs)
        dna_embeddings = outputs.hidden_states[-1].mean(1)
        combined = torch.cat((dna_embeddings, env_data), dim=1)
        logits = self.fc(combined)

        return logits

classification_model = DNASeqClassifier.from_pretrained(
    "LofiAmazon/BarcodeBERT-Finetuned-Amazon",
    bert_model=BertForMaskedLM(
        BertConfig(vocab_size=259, output_hidden_states=True),
    ),
)

ecolayers = [
    "median_elevation_1km.tiff",
    "human_footprint.tiff",
    "population_density_1km.tif",
    "annual_precipitation.tif",
    "precipitation_seasonality.tif",
    "annual_mean_air_temp.tif",
    "temp_seasonality.tif",
]

with open("scaler.pkl", "rb") as f:
    scaler = pickle.load(f)

tokenizer = PreTrainedTokenizerFast.from_pretrained("LofiAmazon/BarcodeBERT-Entire-BOLD")

# The DNA sequence you want to predict.
# There should be a space after every 4 characters.
# The sequence may also have unknown characters which are not A,C,T,G.
# The maximum DNA sequence length (not counting spaces) should be 660 characters
dna_sequence = "AACA ATGT ATTT A-T- TTCG CCCT TGTG AATT TATT ..."
# Location where DNA was sampled
coords = (-3.009083, -58.68281)

# Tokenize the DNA sequence
dna_tokenized = tokenizer(dna_sequence, return_tensors="pt")

# Obtain the environmental data from the coordinates
env_data = []
for layer in ecolayers:
    with rasterio.open(layer) as dataset:
        # Get the corresponding ecological values for the samples
        results = sample_gen(dataset, [coords])
        results = [r for r in results]
    layer_data = np.mean(results[0])
    env_data.append(layer_data)
env_data = scaler.transform([env_data])
env_data = torch.from_numpy(env_data).to(torch.float32)

# Obtain genus prediction
logits = classification_model(dna_tokenized, env_data)
temperature = 0.2

# Obtain the final genus probabilities
probs = torch.softmax(logits / temperature, dim=1).squeeze()
```