genomenet commited on
Commit
25669cc
·
1 Parent(s): 747cf48

Add BERT metagenome embedding extraction app

Browse files
Files changed (4) hide show
  1. Dockerfile +27 -0
  2. README.md +30 -4
  3. app.py +233 -0
  4. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONUNBUFFERED=1
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV TF_CPP_MIN_LOG_LEVEL=2
6
+
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
+ ENV HOME=/home/user
14
+ ENV PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR /home/user/app
17
+
18
+ COPY --chown=user:user requirements.txt .
19
+
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ COPY --chown=user:user . .
24
+
25
+ EXPOSE 7860
26
+
27
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,10 +1,36 @@
1
  ---
2
- title: Bert Embedding
3
- emoji: 🔥
4
- colorFrom: green
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BERT Metagenome Embeddings
3
+ emoji: 🧬
4
+ colorFrom: gray
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # bert-embedding
12
+
13
+ Extract embeddings from DNA sequences using a BERT model pretrained on metagenomic sequences.
14
+
15
+ ## Model
16
+
17
+ | | |
18
+ |---|---|
19
+ | architecture | BERT, 24 layers, 768 hidden, 12 heads |
20
+ | parameters | ~430M |
21
+ | input | DNA sequence (min 1000 bp) |
22
+ | output | 768-dim embedding |
23
+ | source | [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome) |
24
+
25
+ ## Deployment
26
+
27
+ ```bash
28
+ cd /vol/hpcprojects/pmuench/crispr_tool/bert-embedding
29
+ git add -A && git commit -m "update" && git push
30
+ ```
31
+
32
+ ## Acknowledgements
33
+
34
+ - BMBF de.NBI / GenomeNet
35
+ - DFG SPP 2141
36
+ - HZI BIFO
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT Metagenome Embeddings - HuggingFace Spaces App
3
+ """
4
+
5
+ import os
6
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Model config
14
+ MODEL_REPO = "genomenet/bert-metagenome"
15
+ MODEL_FILE = "bert_1k_3.h5"
16
+ WINDOW_SIZE = 1000
17
+ EMBEDDING_LAYER = "layer_transformer_block_21"
18
+ EMBEDDING_DIM = 768
19
+
20
+ # Singleton model
21
+ _model = None
22
+ _embedding_model = None
23
+
24
+ def get_model():
25
+ global _model, _embedding_model
26
+ if _model is None:
27
+ print("Downloading model...")
28
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
29
+ print(f"Loading model from {model_path}...")
30
+ _model = tf.keras.models.load_model(model_path, compile=False)
31
+ _embedding_model = tf.keras.Model(
32
+ inputs=_model.input,
33
+ outputs=_model.get_layer(EMBEDDING_LAYER).output
34
+ )
35
+ print("Model loaded.")
36
+ return _embedding_model
37
+
38
+ def get_gpu_status():
39
+ gpus = tf.config.list_physical_devices('GPU')
40
+ if gpus:
41
+ return f"GPU: {gpus[0].name}"
42
+ return "CPU only"
43
+
44
+ # Tokenization
45
+ NUCLEOTIDE_MAP = {
46
+ 'A': [1, 0, 0, 0],
47
+ 'C': [0, 1, 0, 0],
48
+ 'G': [0, 0, 1, 0],
49
+ 'T': [0, 0, 0, 1],
50
+ 'N': [0.25, 0.25, 0.25, 0.25],
51
+ }
52
+
53
+ def tokenize(sequence):
54
+ sequence = sequence.upper().replace('U', 'T')
55
+ tokens = []
56
+ for char in sequence:
57
+ if char in NUCLEOTIDE_MAP:
58
+ tokens.append(NUCLEOTIDE_MAP[char])
59
+ elif char in 'RYSWKMBDHV':
60
+ tokens.append(NUCLEOTIDE_MAP['N'])
61
+ return np.array(tokens, dtype=np.float32)
62
+
63
+ def validate_sequence(sequence):
64
+ if not sequence or len(sequence.strip()) == 0:
65
+ return False, "Sequence is empty"
66
+ sequence = sequence.upper().replace('U', 'T')
67
+ valid_chars = set('ACGTNRYSWKMBDHV')
68
+ invalid = set(sequence) - valid_chars - set(' \n\r\t')
69
+ if invalid:
70
+ return False, f"Invalid characters: {invalid}"
71
+ clean = ''.join(c for c in sequence if c in valid_chars)
72
+ if len(clean) < WINDOW_SIZE:
73
+ return False, f"Sequence too short: {len(clean)} < {WINDOW_SIZE} bp"
74
+ return True, ""
75
+
76
+ def strip_fasta_header(text):
77
+ lines = text.strip().split('\n')
78
+ seq_lines = [l for l in lines if not l.startswith('>')]
79
+ return ''.join(seq_lines).replace(' ', '').replace('\t', '')
80
+
81
+ def embed_sequence(sequence, mode="mean", stride=100):
82
+ """Extract embeddings from sequence."""
83
+ model = get_model()
84
+
85
+ seq_len = len(sequence)
86
+ embeddings = []
87
+ positions = []
88
+
89
+ for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
90
+ window = sequence[start:start + WINDOW_SIZE]
91
+ tokens = tokenize(window)
92
+ tokens = np.expand_dims(tokens, axis=0)
93
+
94
+ emb = model.predict(tokens, verbose=0)
95
+ embeddings.append(emb[0])
96
+ positions.append(start)
97
+
98
+ embeddings = np.array(embeddings) # (n_windows, 1000, 768)
99
+
100
+ # Pool across sequence positions within each window
101
+ if mode == "mean":
102
+ # Mean pool: (n_windows, 768) -> (768,)
103
+ window_emb = np.mean(embeddings, axis=1)
104
+ return np.mean(window_emb, axis=0)
105
+ elif mode == "max":
106
+ window_emb = np.max(embeddings, axis=1)
107
+ return np.max(window_emb, axis=0)
108
+ elif mode == "per-window":
109
+ # Return per-window mean embeddings
110
+ return np.mean(embeddings, axis=1)
111
+ else:
112
+ return np.mean(np.mean(embeddings, axis=1), axis=0)
113
+
114
+ # Example sequence
115
+ EXAMPLE_SEQUENCE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCT"""
116
+
117
+ def process(sequence: str, mode: str, stride: int):
118
+ """Main processing function."""
119
+ sequence = strip_fasta_header(sequence.strip())
120
+
121
+ is_valid, error = validate_sequence(sequence)
122
+ if not is_valid:
123
+ return f"**Error**: {error}", None
124
+
125
+ embedding = embed_sequence(sequence, mode=mode, stride=stride)
126
+
127
+ if mode == "per-window":
128
+ # Return as downloadable numpy file
129
+ import tempfile
130
+ path = os.path.join(tempfile.gettempdir(), "embeddings.npy")
131
+ np.save(path, embedding)
132
+
133
+ summary = f"""## Embeddings extracted
134
+
135
+ | | |
136
+ |---|---|
137
+ | sequence length | {len(sequence):,} bp |
138
+ | windows | {embedding.shape[0]} |
139
+ | embedding dim | {embedding.shape[1]} |
140
+ | stride | {stride} bp |
141
+ | shape | {embedding.shape} |
142
+
143
+ Download the `.npy` file for per-window embeddings.
144
+ """
145
+ return summary, path
146
+ else:
147
+ # Single vector - show as text
148
+ emb_str = ", ".join([f"{x:.4f}" for x in embedding[:10]])
149
+
150
+ summary = f"""## Embedding extracted
151
+
152
+ | | |
153
+ |---|---|
154
+ | sequence length | {len(sequence):,} bp |
155
+ | mode | {mode} |
156
+ | embedding dim | {len(embedding)} |
157
+
158
+ **First 10 dimensions**: [{emb_str}, ...]
159
+
160
+ Full embedding saved to file.
161
+ """
162
+ import tempfile
163
+ path = os.path.join(tempfile.gettempdir(), "embedding.npy")
164
+ np.save(path, embedding)
165
+ return summary, path
166
+
167
+ # CSS
168
+ CUSTOM_CSS = """
169
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500&display=swap');
170
+ * { font-family: 'Inter', system-ui, sans-serif !important; }
171
+ code, pre, textarea { font-family: 'SF Mono', Consolas, monospace !important; }
172
+ .gradio-container { max-width: 900px !important; background: #fafafa !important; }
173
+ """
174
+
175
+ # Build interface
176
+ with gr.Blocks(title="BERT Metagenome Embeddings") as demo:
177
+ gr.Markdown("""
178
+ # bert-embedding
179
+
180
+ Extract embeddings from DNA sequences. BERT model (430M params) pretrained on metagenomic sequences.
181
+ """)
182
+
183
+ with gr.Row():
184
+ with gr.Column(scale=1):
185
+ seq_input = gr.Textbox(
186
+ label="sequence",
187
+ placeholder="Paste DNA sequence (FASTA or raw)...",
188
+ lines=8,
189
+ value=EXAMPLE_SEQUENCE,
190
+ info="min 1000 bp"
191
+ )
192
+ mode_input = gr.Radio(
193
+ choices=["mean", "max", "per-window"],
194
+ value="mean",
195
+ label="pooling",
196
+ info="mean/max: single 768-dim vector | per-window: (n, 768) matrix"
197
+ )
198
+ stride_input = gr.Slider(
199
+ minimum=50, maximum=500, value=100, step=50,
200
+ label="stride",
201
+ info="step size between windows"
202
+ )
203
+ btn = gr.Button("extract", variant="primary")
204
+
205
+ with gr.Column(scale=1):
206
+ output = gr.Markdown()
207
+ download = gr.File(label="download")
208
+
209
+ btn.click(
210
+ process,
211
+ inputs=[seq_input, mode_input, stride_input],
212
+ outputs=[output, download]
213
+ )
214
+
215
+ gr.Markdown("""
216
+ ---
217
+ **Model**: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome) |
218
+ **Related**: [CRISPR Detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
219
+ """)
220
+
221
+ if __name__ == "__main__":
222
+ print("Loading model...")
223
+ model = get_model()
224
+ print(f"Ready! {get_gpu_status()}")
225
+ demo.launch(
226
+ server_name="0.0.0.0",
227
+ server_port=7860,
228
+ theme=gr.themes.Base(
229
+ primary_hue=gr.themes.colors.zinc,
230
+ neutral_hue=gr.themes.colors.zinc,
231
+ ),
232
+ css=CUSTOM_CSS
233
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorflow==2.15.1
2
+ keras==2.15.0
3
+ gradio>=4.0.0
4
+ numpy>=1.26.0,<2.0.0
5
+ huggingface_hub>=0.20.0