openfree commited on
Commit
7e4fb15
ยท
verified ยท
1 Parent(s): dccecb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +465 -0
app.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import hashlib
5
+ from typing import List, Dict, Tuple
6
+
7
+ import streamlit as st
8
+ import requests
9
+
10
+ # Optional heavy deps; guard imports so the app still loads
11
+ try:
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
14
+ except Exception as e:
15
+ torch = None
16
+ AutoTokenizer = None
17
+ AutoModel = None
18
+ AutoModelForMaskedLM = None
19
+
20
+ try:
21
+ from datasets import load_dataset
22
+ except Exception:
23
+ load_dataset = None
24
+
25
+ try:
26
+ from sentence_transformers import SentenceTransformer
27
+ except Exception:
28
+ SentenceTransformer = None
29
+
30
+ try:
31
+ import faiss # faiss-cpu
32
+ except Exception:
33
+ faiss = None
34
+
35
+ try:
36
+ from Bio import SeqIO
37
+ except Exception:
38
+ SeqIO = None
39
+
40
+ APP_TITLE = "BioSeq Chat: Protein & DNA Assistant"
41
+ DISCLAIMER = (
42
+ "This tool is for research/education and is not a medical device. "
43
+ "Do not use outputs for diagnosis or treatment decisions."
44
+ )
45
+
46
+ # --------------- Helpers ---------------
47
+
48
+ def get_secret(name: str, fallback: str = "") -> str:
49
+ """Get secret from st.secrets, environment, or fallback"""
50
+ try:
51
+ return st.secrets.get(name, os.environ.get(name, fallback))
52
+ except Exception:
53
+ return os.environ.get(name, fallback)
54
+
55
+ def brave_search(query: str, count: int = 5) -> List[Dict]:
56
+ """Search using Brave Search API"""
57
+ key = get_secret("BRAVE_API_KEY", "")
58
+ if not key:
59
+ return [{"title": "BRAVE_API_KEY is missing",
60
+ "url": "",
61
+ "snippet": "Set BRAVE_API_KEY in Space secrets or sidebar to enable web search."}]
62
+
63
+ url = "https://api.search.brave.com/res/v1/web/search"
64
+ headers = {
65
+ "Accept": "application/json",
66
+ "X-Subscription-Token": key,
67
+ "Accept-Encoding": "gzip"
68
+ }
69
+ params = {"q": query, "count": count, "country": "us"}
70
+
71
+ try:
72
+ r = requests.get(url, headers=headers, params=params, timeout=15)
73
+ r.raise_for_status()
74
+ data = r.json()
75
+ results = []
76
+ for item in data.get("web", {}).get("results", [])[:count]:
77
+ results.append({
78
+ "title": item.get("title", ""),
79
+ "url": item.get("url", ""),
80
+ "snippet": item.get("description", ""),
81
+ })
82
+ if not results:
83
+ results = [{"title": "No results", "url": "", "snippet": "Query returned no results."}]
84
+ return results
85
+ except Exception as e:
86
+ return [{"title": "Search error", "url": "", "snippet": str(e)}]
87
+
88
+ def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str:
89
+ """Call Fireworks AI chat completion API"""
90
+ api_key = get_secret("FIREWORKS_API_KEY", "")
91
+ if not api_key:
92
+ return "FIREWORKS_API_KEY is missing. Set it in Secrets or the sidebar."
93
+
94
+ url = "https://api.fireworks.ai/inference/v1/chat/completions"
95
+ payload = {
96
+ "model": "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
97
+ "max_tokens": max_tokens,
98
+ "top_p": 1,
99
+ "top_k": 40,
100
+ "presence_penalty": 0,
101
+ "frequency_penalty": 0,
102
+ "temperature": temperature,
103
+ "messages": messages
104
+ }
105
+ headers = {
106
+ "Accept": "application/json",
107
+ "Content-Type": "application/json",
108
+ "Authorization": f"Bearer {api_key}"
109
+ }
110
+
111
+ try:
112
+ r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
113
+ r.raise_for_status()
114
+ data = r.json()
115
+ return data["choices"][0]["message"]["content"]
116
+ except Exception as e:
117
+ return f"[Fireworks API error] {e}"
118
+
119
+ def load_text_from_file(upload) -> str:
120
+ """Load text from uploaded file"""
121
+ name = upload.name.lower()
122
+ content = upload.read()
123
+
124
+ try:
125
+ text = content.decode("utf-8", errors="ignore")
126
+ except Exception:
127
+ text = str(content)
128
+
129
+ # FASTA quick parse
130
+ if name.endswith((".fa", ".fasta", ".faa", ".fna")) and SeqIO is not None:
131
+ upload.seek(0)
132
+ try:
133
+ records = list(SeqIO.parse(upload, "fasta"))
134
+ seqs = []
135
+ for r in records:
136
+ seqs.append(f">{r.id}\n{str(r.seq)}")
137
+ return "\n\n".join(seqs)
138
+ except Exception:
139
+ return text
140
+
141
+ return text
142
+
143
+ def build_vector_index(texts: List[str], embedder_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
144
+ """Build FAISS vector index from texts"""
145
+ if SentenceTransformer is None or faiss is None:
146
+ return None, None, None
147
+
148
+ try:
149
+ model = SentenceTransformer(embedder_name)
150
+ emb = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
151
+ dim = emb.shape[1]
152
+ index = faiss.IndexFlatIP(dim)
153
+ index.add(emb.astype("float32"))
154
+ return index, emb, model
155
+ except Exception as e:
156
+ st.warning(f"Failed to build index: {e}")
157
+ return None, None, None
158
+
159
+ def search_index(query: str, index, model, texts: List[str], k: int = 4):
160
+ """Search vector index"""
161
+ if index is None or model is None:
162
+ return []
163
+
164
+ try:
165
+ q = model.encode([query], normalize_embeddings=True)
166
+ D, I = index.search(q.astype("float32"), k)
167
+ hits = []
168
+ for idx, score in zip(I[0], D[0]):
169
+ if 0 <= idx < len(texts):
170
+ hits.append({"score": float(score), "text": texts[idx]})
171
+ return hits
172
+ except Exception:
173
+ return []
174
+
175
+ def esm2_embed(seq: str, model_id: str = "facebook/esm2_t6_8M_UR50D") -> Dict:
176
+ """Generate ESM-2 embedding for protein sequence"""
177
+ if AutoTokenizer is None or AutoModelForMaskedLM is None or torch is None:
178
+ return {"error": "Transformers/torch not available"}
179
+
180
+ try:
181
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
182
+ model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True)
183
+ model.eval()
184
+
185
+ with torch.no_grad():
186
+ toks = tokenizer(seq, return_tensors="pt")
187
+ out = model(**toks, output_hidden_states=True)
188
+ hidden = out.hidden_states[-1].mean(dim=1).squeeze(0) # [hidden_size]
189
+ vec = hidden.detach().cpu().numpy()
190
+ return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
191
+ except Exception as e:
192
+ return {"error": str(e)}
193
+
194
+ def dna_embed(seq: str, model_id: str = "zhihan1996/DNABERT-2-117M") -> Dict:
195
+ """Generate DNABERT-2 or Nucleotide Transformer embedding for DNA sequence"""
196
+ if AutoTokenizer is None or AutoModel is None or torch is None:
197
+ return {"error": "Transformers/torch not available"}
198
+
199
+ try:
200
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
201
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
202
+ model.eval()
203
+
204
+ with torch.no_grad():
205
+ toks = tokenizer(seq, return_tensors="pt", truncation=True, max_length=4096)
206
+ out = model(**toks, output_hidden_states=True)
207
+ hidden = out.last_hidden_state.mean(dim=1).squeeze(0)
208
+ vec = hidden.detach().cpu().numpy()
209
+ return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
210
+ except Exception as e:
211
+ return {"error": str(e)}
212
+
213
+ def chunk_text(text: str, chunk_size: int = 1200, overlap: int = 200) -> List[str]:
214
+ """Chunk text with overlap"""
215
+ text = text.replace("\r\n", "\n")
216
+ chunks = []
217
+ start = 0
218
+
219
+ while start < len(text):
220
+ end = min(len(text), start + chunk_size)
221
+ chunks.append(text[start:end])
222
+ start = end - overlap
223
+ if start < 0:
224
+ start = 0
225
+ if end >= len(text):
226
+ break
227
+
228
+ return chunks
229
+
230
+ def safe_len(obj, default=0):
231
+ """Safely get length of object"""
232
+ try:
233
+ return len(obj)
234
+ except Exception:
235
+ return default
236
+
237
+ # --------------- UI ---------------
238
+
239
+ st.set_page_config(page_title=APP_TITLE, page_icon="๐Ÿงฌ", layout="wide")
240
+ st.title(APP_TITLE)
241
+ st.caption(DISCLAIMER)
242
+
243
+ # Sidebar configuration
244
+ with st.sidebar:
245
+ st.header("Keys and settings")
246
+ fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password")
247
+ brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password")
248
+
249
+ if fw_key:
250
+ os.environ["FIREWORKS_API_KEY"] = fw_key
251
+ if brave_key:
252
+ os.environ["BRAVE_API_KEY"] = brave_key
253
+
254
+ st.markdown("### Model selections")
255
+ esm2_id = st.text_input(
256
+ "Protein model (ESM-2)",
257
+ value="facebook/esm2_t6_8M_UR50D",
258
+ help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow."
259
+ )
260
+ dna_id = st.text_input(
261
+ "DNA model",
262
+ value="zhihan1996/DNABERT-2-117M",
263
+ help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref"
264
+ )
265
+
266
+ use_web = st.checkbox("Use Brave web search for context", value=True)
267
+ web_k = st.slider("Web results", 1, 10, 4)
268
+
269
+ st.markdown("### Datasets (optional)")
270
+ ds_hint = "Enter a Hugging Face dataset repo id, e.g., 'genomics-benchmark/jaspar_motifs'"
271
+ dataset_ids = st.text_area("Datasets to load (one per line)", value="", help=ds_hint)
272
+
273
+ st.divider()
274
+ st.markdown("Files you upload are indexed locally and used for answers.")
275
+
276
+ # Main tabs
277
+ tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"])
278
+
279
+ # File upload and indexing
280
+ with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True):
281
+ uploads = st.file_uploader(
282
+ "Add files",
283
+ type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"],
284
+ accept_multiple_files=True
285
+ )
286
+ docs = []
287
+ if uploads:
288
+ for up in uploads:
289
+ try:
290
+ txt = load_text_from_file(up)
291
+ docs.extend(chunk_text(txt))
292
+ except Exception as e:
293
+ st.warning(f"Failed to read {up.name}: {e}")
294
+ st.caption(f"Indexed chunks: {len(docs)}")
295
+
296
+ # Build vector index
297
+ index = None
298
+ index_model = None
299
+ if docs:
300
+ with st.spinner("Building vector index..."):
301
+ index, emb, index_model = build_vector_index(docs)
302
+
303
+ # Load datasets
304
+ loaded_datasets = []
305
+ if dataset_ids.strip():
306
+ if load_dataset is None:
307
+ st.warning("datasets library not available")
308
+ else:
309
+ for rid in [x.strip() for x in dataset_ids.splitlines() if x.strip()]:
310
+ with st.spinner(f"Loading dataset {rid} ..."):
311
+ try:
312
+ ds = load_dataset(rid)
313
+ # Show a sample without materializing fully
314
+ sample = ""
315
+ for split in ds.keys():
316
+ try:
317
+ row = ds[split][0]
318
+ sample = json.dumps(row, ensure_ascii=False)[:500]
319
+ break
320
+ except Exception:
321
+ pass
322
+ loaded_datasets.append((rid, sample))
323
+ st.success(f"Loaded {rid}")
324
+ except Exception as e:
325
+ st.error(f"Failed to load {rid}: {e}")
326
+
327
+ def build_context(user_query: str) -> Tuple[str, List[Dict]]:
328
+ """Build context from various sources"""
329
+ pieces = []
330
+ sources = []
331
+
332
+ # From uploaded files
333
+ if index is not None and index_model is not None and docs:
334
+ hits = search_index(user_query, index, index_model, docs, k=4)
335
+ for h in hits:
336
+ pieces.append(f"[FILE] {h['text'][:800]}")
337
+ sources.append({"type": "file", "text": h["text"][:200]})
338
+
339
+ # From datasets
340
+ for rid, sample in loaded_datasets:
341
+ if sample:
342
+ pieces.append(f"[DATASET {rid}] {sample}")
343
+ sources.append({"type": "dataset", "id": rid})
344
+
345
+ # From web
346
+ if use_web:
347
+ results = brave_search(user_query, count=web_k)
348
+ for r in results:
349
+ snippet = r.get("snippet", "")
350
+ url = r.get("url", "")
351
+ title = r.get("title", "")
352
+ pieces.append(f"[WEB] {title}\n{snippet}\n{url}")
353
+ sources.append({"type": "web", "title": title, "url": url})
354
+
355
+ context = "\n\n---\n\n".join(pieces)[:6000]
356
+ return context, sources
357
+
358
+ def chat_answer(user_query: str) -> Tuple[str, List[Dict]]:
359
+ """Generate chat answer with context"""
360
+ context, sources = build_context(user_query)
361
+ system = (
362
+ "You are a concise, careful bioinformatics assistant for protein and DNA. "
363
+ "Answer with factual, verifiable statements. "
364
+ "When uncertain, say so briefly. "
365
+ "Never give medical advice. Provide short references as plain URLs or titles if present in context. "
366
+ "User uploads and web/dataset snippets are provided as context below."
367
+ )
368
+ prompt = f"Context:\n{context}\n\nUser question:\n{user_query}\n\nAnswer in Korean if the user used Korean; otherwise match user language."
369
+ messages = [
370
+ {"role": "system", "content": system},
371
+ {"role": "user", "content": prompt}
372
+ ]
373
+ answer = call_fireworks(messages, temperature=0.4, max_tokens=1200)
374
+ return answer, sources
375
+
376
+ # Chat tab
377
+ with tabs[0]:
378
+ st.subheader("Chat")
379
+ q = st.text_area("Ask a question about protein/DNA", value="ESM-2 ์ž„๋ฒ ๋”ฉ์€ ๋‹จ๋ฐฑ์งˆ ๊ธฐ๋Šฅ ํ•ด์„์— ์–ด๋–ป๊ฒŒ ๋„์›€๋˜๋‚˜์š”?")
380
+
381
+ if st.button("Answer", type="primary"):
382
+ with st.spinner("Thinking..."):
383
+ ans, srcs = chat_answer(q)
384
+ st.write(ans)
385
+
386
+ if srcs:
387
+ st.markdown("#### Sources")
388
+ for s in srcs:
389
+ if s.get("type") == "web" and s.get("url"):
390
+ st.markdown(f"- {s.get('title','web')}: {s.get('url')}")
391
+ elif s.get("type") == "dataset":
392
+ st.markdown(f"- dataset: {s.get('id')}")
393
+ elif s.get("type") == "file":
394
+ snippet = s.get("text", "")
395
+ st.markdown(f"- file snippet: {snippet[:120]}...")
396
+
397
+ # Protein tab
398
+ with tabs[1]:
399
+ st.subheader("Protein analysis")
400
+ seq = st.text_area("Protein sequence (FASTA seq only; single sequence)", value="MKTIIALSYIFCLVFADYKDDDDK")
401
+
402
+ col1, col2 = st.columns(2)
403
+ with col1:
404
+ st.caption("ESM-2 embedding")
405
+ if st.button("Run ESM-2", key="run_esm2"):
406
+ with st.spinner("Computing ESM-2 embedding..."):
407
+ out = esm2_embed(seq, esm2_id)
408
+ if "error" in out:
409
+ st.error(out["error"])
410
+ else:
411
+ st.success(f"Vector size: {out['hidden_size']}")
412
+ st.json({"embedding_preview": out["embedding"][:8]})
413
+
414
+ with col2:
415
+ st.caption("Quick stats")
416
+ s = seq.replace("\n", "").replace(" ", "")
417
+ length = len(s)
418
+ aa_set = sorted(set(list(s)))
419
+ st.write(f"Length: {length}")
420
+ st.write(f"Unique AAs: {''.join(aa_set)[:30]}")
421
+
422
+ # DNA tab
423
+ with tabs[2]:
424
+ st.subheader("DNA analysis")
425
+ dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC")
426
+
427
+ col3, col4 = st.columns(2)
428
+ with col3:
429
+ st.caption("DNABERT-2 / Nucleotide Transformer embedding")
430
+ if st.button("Run DNA embed", key="run_dna"):
431
+ with st.spinner("Computing DNA embedding..."):
432
+ out = dna_embed(dseq, dna_id)
433
+ if "error" in out:
434
+ st.error(out["error"])
435
+ else:
436
+ st.success(f"Vector size: {out['hidden_size']}")
437
+ st.json({"embedding_preview": out["embedding"][:8]})
438
+
439
+ with col4:
440
+ st.caption("GC content")
441
+ s = dseq.upper().replace("N", "")
442
+ if len(s) > 0:
443
+ gc = (s.count("G") + s.count("C")) / len(s)
444
+ else:
445
+ gc = 0
446
+ st.write(f"Length: {len(s)}")
447
+ st.write(f"GC: {gc:.3f}")
448
+
449
+ # Examples tab
450
+ with tabs[3]:
451
+ st.subheader("Examples")
452
+ st.markdown("- ์—…๋กœ๋“œํ•œ FASTA์—์„œ ํŠน์ • ๋‹จ๋ฐฑ์งˆ์˜ ๊ธฐ๋Šฅ ์š”์•ฝ๊ณผ ๋ณ€์ด ์˜ํ–ฅ ์งˆ๋ฌธ")
453
+ st.markdown("- DNA ์„œ์—ด์—์„œ ํ”„๋กœ๋ชจํ„ฐ ๊ฐ€๋Šฅ์„ฑ๊ณผ ์ „์‚ฌ์ธ์ž ๋ชจํ‹ฐํ”„ ๊ด€๋ จ ๊ทผ๊ฑฐ ์š”์ฒญ")
454
+ st.markdown("- Enzyme active site ๊ทผ์ ‘ ๋ณ€์ด์˜ ๋ฆฌ์Šคํฌ ํ•ด์„(์—ฐ๊ตฌ ๊ด€์ )")
455
+ st.markdown("- ENCODE/UniProt/AlphaFold ๊ฐœ๋… ์„ค๋ช… ์š”์ฒญ")
456
+ st.markdown("- RAG ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฌธ์„œ ์ธ์šฉ๊ณผ ํ•จ๊ป˜ ๊ฐ„๋žต ๋‹ต๋ณ€ ์š”์ฒญ")
457
+
458
+ # About tab
459
+ with tabs[4]:
460
+ st.subheader("About this Space")
461
+ st.write("Models suggested: ESM-2 for proteins; DNABERT-2 or Nucleotide Transformer for DNA.")
462
+ st.write("Datasets commonly used: UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar.")
463
+ st.write("Web search powered by Brave Search if API key is provided.")
464
+ st.write("")
465
+ st.info(DISCLAIMER)